diff --git a/cache.go b/cache.go index 61ad1c0..d8a929f 100644 --- a/cache.go +++ b/cache.go @@ -438,6 +438,7 @@ func (c *Cache[T]) gc() int { c.onDelete(item) } dropped += 1 + item.node = nil item.promotions = -2 } node = prev diff --git a/cache_test.go b/cache_test.go index 6209c8b..f8cc06c 100644 --- a/cache_test.go +++ b/cache_test.go @@ -1,6 +1,7 @@ package ccache import ( + "math/rand" "sort" "strconv" "sync/atomic" @@ -313,6 +314,29 @@ func Test_CacheForEachFunc(t *testing.T) { assert.DoesNotContain(t, forEachKeys(cache), "stop") } +func Test_CachePrune(t *testing.T) { + maxSize := int64(500) + cache := New(Configure[string]().MaxSize(maxSize).ItemsToPrune(50)) + epoch := 0 + for i := 0; i < 10000; i++ { + epoch += 1 + expired := make([]string, 0) + for i := 0; i < 50; i += 1 { + key := strconv.FormatInt(rand.Int63n(maxSize*20), 10) + item := cache.Get(key) + if item == nil || item.TTL() > 1*time.Minute { + expired = append(expired, key) + } + } + for _, key := range expired { + cache.Set(key, key, 5*time.Minute) + } + if epoch%500 == 0 { + assert.True(t, cache.GetSize() < 500) + } + } +} + type SizedItem struct { id int s int64 diff --git a/layeredcache.go b/layeredcache.go index f98b291..b6798dd 100644 --- a/layeredcache.go +++ b/layeredcache.go @@ -355,6 +355,7 @@ func (c *LayeredCache[T]) gc() int { if c.onDelete != nil { c.onDelete(item) } + item.node = nil item.promotions = -2 dropped += 1 } diff --git a/layeredcache_test.go b/layeredcache_test.go index c993a47..adf5483 100644 --- a/layeredcache_test.go +++ b/layeredcache_test.go @@ -1,6 +1,7 @@ package ccache import ( + "math/rand" "sort" "strconv" "sync/atomic" @@ -372,6 +373,29 @@ func Test_LayeredCache_EachFunc(t *testing.T) { assert.DoesNotContain(t, forEachKeysLayered[int](cache, "1"), "stop") } +func Test_LayeredCachePrune(t *testing.T) { + maxSize := int64(500) + cache := Layered(Configure[string]().MaxSize(maxSize).ItemsToPrune(50)) + epoch := 0 + for i := 0; i < 10000; i++ { + epoch += 1 + expired := make([]string, 0) + for i := 0; i < 50; i += 1 { + key := strconv.FormatInt(rand.Int63n(maxSize*20), 10) + item := cache.Get(key, key) + if item == nil || item.TTL() > 1*time.Minute { + expired = append(expired, key) + } + } + for _, key := range expired { + cache.Set(key, key, key, 5*time.Minute) + } + if epoch%500 == 0 { + assert.True(t, cache.GetSize() < 500) + } + } +} + func newLayered[T any]() *LayeredCache[T] { c := Layered[T](Configure[T]()) c.Clear()