diff --git a/src/eviction/lfu.go b/src/eviction/lfu.go index d1c38f7..b28d779 100644 --- a/src/eviction/lfu.go +++ b/src/eviction/lfu.go @@ -3,12 +3,14 @@ package eviction import ( "container/heap" "slices" + "time" ) type EntryLFU struct { - key string // The key, matching the key in the store - count int // The number of times this key has been accessed - index int // The index of the entry in the heap + key string // The key, matching the key in the store + count int // The number of times this key has been accessed + addedTime int64 // The time this entry was added to the cache in unix milliseconds + index int // The index of the entry in the heap } type CacheLFU struct { @@ -30,6 +32,11 @@ func (cache *CacheLFU) Len() int { } func (cache *CacheLFU) Less(i, j int) bool { + // If 2 entries have the same count, return the older one + if cache.entries[i].count == cache.entries[j].count { + return cache.entries[i].addedTime > cache.entries[j].addedTime + } + // Otherwise, return the one with a lower count return cache.entries[i].count < cache.entries[j].count } @@ -42,9 +49,10 @@ func (cache *CacheLFU) Swap(i, j int) { func (cache *CacheLFU) Push(key any) { n := len(cache.entries) cache.entries = append(cache.entries, &EntryLFU{ - key: key.(string), - count: 1, - index: n, + key: key.(string), + count: 1, + addedTime: time.Now().UnixMilli(), + index: n, }) cache.keys[key.(string)] = true } diff --git a/src/modules/generic/commands.go b/src/modules/generic/commands.go index 5ecd0dc..9c40d46 100644 --- a/src/modules/generic/commands.go +++ b/src/modules/generic/commands.go @@ -35,7 +35,7 @@ func handleSet(ctx context.Context, cmd []string, server utils.Server, _ *net.Co if !server.KeyExists(key) { res = []byte("$-1\r\n") } else { - res = []byte(fmt.Sprintf("+%v\r\n", server.GetValue(key))) + res = []byte(fmt.Sprintf("+%v\r\n", server.GetValue(ctx, key))) } } @@ -72,7 +72,7 @@ func handleSet(ctx context.Context, cmd []string, server utils.Server, _ *net.Co // If expiresAt is set, set the key's expiry time as well if params.expireAt != nil { - server.SetKeyExpiry(key, params.expireAt.(time.Time), false) + server.SetKeyExpiry(ctx, key, params.expireAt.(time.Time), false) } return res, nil @@ -151,7 +151,7 @@ func handleGet(ctx context.Context, cmd []string, server utils.Server, _ *net.Co } defer server.KeyRUnlock(key) - value := server.GetValue(key) + value := server.GetValue(ctx, key) return []byte(fmt.Sprintf("+%v\r\n", value)), nil } @@ -190,7 +190,7 @@ func handleMGet(ctx context.Context, cmd []string, server utils.Server, _ *net.C }() for key, _ := range locks { - values[key] = fmt.Sprintf("%v", server.GetValue(key)) + values[key] = fmt.Sprintf("%v", server.GetValue(ctx, key)) } bytes := []byte(fmt.Sprintf("*%d\r\n", len(cmd[1:]))) diff --git a/src/modules/generic/commands_test.go b/src/modules/generic/commands_test.go index 41a2849..a054577 100644 --- a/src/modules/generic/commands_test.go +++ b/src/modules/generic/commands_test.go @@ -71,7 +71,7 @@ func Test_HandleSET(t *testing.T) { if rv.String() != test.expectedResponse { t.Errorf("expected response %s, got %s", test.expectedResponse, rv.String()) } - value := mockServer.GetValue(test.command[1]) + value := mockServer.GetValue(context.Background(), test.command[1]) switch value.(type) { default: t.Error("unexpected type for expectedValue") @@ -151,7 +151,7 @@ func Test_HandleMSET(t *testing.T) { t.Error("unexpected type for expectedValue") case int: ev, _ := expectedValue.(int) - value, ok := mockServer.GetValue(key).(int) + value, ok := mockServer.GetValue(context.Background(), key).(int) if !ok { t.Errorf("expected integer type for key %s, got another type", key) } @@ -160,7 +160,7 @@ func Test_HandleMSET(t *testing.T) { } case float64: ev, _ := expectedValue.(float64) - value, ok := mockServer.GetValue(key).(float64) + value, ok := mockServer.GetValue(context.Background(), key).(float64) if !ok { t.Errorf("expected float type for key %s, got another type", key) } @@ -169,7 +169,7 @@ func Test_HandleMSET(t *testing.T) { } case string: ev, _ := expectedValue.(string) - value, ok := mockServer.GetValue(key).(string) + value, ok := mockServer.GetValue(context.Background(), key).(string) if !ok { t.Errorf("expected string type for key %s, got another type", key) } diff --git a/src/modules/hash/commands.go b/src/modules/hash/commands.go index f591801..29ca636 100644 --- a/src/modules/hash/commands.go +++ b/src/modules/hash/commands.go @@ -46,7 +46,7 @@ func handleHSET(ctx context.Context, cmd []string, server utils.Server, conn *ne } defer server.KeyUnlock(key) - hash, ok := server.GetValue(key).(map[string]interface{}) + hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -88,7 +88,7 @@ func handleHGET(ctx context.Context, cmd []string, server utils.Server, conn *ne } defer server.KeyRUnlock(key) - hash, ok := server.GetValue(key).(map[string]interface{}) + hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -139,7 +139,7 @@ func handleHSTRLEN(ctx context.Context, cmd []string, server utils.Server, conn } defer server.KeyRUnlock(key) - hash, ok := server.GetValue(key).(map[string]interface{}) + hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -189,7 +189,7 @@ func handleHVALS(ctx context.Context, cmd []string, server utils.Server, conn *n } defer server.KeyRUnlock(key) - hash, ok := server.GetValue(key).(map[string]interface{}) + hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -251,7 +251,7 @@ func handleHRANDFIELD(ctx context.Context, cmd []string, server utils.Server, co } defer server.KeyRUnlock(key) - hash, ok := server.GetValue(key).(map[string]interface{}) + hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -346,7 +346,7 @@ func handleHLEN(ctx context.Context, cmd []string, server utils.Server, conn *ne } defer server.KeyRUnlock(key) - hash, ok := server.GetValue(key).(map[string]interface{}) + hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -371,7 +371,7 @@ func handleHKEYS(ctx context.Context, cmd []string, server utils.Server, conn *n } defer server.KeyRUnlock(key) - hash, ok := server.GetValue(key).(map[string]interface{}) + hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -436,7 +436,7 @@ func handleHINCRBY(ctx context.Context, cmd []string, server utils.Server, conn } defer server.KeyUnlock(key) - hash, ok := server.GetValue(key).(map[string]interface{}) + hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -493,7 +493,7 @@ func handleHGETALL(ctx context.Context, cmd []string, server utils.Server, conn } defer server.KeyRUnlock(key) - hash, ok := server.GetValue(key).(map[string]interface{}) + hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -534,7 +534,7 @@ func handleHEXISTS(ctx context.Context, cmd []string, server utils.Server, conn } defer server.KeyRUnlock(key) - hash, ok := server.GetValue(key).(map[string]interface{}) + hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -564,7 +564,7 @@ func handleHDEL(ctx context.Context, cmd []string, server utils.Server, conn *ne } defer server.KeyUnlock(key) - hash, ok := server.GetValue(key).(map[string]interface{}) + hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } diff --git a/src/modules/hash/commands_test.go b/src/modules/hash/commands_test.go index 989ed97..09a5fc6 100644 --- a/src/modules/hash/commands_test.go +++ b/src/modules/hash/commands_test.go @@ -125,7 +125,7 @@ func Test_HandleHSET(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - hash, ok := mockServer.GetValue(test.key).(map[string]interface{}) + hash, ok := mockServer.GetValue(context.Background(), test.key).(map[string]interface{}) if !ok { t.Errorf("value at key \"%s\" is not a hash map", test.key) } @@ -278,7 +278,7 @@ func Test_HandleHINCRBY(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - hash, ok := mockServer.GetValue(test.key).(map[string]interface{}) + hash, ok := mockServer.GetValue(context.Background(), test.key).(map[string]interface{}) if !ok { t.Errorf("value at key \"%s\" is not a hash map", test.key) } @@ -1255,7 +1255,7 @@ func Test_HandleHDEL(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - if hash, ok := mockServer.GetValue(test.key).(map[string]interface{}); ok { + if hash, ok := mockServer.GetValue(context.Background(), test.key).(map[string]interface{}); ok { for field, value := range hash { if value != test.expectedValue[field] { t.Errorf("expected value \"%+v\", got \"%+v\"", test.expectedValue[field], value) diff --git a/src/modules/list/commands.go b/src/modules/list/commands.go index bb124e5..3a0f6d4 100644 --- a/src/modules/list/commands.go +++ b/src/modules/list/commands.go @@ -29,7 +29,7 @@ func handleLLen(ctx context.Context, cmd []string, server utils.Server, _ *net.C } defer server.KeyRUnlock(key) - if list, ok := server.GetValue(key).([]interface{}); ok { + if list, ok := server.GetValue(ctx, key).([]interface{}); ok { return []byte(fmt.Sprintf(":%d\r\n", len(list))), nil } @@ -56,7 +56,7 @@ func handleLIndex(ctx context.Context, cmd []string, server utils.Server, conn * if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - list, ok := server.GetValue(key).([]interface{}) + list, ok := server.GetValue(ctx, key).([]interface{}) server.KeyRUnlock(key) if !ok { @@ -93,7 +93,7 @@ func handleLRange(ctx context.Context, cmd []string, server utils.Server, conn * } defer server.KeyRUnlock(key) - list, ok := server.GetValue(key).([]interface{}) + list, ok := server.GetValue(ctx, key).([]interface{}) if !ok { return nil, errors.New("LRANGE command on non-list item") } @@ -171,7 +171,7 @@ func handleLSet(ctx context.Context, cmd []string, server utils.Server, conn *ne } defer server.KeyUnlock(key) - list, ok := server.GetValue(key).([]interface{}) + list, ok := server.GetValue(ctx, key).([]interface{}) if !ok { return nil, errors.New("LSET command on non-list item") } @@ -215,7 +215,7 @@ func handleLTrim(ctx context.Context, cmd []string, server utils.Server, conn *n } defer server.KeyUnlock(key) - list, ok := server.GetValue(key).([]interface{}) + list, ok := server.GetValue(ctx, key).([]interface{}) if !ok { return nil, errors.New("LTRIM command on non-list item") } @@ -262,7 +262,7 @@ func handleLRem(ctx context.Context, cmd []string, server utils.Server, conn *ne } defer server.KeyUnlock(key) - list, ok := server.GetValue(key).([]interface{}) + list, ok := server.GetValue(ctx, key).([]interface{}) if !ok { return nil, errors.New("LREM command on non-list item") } @@ -335,8 +335,8 @@ func handleLMove(ctx context.Context, cmd []string, server utils.Server, conn *n } defer server.KeyUnlock(destination) - sourceList, sourceOk := server.GetValue(source).([]interface{}) - destinationList, destinationOk := server.GetValue(destination).([]interface{}) + sourceList, sourceOk := server.GetValue(ctx, source).([]interface{}) + destinationList, destinationOk := server.GetValue(ctx, destination).([]interface{}) if !sourceOk || !destinationOk { return nil, errors.New("both source and destination must be lists") @@ -399,7 +399,7 @@ func handleLPush(ctx context.Context, cmd []string, server utils.Server, conn *n } defer server.KeyUnlock(key) - currentList := server.GetValue(key) + currentList := server.GetValue(ctx, key) l, ok := currentList.([]interface{}) if !ok { @@ -446,7 +446,7 @@ func handleRPush(ctx context.Context, cmd []string, server utils.Server, conn *n defer server.KeyUnlock(key) } - currentList := server.GetValue(key) + currentList := server.GetValue(ctx, key) l, ok := currentList.([]interface{}) @@ -477,7 +477,7 @@ func handlePop(ctx context.Context, cmd []string, server utils.Server, conn *net } defer server.KeyUnlock(key) - list, ok := server.GetValue(key).([]interface{}) + list, ok := server.GetValue(ctx, key).([]interface{}) if !ok { return nil, fmt.Errorf("%s command on non-list item", strings.ToUpper(cmd[0])) } diff --git a/src/modules/list/commands_test.go b/src/modules/list/commands_test.go index ae2eabf..a04dc1c 100644 --- a/src/modules/list/commands_test.go +++ b/src/modules/list/commands_test.go @@ -512,7 +512,7 @@ func Test_HandleLSET(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - list, ok := mockServer.GetValue(test.key).([]interface{}) + list, ok := mockServer.GetValue(context.Background(), test.key).([]interface{}) if !ok { t.Error("expected value to be list, got another type") } @@ -670,7 +670,7 @@ func Test_HandleLTRIM(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - list, ok := mockServer.GetValue(test.key).([]interface{}) + list, ok := mockServer.GetValue(context.Background(), test.key).([]interface{}) if !ok { t.Error("expected value to be list, got another type") } @@ -789,7 +789,7 @@ func Test_HandleLREM(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - list, ok := mockServer.GetValue(test.key).([]interface{}) + list, ok := mockServer.GetValue(context.Background(), test.key).([]interface{}) if !ok { t.Error("expected value to be list, got another type") } @@ -999,7 +999,7 @@ func Test_HandleLMOVE(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { t.Error(err) } - list, ok := mockServer.GetValue(key).([]interface{}) + list, ok := mockServer.GetValue(context.Background(), key).([]interface{}) if !ok { t.Error("expected value to be list, got another type") } @@ -1105,7 +1105,7 @@ func Test_HandleLPUSH(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - list, ok := mockServer.GetValue(test.key).([]interface{}) + list, ok := mockServer.GetValue(context.Background(), test.key).([]interface{}) if !ok { t.Error("expected value to be list, got another type") } @@ -1206,7 +1206,7 @@ func Test_HandleRPUSH(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - list, ok := mockServer.GetValue(test.key).([]interface{}) + list, ok := mockServer.GetValue(context.Background(), test.key).([]interface{}) if !ok { t.Error("expected value to be list, got another type") } @@ -1316,7 +1316,7 @@ func Test_HandlePop(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - list, ok := mockServer.GetValue(test.key).([]interface{}) + list, ok := mockServer.GetValue(context.Background(), test.key).([]interface{}) if !ok { t.Error("expected value to be list, got another type") } diff --git a/src/modules/set/commands.go b/src/modules/set/commands.go index c77531e..20923e1 100644 --- a/src/modules/set/commands.go +++ b/src/modules/set/commands.go @@ -37,7 +37,7 @@ func handleSADD(ctx context.Context, cmd []string, server utils.Server, conn *ne } defer server.KeyUnlock(key) - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { return nil, fmt.Errorf("value at key %s is not a set", key) } @@ -64,7 +64,7 @@ func handleSCARD(ctx context.Context, cmd []string, server utils.Server, conn *n } defer server.KeyRUnlock(key) - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { return nil, fmt.Errorf("value at key %s is not a set", key) } @@ -88,7 +88,7 @@ func handleSDIFF(ctx context.Context, cmd []string, server utils.Server, conn *n return nil, err } defer server.KeyRUnlock(keys[0]) - baseSet, ok := server.GetValue(keys[0]).(*Set) + baseSet, ok := server.GetValue(ctx, keys[0]).(*Set) if !ok { return nil, fmt.Errorf("value at key %s is not a set", keys[0]) } @@ -114,7 +114,7 @@ func handleSDIFF(ctx context.Context, cmd []string, server utils.Server, conn *n var sets []*Set for _, key := range cmd[2:] { - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { continue } @@ -151,7 +151,7 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co return nil, err } defer server.KeyRUnlock(keys[1]) - baseSet, ok := server.GetValue(keys[1]).(*Set) + baseSet, ok := server.GetValue(ctx, keys[1]).(*Set) if !ok { return nil, fmt.Errorf("value at key %s is not a set", keys[1]) } @@ -177,7 +177,7 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co var sets []*Set for _, key := range keys[2:] { - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { continue } @@ -240,7 +240,7 @@ func handleSINTER(ctx context.Context, cmd []string, server utils.Server, conn * var sets []*Set for key, _ := range locks { - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { // If the value at the key is not a set, return error return nil, fmt.Errorf("value at key %s is not a set", key) @@ -316,7 +316,7 @@ func handleSINTERCARD(ctx context.Context, cmd []string, server utils.Server, co var sets []*Set for key, _ := range locks { - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { // If the value at the key is not a set, return error return nil, fmt.Errorf("value at key %s is not a set", key) @@ -362,7 +362,7 @@ func handleSINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c var sets []*Set for key, _ := range locks { - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { // If the value at the key is not a set, return error return nil, fmt.Errorf("value at key %s is not a set", key) @@ -408,7 +408,7 @@ func handleSISMEMBER(ctx context.Context, cmd []string, server utils.Server, con } defer server.KeyRUnlock(key) - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { return nil, fmt.Errorf("value at key %s is not a set", key) } @@ -437,7 +437,7 @@ func handleSMEMBERS(ctx context.Context, cmd []string, server utils.Server, conn } defer server.KeyRUnlock(key) - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { return nil, fmt.Errorf("value at key %s is not a set", key) } @@ -480,7 +480,7 @@ func handleSMISMEMBER(ctx context.Context, cmd []string, server utils.Server, co } defer server.KeyRUnlock(key) - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { return nil, fmt.Errorf("value at key %s is not a set", key) } @@ -517,7 +517,7 @@ func handleSMOVE(ctx context.Context, cmd []string, server utils.Server, conn *n } defer server.KeyUnlock(source) - sourceSet, ok := server.GetValue(source).(*Set) + sourceSet, ok := server.GetValue(ctx, source).(*Set) if !ok { return nil, errors.New("source is not a set") } @@ -540,7 +540,7 @@ func handleSMOVE(ctx context.Context, cmd []string, server utils.Server, conn *n return nil, err } defer server.KeyUnlock(destination) - ds, ok := server.GetValue(destination).(*Set) + ds, ok := server.GetValue(ctx, destination).(*Set) if !ok { return nil, errors.New("destination is not a set") } @@ -578,7 +578,7 @@ func handleSPOP(ctx context.Context, cmd []string, server utils.Server, conn *ne } defer server.KeyUnlock(key) - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { return nil, fmt.Errorf("value at %s is not a set", key) } @@ -622,7 +622,7 @@ func handleSRANDMEMBER(ctx context.Context, cmd []string, server utils.Server, c } defer server.KeyUnlock(key) - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { return nil, fmt.Errorf("value at %s is not a set", key) } @@ -658,7 +658,7 @@ func handleSREM(ctx context.Context, cmd []string, server utils.Server, conn *ne } defer server.KeyUnlock(key) - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { return nil, fmt.Errorf("value at key %s is not a set", key) } @@ -699,7 +699,7 @@ func handleSUNION(ctx context.Context, cmd []string, server utils.Server, conn * if !locked { continue } - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { return nil, fmt.Errorf("value at key %s is not a set", key) } @@ -750,7 +750,7 @@ func handleSUNIONSTORE(ctx context.Context, cmd []string, server utils.Server, c if !locked { continue } - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { return nil, fmt.Errorf("value at key %s is not a set", key) } diff --git a/src/modules/set/commant_test.go b/src/modules/set/commant_test.go index 90999cc..3081cd9 100644 --- a/src/modules/set/commant_test.go +++ b/src/modules/set/commant_test.go @@ -88,7 +88,7 @@ func Test_HandleSADD(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.key).(*Set) + set, ok := mockServer.GetValue(context.Background(), test.key).(*Set) if !ok { t.Errorf("expected set value at key \"%s\"", test.key) } @@ -408,7 +408,7 @@ func Test_HandleSDIFFSTORE(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.destination).(*Set) + set, ok := mockServer.GetValue(context.Background(), test.destination).(*Set) if !ok { t.Errorf("expected vaule at key %s to be set, got another type", test.destination) } @@ -754,7 +754,7 @@ func Test_HandleSINTERSTORE(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.destination).(*Set) + set, ok := mockServer.GetValue(context.Background(), test.destination).(*Set) if !ok { t.Errorf("expected vaule at key %s to be set, got another type", test.destination) } @@ -1126,7 +1126,7 @@ func Test_HandleSMOVE(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { t.Error(key) } - set, ok := mockServer.GetValue(key).(*Set) + set, ok := mockServer.GetValue(context.Background(), key).(*Set) if !ok { t.Errorf("expected set \"%s\" to be a set, got another type", key) } @@ -1223,7 +1223,7 @@ func Test_HandleSPOP(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.key).(*Set) + set, ok := mockServer.GetValue(context.Background(), test.key).(*Set) if !ok { t.Errorf("expected value at key \"%s\" to be a set, got another type", test.key) } @@ -1334,7 +1334,7 @@ func Test_HandleSRANDMEMBER(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.key).(*Set) + set, ok := mockServer.GetValue(context.Background(), test.key).(*Set) if !ok { t.Errorf("expected value at key \"%s\" to be a set, got another type", test.key) } @@ -1437,7 +1437,7 @@ func Test_HandleSREM(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.key).(*Set) + set, ok := mockServer.GetValue(context.Background(), test.key).(*Set) if !ok { t.Errorf("expected value at key \"%s\" to be a set, got another type", test.key) } @@ -1641,7 +1641,7 @@ func Test_HandleSUNIONSTORE(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.destination).(*Set) + set, ok := mockServer.GetValue(context.Background(), test.destination).(*Set) if !ok { t.Errorf("expected vaule at key %s to be set, got another type", test.destination) } diff --git a/src/modules/sorted_set/commands.go b/src/modules/sorted_set/commands.go index 8dd2142..903f6b7 100644 --- a/src/modules/sorted_set/commands.go +++ b/src/modules/sorted_set/commands.go @@ -133,7 +133,7 @@ func handleZADD(ctx context.Context, cmd []string, server utils.Server, conn *ne return nil, err } defer server.KeyUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -180,7 +180,7 @@ func handleZCARD(ctx context.Context, cmd []string, server utils.Server, conn *n } defer server.KeyRUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -241,7 +241,7 @@ func handleZCOUNT(ctx context.Context, cmd []string, server utils.Server, conn * } defer server.KeyRUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -275,7 +275,7 @@ func handleZLEXCOUNT(ctx context.Context, cmd []string, server utils.Server, con } defer server.KeyRUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -332,7 +332,7 @@ func handleZDIFF(ctx context.Context, cmd []string, server utils.Server, conn *n return nil, err } defer server.KeyRUnlock(keys[0]) - baseSortedSet, ok := server.GetValue(keys[0]).(*SortedSet) + baseSortedSet, ok := server.GetValue(ctx, keys[0]).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys[0]) } @@ -349,7 +349,7 @@ func handleZDIFF(ctx context.Context, cmd []string, server utils.Server, conn *n return nil, err } locks[keys[i]] = locked - set, ok := server.GetValue(keys[i]).(*SortedSet) + set, ok := server.GetValue(ctx, keys[i]).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) } @@ -400,7 +400,7 @@ func handleZDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co return nil, err } defer server.KeyRUnlock(keys[0]) - baseSortedSet, ok := server.GetValue(keys[0]).(*SortedSet) + baseSortedSet, ok := server.GetValue(ctx, keys[0]).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys[0]) } @@ -412,7 +412,7 @@ func handleZDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co if _, err = server.KeyRLock(ctx, keys[i]); err != nil { return nil, err } - set, ok := server.GetValue(keys[i]).(*SortedSet) + set, ok := server.GetValue(ctx, keys[i]).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) } @@ -486,7 +486,7 @@ func handleZINCRBY(ctx context.Context, cmd []string, server utils.Server, conn return nil, err } defer server.KeyUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -534,7 +534,7 @@ func handleZINTER(ctx context.Context, cmd []string, server utils.Server, conn * return nil, err } locks[keys[i]] = true - set, ok := server.GetValue(keys[i]).(*SortedSet) + set, ok := server.GetValue(ctx, keys[i]).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) } @@ -600,7 +600,7 @@ func handleZINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c return nil, err } locks[keys[i]] = true - set, ok := server.GetValue(keys[i]).(*SortedSet) + set, ok := server.GetValue(ctx, keys[i]).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) } @@ -681,7 +681,7 @@ func handleZMPOP(ctx context.Context, cmd []string, server utils.Server, conn *n if _, err = server.KeyLock(ctx, keys[i]); err != nil { continue } - v, ok := server.GetValue(keys[i]).(*SortedSet) + v, ok := server.GetValue(ctx, keys[i]).(*SortedSet) if !ok || v.Cardinality() == 0 { server.KeyUnlock(keys[i]) continue @@ -739,7 +739,7 @@ func handleZPOP(ctx context.Context, cmd []string, server utils.Server, conn *ne } defer server.KeyUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at key %s is not a sorted set", key) } @@ -776,7 +776,7 @@ func handleZMSCORE(ctx context.Context, cmd []string, server utils.Server, conn } defer server.KeyRUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -835,7 +835,7 @@ func handleZRANDMEMBER(ctx context.Context, cmd []string, server utils.Server, c } defer server.KeyRUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -879,7 +879,7 @@ func handleZRANK(ctx context.Context, cmd []string, server utils.Server, conn *n } defer server.KeyRUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -923,7 +923,7 @@ func handleZREM(ctx context.Context, cmd []string, server utils.Server, conn *ne } defer server.KeyUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -953,7 +953,7 @@ func handleZSCORE(ctx context.Context, cmd []string, server utils.Server, conn * return nil, err } defer server.KeyRUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -996,7 +996,7 @@ func handleZREMRANGEBYSCORE(ctx context.Context, cmd []string, server utils.Serv } defer server.KeyUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -1038,7 +1038,7 @@ func handleZREMRANGEBYRANK(ctx context.Context, cmd []string, server utils.Serve } defer server.KeyUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -1095,7 +1095,7 @@ func handleZREMRANGEBYLEX(ctx context.Context, cmd []string, server utils.Server } defer server.KeyUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -1193,7 +1193,7 @@ func handleZRANGE(ctx context.Context, cmd []string, server utils.Server, conn * } defer server.KeyRUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -1330,7 +1330,7 @@ func handleZRANGESTORE(ctx context.Context, cmd []string, server utils.Server, c } defer server.KeyRUnlock(source) - set, ok := server.GetValue(source).(*SortedSet) + set, ok := server.GetValue(ctx, source).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", source) } @@ -1432,7 +1432,7 @@ func handleZUNION(ctx context.Context, cmd []string, server utils.Server, conn * return nil, err } locks[keys[i]] = true - set, ok := server.GetValue(keys[i]).(*SortedSet) + set, ok := server.GetValue(ctx, keys[i]).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) } @@ -1494,7 +1494,7 @@ func handleZUNIONSTORE(ctx context.Context, cmd []string, server utils.Server, c return nil, err } locks[keys[i]] = true - set, ok := server.GetValue(keys[i]).(*SortedSet) + set, ok := server.GetValue(ctx, keys[i]).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) } diff --git a/src/modules/sorted_set/commands_test.go b/src/modules/sorted_set/commands_test.go index 45ddd93..2cef684 100644 --- a/src/modules/sorted_set/commands_test.go +++ b/src/modules/sorted_set/commands_test.go @@ -248,7 +248,7 @@ func Test_HandleZADD(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - sortedSet, ok := mockServer.GetValue(test.key).(*SortedSet) + sortedSet, ok := mockServer.GetValue(context.Background(), test.key).(*SortedSet) if !ok { t.Errorf("expected the value at key \"%s\" to be a sorted set, got another type", test.key) } @@ -935,7 +935,7 @@ func Test_HandleZDIFFSTORE(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.destination).(*SortedSet) + set, ok := mockServer.GetValue(context.Background(), test.destination).(*SortedSet) if !ok { t.Errorf("expected vaule at key %s to be set, got another type", test.destination) } @@ -1149,7 +1149,7 @@ func Test_HandleZINCRBY(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.key).(*SortedSet) + set, ok := mockServer.GetValue(context.Background(), test.key).(*SortedSet) if !ok { t.Errorf("expected vaule at key %s to be set, got another type", test.key) } @@ -1384,7 +1384,7 @@ func Test_HandleZMPOP(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(key).(*SortedSet) + set, ok := mockServer.GetValue(context.Background(), key).(*SortedSet) if !ok { t.Errorf("expected key \"%s\" to be a sorted set, got another type", key) } @@ -1555,7 +1555,7 @@ func Test_HandleZPOP(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(key).(*SortedSet) + set, ok := mockServer.GetValue(context.Background(), key).(*SortedSet) if !ok { t.Errorf("expected key \"%s\" to be a sorted set, got another type", key) } @@ -1869,7 +1869,7 @@ func Test_HandleZRANDMEMBER(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.key).(*SortedSet) + set, ok := mockServer.GetValue(context.Background(), test.key).(*SortedSet) if !ok { t.Errorf("expected value at key \"%s\" to be a set, got another type", test.key) } @@ -2112,7 +2112,7 @@ func Test_HandleZREM(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(key).(*SortedSet) + set, ok := mockServer.GetValue(context.Background(), key).(*SortedSet) if !ok { t.Errorf("expected value at key \"%s\" to be a sorted set, got another type", key) } @@ -2218,7 +2218,7 @@ func Test_HandleZREMRANGEBYSCORE(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(key).(*SortedSet) + set, ok := mockServer.GetValue(context.Background(), key).(*SortedSet) if !ok { t.Errorf("expected value at key \"%s\" to be a sorted set, got another type", key) } @@ -2378,7 +2378,7 @@ func Test_HandleZREMRANGEBYRANK(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(key).(*SortedSet) + set, ok := mockServer.GetValue(context.Background(), key).(*SortedSet) if !ok { t.Errorf("expected value at key \"%s\" to be a sorted set, got another type", key) } @@ -2509,7 +2509,7 @@ func Test_HandleZREMRANGEBYLEX(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(key).(*SortedSet) + set, ok := mockServer.GetValue(context.Background(), key).(*SortedSet) if !ok { t.Errorf("expected value at key \"%s\" to be a sorted set, got another type", key) } @@ -3034,7 +3034,7 @@ func Test_HandleZRANGESTORE(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.destination).(*SortedSet) + set, ok := mockServer.GetValue(context.Background(), test.destination).(*SortedSet) if !ok { t.Errorf("expected vaule at key %s to be set, got another type", test.destination) } @@ -3693,7 +3693,7 @@ func Test_HandleZINTERSTORE(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.destination).(*SortedSet) + set, ok := mockServer.GetValue(context.Background(), test.destination).(*SortedSet) if !ok { t.Errorf("expected vaule at key %s to be set, got another type", test.destination) } @@ -4418,7 +4418,7 @@ func Test_HandleZUNIONSTORE(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.destination).(*SortedSet) + set, ok := mockServer.GetValue(context.Background(), test.destination).(*SortedSet) if !ok { t.Errorf("expected vaule at key %s to be set, got another type", test.destination) } diff --git a/src/modules/string/commands.go b/src/modules/string/commands.go index 35a3548..ee92e6d 100644 --- a/src/modules/string/commands.go +++ b/src/modules/string/commands.go @@ -39,7 +39,7 @@ func handleSetRange(ctx context.Context, cmd []string, server utils.Server, conn } defer server.KeyUnlock(key) - str, ok := server.GetValue(key).(string) + str, ok := server.GetValue(ctx, key).(string) if !ok { return nil, fmt.Errorf("value at key %s is not a string", key) } @@ -100,7 +100,7 @@ func handleStrLen(ctx context.Context, cmd []string, server utils.Server, conn * } defer server.KeyRUnlock(key) - value, ok := server.GetValue(key).(string) + value, ok := server.GetValue(ctx, key).(string) if !ok { return nil, fmt.Errorf("value at key %s is not a string", key) @@ -134,7 +134,7 @@ func handleSubStr(ctx context.Context, cmd []string, server utils.Server, conn * } defer server.KeyRUnlock(key) - value, ok := server.GetValue(key).(string) + value, ok := server.GetValue(ctx, key).(string) if !ok { return nil, fmt.Errorf("value at key %s is not a string", key) } diff --git a/src/modules/string/commands_test.go b/src/modules/string/commands_test.go index f06142e..c531d1b 100644 --- a/src/modules/string/commands_test.go +++ b/src/modules/string/commands_test.go @@ -138,7 +138,7 @@ func Test_HandleSetRange(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - value, ok := mockServer.GetValue(test.key).(string) + value, ok := mockServer.GetValue(context.Background(), test.key).(string) if !ok { t.Error("expected string data type, got another type") } diff --git a/src/server/keyspace.go b/src/server/keyspace.go index 4357fa4..9294514 100644 --- a/src/server/keyspace.go +++ b/src/server/keyspace.go @@ -3,20 +3,26 @@ package server import ( "context" "errors" + "fmt" "github.com/echovault/echovault/src/utils" + "log" + "math/rand" + "runtime" "slices" "strings" "sync" "time" ) -// KeyLock tries to acquire the write lock for the specified key every 5 milliseconds. +// KeyLock tries to acquire the write lock for the specified key. // If the context passed to the function finishes before the lock is acquired, an error is returned. func (server *Server) KeyLock(ctx context.Context, key string) (bool, error) { - ticker := time.NewTicker(5 * time.Millisecond) for { select { default: + if server.keyLocks[key] == nil { + return false, fmt.Errorf("key %s not found", key) + } ok := server.keyLocks[key].TryLock() if ok { return true, nil @@ -24,21 +30,24 @@ func (server *Server) KeyLock(ctx context.Context, key string) (bool, error) { case <-ctx.Done(): return false, context.Cause(ctx) } - <-ticker.C } } func (server *Server) KeyUnlock(key string) { - server.keyLocks[key].Unlock() + if server.KeyExists(key) { + server.keyLocks[key].Unlock() + } } -// KeyRLock tries to acquire the read lock for the specified key every few milliseconds. +// KeyRLock tries to acquire the read lock for the specified key. // If the context passed to the function finishes before the lock is acquired, an error is returned. func (server *Server) KeyRLock(ctx context.Context, key string) (bool, error) { - ticker := time.NewTicker(5 * time.Millisecond) for { select { default: + if server.keyLocks[key] == nil { + return false, fmt.Errorf("key %s not found", key) + } ok := server.keyLocks[key].TryRLock() if ok { return true, nil @@ -46,12 +55,13 @@ func (server *Server) KeyRLock(ctx context.Context, key string) (bool, error) { case <-ctx.Done(): return false, context.Cause(ctx) } - <-ticker.C } } func (server *Server) KeyRUnlock(key string) { - server.keyLocks[key].RUnlock() + if server.KeyExists(key) { + server.keyLocks[key].RUnlock() + } } func (server *Server) KeyExists(key string) bool { @@ -61,7 +71,7 @@ func (server *Server) KeyExists(key string) bool { // CreateKeyAndLock creates a new key lock and immediately locks it if the key does not exist. // If the key exists, the existing key is locked. func (server *Server) CreateKeyAndLock(ctx context.Context, key string) (bool, error) { - if utils.IsMaxMemoryExceeded(server.Config) && server.Config.EvictionPolicy == utils.NoEviction { + if utils.IsMaxMemoryExceeded(server.Config.MaxMemory) && server.Config.EvictionPolicy == utils.NoEviction { return false, errors.New("max memory reached, key not created") } @@ -80,8 +90,11 @@ func (server *Server) CreateKeyAndLock(ctx context.Context, key string) (bool, e // GetValue retrieves the current value at the specified key. // The key must be read-locked before calling this function. -func (server *Server) GetValue(key string) interface{} { - server.updateKeyInCache(key) +func (server *Server) GetValue(ctx context.Context, key string) interface{} { + err := server.updateKeyInCache(ctx, key) + if err != nil { + log.Printf("GetValue error: %+v\n", err) + } return server.store[key] } @@ -90,14 +103,17 @@ func (server *Server) GetValue(key string) interface{} { // in the snapshot engine. // This count triggers a snapshot when the threshold is reached. // The key must be locked prior to calling this function. -func (server *Server) SetValue(_ context.Context, key string, value interface{}) error { - if utils.IsMaxMemoryExceeded(server.Config) && server.Config.EvictionPolicy == utils.NoEviction { +func (server *Server) SetValue(ctx context.Context, key string, value interface{}) error { + if utils.IsMaxMemoryExceeded(server.Config.MaxMemory) && server.Config.EvictionPolicy == utils.NoEviction { return errors.New("max memory reached, key value not set") } server.store[key] = value - server.updateKeyInCache(key) + err := server.updateKeyInCache(ctx, key) + if err != nil { + log.Printf("SetValue error: %+v\n", err) + } if !server.IsInCluster() { server.SnapshotEngine.IncrementChangeCount() @@ -112,10 +128,13 @@ func (server *Server) SetValue(_ context.Context, key string, value interface{}) // The touch parameter determines whether to update the keys access count on lfu eviction policy, // or the access time on lru eviction policy. // The key must be locked prior to calling this function. -func (server *Server) SetKeyExpiry(key string, expire time.Time, touch bool) { +func (server *Server) SetKeyExpiry(ctx context.Context, key string, expire time.Time, touch bool) { server.keyExpiry[key] = expire if touch { - server.updateKeyInCache(key) + err := server.updateKeyInCache(ctx, key) + if err != nil { + log.Printf("SetKeyExpiry error: %+v\n", err) + } } } @@ -153,7 +172,11 @@ func (server *Server) GetState() map[string]interface{} { // updateKeyInCache updates either the key access count or the most recent access time in the cache // depending on whether an LFU or LRU strategy was used. -func (server *Server) updateKeyInCache(key string) { +func (server *Server) updateKeyInCache(ctx context.Context, key string) error { + // If max memory is 0, there's no max so no need to update caches + if server.Config.MaxMemory == 0 { + return nil + } switch strings.ToLower(server.Config.EvictionPolicy) { case utils.AllKeysLFU: server.lfuCache.Update(key) @@ -168,5 +191,150 @@ func (server *Server) updateKeyInCache(key string) { server.lruCache.Update(key) } } - // TODO: Check if memory usage is above max-memory. If it is, pop items from the cache until we get under the limit. + if err := server.adjustMemoryUsage(ctx); err != nil { + return fmt.Errorf("updateKeyInCache: %+v", err) + } + return nil +} + +func (server *Server) adjustMemoryUsage(ctx context.Context) error { + // If max memory is 0, there's no need to adjust memory usage. + if server.Config.MaxMemory == 0 { + return nil + } + // Check if memory usage is above max-memory. + // If it is, pop items from the cache until we get under the limit. + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + // If we're using less memory than the max-memory, there's no need to evict. + if memStats.HeapInuse < server.Config.MaxMemory { + return nil + } + // Force a garbage collection first before we start evicting key. + runtime.GC() + runtime.ReadMemStats(&memStats) + if memStats.HeapInuse < server.Config.MaxMemory { + return nil + } + // We've done a GC, but we're still at or above the max memory limit. + // Start a loop that evicts keys until either the heap is empty or + // we're below the max memory limit. + for { + switch { + case slices.Contains([]string{utils.AllKeysLFU, utils.VolatileLFU}, strings.ToLower(server.Config.EvictionPolicy)): + // Remove keys from LFU cache until we're below the max memory limit or + // until the LFU cache is empty. + for { + // Return if cache is empty + if server.lfuCache.Len() == 0 { + return fmt.Errorf("adjsutMemoryUsage -> LFU cache empty") + } + key := server.lfuCache.Pop().(string) + if _, err := server.KeyLock(ctx, key); err != nil { + return fmt.Errorf("adjustMemoryUsage -> LFU cache eviction: %+v", err) + } + // Delete the keys + delete(server.store, key) + delete(server.keyExpiry, key) + delete(server.keyLocks, key) + // Run garbage collection + runtime.GC() + // Return if we're below max memory + runtime.ReadMemStats(&memStats) + if memStats.HeapInuse < server.Config.MaxMemory { + return nil + } + } + case slices.Contains([]string{utils.AllKeysLRU, utils.VolatileLRU}, strings.ToLower(server.Config.EvictionPolicy)): + // Remove keys from th LRU cache until we're below the max memory limit or + // until the LRU cache is empty. + for { + // Return if cache is empty + if server.lruCache.Len() == 0 { + return fmt.Errorf("adjsutMemoryUsage -> LRU cache empty") + } + key := server.lruCache.Pop().(string) + if _, err := server.KeyLock(ctx, key); err != nil { + return fmt.Errorf("adjustMemoryUsage -> LRU cache eviction: %+v", err) + } + // Delete the keys + delete(server.store, key) + delete(server.keyExpiry, key) + delete(server.keyLocks, key) + // Run garbage collection + runtime.GC() + // Return if we're below max memory + runtime.ReadMemStats(&memStats) + if memStats.HeapInuse < server.Config.MaxMemory { + return nil + } + } + case slices.Contains([]string{utils.AllKeysRandom}, strings.ToLower(server.Config.EvictionPolicy)): + // Remove random keys until we're below the max memory limit + // or there are no more keys remaining. + for { + // If there are no keys, return error + if len(server.keyLocks) == 0 { + err := errors.New("no keys to evict") + return fmt.Errorf("adjustMemoryUsage -> all keys random: %+v", err) + } + // Get random key + idx := rand.Intn(len(server.keyLocks)) + for key, _ := range server.keyLocks { + if idx == 0 { + // Lock the key + if _, err := server.KeyLock(ctx, key); err != nil { + return fmt.Errorf("adjustMemoryUsage -> all keys random: %+v", err) + } + // Delete the key + delete(server.keyLocks, key) + delete(server.store, key) + delete(server.keyExpiry, key) + // Run garbage collection + runtime.GC() + // Return if we're below max memory + runtime.ReadMemStats(&memStats) + if memStats.HeapInuse < server.Config.MaxMemory { + return nil + } + } + idx-- + } + } + case slices.Contains([]string{utils.VolatileRandom}, strings.ToLower(server.Config.EvictionPolicy)): + // Remove random keys with expiry time until we're below the max memory limit + // or there are no more keys with expiry time. + for { + // If there are no volatile keys, return error + if len(server.keyExpiry) == 0 { + err := errors.New("no volatile keys to evict") + return fmt.Errorf("adjustMemoryUsage -> volatile keys random: %+v", err) + } + // Get random volatile key + idx := rand.Intn(len(server.keyExpiry)) + for key, _ := range server.keyExpiry { + if idx == 0 { + // Lock the key + if _, err := server.KeyLock(ctx, key); err != nil { + return fmt.Errorf("adjustMemoryUsage -> volatile keys random: %+v", err) + } + // Delete the key + delete(server.keyLocks, key) + delete(server.store, key) + delete(server.keyExpiry, key) + // Run garbage collection + runtime.GC() + // Return if we're below max memory + runtime.ReadMemStats(&memStats) + if memStats.HeapInuse < server.Config.MaxMemory { + return nil + } + } + idx-- + } + } + default: + return nil + } + } } diff --git a/src/server/server.go b/src/server/server.go index 9648b7c..92dbecb 100644 --- a/src/server/server.go +++ b/src/server/server.go @@ -43,13 +43,13 @@ type Server struct { ACL utils.ACL PubSub utils.PubSub - SnapshotInProgress atomic.Bool - RewriteAOFInProgress atomic.Bool - StateCopyInProgress atomic.Bool - StateMutationInProgress atomic.Bool - LatestSnapshotMilliseconds atomic.Int64 // Unix epoch in milliseconds - SnapshotEngine *snapshot.Engine - AOFEngine *aof.Engine + SnapshotInProgress atomic.Bool // Atomic boolean that's true when actively taking a snapshot. + RewriteAOFInProgress atomic.Bool // Atomic boolean that's true when actively rewriting AOF file is in progress. + StateCopyInProgress atomic.Bool // Atomic boolean that's true when actively copying state for snapshotting or preamble generation. + StateMutationInProgress atomic.Bool // Atomic boolean that is set to true when state mutation is in progress. + LatestSnapshotMilliseconds atomic.Int64 // Unix epoch in milliseconds + SnapshotEngine *snapshot.Engine // Snapshot engine for standalone mode + AOFEngine *aof.Engine // AOF engine for standalone mode } type Opts struct { @@ -136,7 +136,8 @@ func NewServer(opts Opts) *Server { server.lfuCache = eviction.NewCacheLFU() server.lruCache = eviction.NewCacheLRU() - // TODO: Start goroutine that continuously reads the mem stats before triggering purge once max-memory is reached + // TODO: If eviction policy is volatile-ttl, start goroutine that continuously reads the mem stats + // TODO: before triggering purge once max-memory is reached return server } diff --git a/src/utils/config.go b/src/utils/config.go index 98156de..0bdeec2 100644 --- a/src/utils/config.go +++ b/src/utils/config.go @@ -80,7 +80,8 @@ The options are 'always' for syncing on each command, 'everysec' to sync every s var maxMemory uint64 = 0 flag.Func("max-memory", `Upper memory limit before triggering eviction. -Supported units (kb, mb, gb, tb, pb). There is no limit by default.`, func(memory string) error { +Supported units (kb, mb, gb, tb, pb). When 0 is passed, there will be no memory limit. +There is no limit by default.`, func(memory string) error { b, err := ParseMemory(memory) if err != nil { return err diff --git a/src/utils/types.go b/src/utils/types.go index 2b42cbe..3679a4f 100644 --- a/src/utils/types.go +++ b/src/utils/types.go @@ -13,9 +13,9 @@ type Server interface { KeyRUnlock(key string) KeyExists(key string) bool CreateKeyAndLock(ctx context.Context, key string) (bool, error) - GetValue(key string) interface{} + GetValue(ctx context.Context, key string) interface{} SetValue(ctx context.Context, key string, value interface{}) error - SetKeyExpiry(key string, expire time.Time, touch bool) + SetKeyExpiry(ctx context.Context, key string, expire time.Time, touch bool) RemoveKeyExpiry(key string) GetState() map[string]interface{} GetAllCommands(ctx context.Context) []Command diff --git a/src/utils/utils.go b/src/utils/utils.go index 28ff13c..8fb237d 100644 --- a/src/utils/utils.go +++ b/src/utils/utils.go @@ -163,12 +163,16 @@ func ParseMemory(memory string) (uint64, error) { } // IsMaxMemoryExceeded checks whether we have exceeded the current maximum memory limit -func IsMaxMemoryExceeded(config Config) bool { +func IsMaxMemoryExceeded(maxMemory uint64) bool { + if maxMemory == 0 { + return false + } + var memStats runtime.MemStats runtime.ReadMemStats(&memStats) // If we're currently using less than the configured max memory, return false - if memStats.HeapInuse < config.MaxMemory { + if memStats.HeapInuse < maxMemory { return false } @@ -176,7 +180,8 @@ func IsMaxMemoryExceeded(config Config) bool { // This measure is to prevent deleting keys that may be important when some memory can be reclaimed // by just collecting garbage. runtime.GC() + runtime.ReadMemStats(&memStats) // Return true when whe are above or equal to max memory. - return memStats.HeapInuse >= config.MaxMemory + return memStats.HeapInuse >= maxMemory }