From 75b43b2ac485a272d03345c02834ed42a70c66e4 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Wed, 25 Jun 2025 14:13:21 +0300 Subject: [PATCH] add context support to more storages --- badger/badger.go | 25 ++++++++ bbolt/bbolt.go | 26 ++++++-- memcache/memcache.go | 26 ++++++-- memory/memory.go | 38 +++++++---- mockstorage/mockstorage.go | 80 +++++++++++++++-------- mockstorage/mockstorage_test.go | 109 ++++++++++++++++++++++++++++++++ pebble/pebble.go | 45 +++++++++---- ristretto/ristretto.go | 26 +++++++- 8 files changed, 312 insertions(+), 63 deletions(-) diff --git a/badger/badger.go b/badger/badger.go index 245b63a1..ec0f800a 100644 --- a/badger/badger.go +++ b/badger/badger.go @@ -1,6 +1,7 @@ package badger import ( + "context" "time" "github.com/dgraph-io/badger/v3" @@ -72,6 +73,12 @@ func (s *Storage) Get(key string) ([]byte, error) { return data, err } +// GetWithContext gets value by key. +// Note: This method is not used in the current implementation, but is included to satisfy the Storage interface. +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { + return s.Get(key) +} + // Set key with value func (s *Storage) Set(key string, val []byte, exp time.Duration) error { // Ain't Nobody Got Time For That @@ -88,6 +95,12 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { }) } +// SetWithContext sets key with value. +// Note: This method is not used in the current implementation, but is included to satisfy the Storage interface. +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { + return s.Set(key, val, exp) +} + // Delete key by key func (s *Storage) Delete(key string) error { // Ain't Nobody Got Time For That @@ -99,11 +112,23 @@ func (s *Storage) Delete(key string) error { }) } +// DeleteWithContext deletes key by key. +// Note: This method is not used in the current implementation, but is included to satisfy the Storage interface. +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { + return s.Delete(key) +} + // Reset all keys func (s *Storage) Reset() error { return s.db.DropAll() } +// ResetWithContext resets all keys. +// Note: This method is not used in the current implementation, but is included to satisfy the Storage interface. +func (s *Storage) ResetWithContext(ctx context.Context) error { + return s.db.DropAll() +} + // Close the memory storage func (s *Storage) Close() error { s.done <- struct{}{} diff --git a/bbolt/bbolt.go b/bbolt/bbolt.go index 85dc97b0..3522bbb4 100644 --- a/bbolt/bbolt.go +++ b/bbolt/bbolt.go @@ -1,6 +1,7 @@ package bbolt import ( + "context" "time" "github.com/gofiber/utils/v2" @@ -62,6 +63,11 @@ func (s *Storage) Get(key string) ([]byte, error) { return value, err } +// GetWithContext gets value by key (dummy context support) +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { + return s.Get(key) +} + // Set key with value func (s *Storage) Set(key string, value []byte, exp time.Duration) error { if len(key) <= 0 || len(value) <= 0 { @@ -70,11 +76,15 @@ func (s *Storage) Set(key string, value []byte, exp time.Duration) error { return s.conn.Update(func(tx *bbolt.Tx) error { b := tx.Bucket(utils.UnsafeBytes(s.bucket)) - return b.Put(utils.UnsafeBytes(key), value) }) } +// SetWithContext sets key with value (dummy context support) +func (s *Storage) SetWithContext(ctx context.Context, key string, value []byte, exp time.Duration) error { + return s.Set(key, value, exp) +} + // Delete entry by key func (s *Storage) Delete(key string) error { if len(key) <= 0 { @@ -83,28 +93,36 @@ func (s *Storage) Delete(key string) error { return s.conn.Update(func(tx *bbolt.Tx) error { b := tx.Bucket(utils.UnsafeBytes(s.bucket)) - return b.Delete(utils.UnsafeBytes(key)) }) } +// DeleteWithContext deletes key by key (dummy context support) +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { + return s.Delete(key) +} + // Reset all entries func (s *Storage) Reset() error { return s.conn.Update(func(tx *bbolt.Tx) error { b := tx.Bucket(utils.UnsafeBytes(s.bucket)) - return b.ForEach(func(k, _ []byte) error { return b.Delete(k) }) }) } +// ResetWithContext resets all entries (dummy context support) +func (s *Storage) ResetWithContext(ctx context.Context) error { + return s.Reset() +} + // Close the database func (s *Storage) Close() error { return s.conn.Close() } -// Return database client +// Conn returns the database client func (s *Storage) Conn() *bbolt.DB { return s.conn } diff --git a/memcache/memcache.go b/memcache/memcache.go index 2afb2ea6..139da935 100644 --- a/memcache/memcache.go +++ b/memcache/memcache.go @@ -1,6 +1,7 @@ package memcache import ( + "context" "strings" "sync" "time" @@ -68,10 +69,13 @@ func (s *Storage) Get(key string) ([]byte, error) { return item.Value, nil } -// Set key with value +// GetWithContext gets value by key (dummy context support) +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { + return s.Get(key) +} + // Set key with value func (s *Storage) Set(key string, val []byte, exp time.Duration) error { - // Ain't Nobody Got Time For That if len(key) <= 0 || len(val) <= 0 { return nil } @@ -87,20 +91,34 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { return err } +// SetWithContext sets key with value (dummy context support) +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { + return s.Set(key, val, exp) +} + // Delete key by key func (s *Storage) Delete(key string) error { - // Ain't Nobody Got Time For That if len(key) <= 0 { return nil } return s.db.Delete(key) } +// DeleteWithContext deletes key by key (dummy context support) +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { + return s.Delete(key) +} + // Reset all keys func (s *Storage) Reset() error { return s.db.DeleteAll() } +// ResetWithContext resets all keys (dummy context support) +func (s *Storage) ResetWithContext(ctx context.Context) error { + return s.Reset() +} + // Close the database func (s *Storage) Close() error { return nil @@ -111,7 +129,7 @@ func (s *Storage) acquireItem() *mc.Item { return s.items.Get().(*mc.Item) } -// Release item from pool +// Release item back to pool func (s *Storage) releaseItem(item *mc.Item) { if item != nil { item.Key = "" diff --git a/memory/memory.go b/memory/memory.go index cf35923c..900cae59 100644 --- a/memory/memory.go +++ b/memory/memory.go @@ -1,6 +1,7 @@ package memory import ( + "context" "sync" "sync/atomic" "time" @@ -17,9 +18,8 @@ type Storage struct { } type entry struct { - data []byte - // max value is 4294967295 -> Sun Feb 07 2106 06:28:15 GMT+0000 - expiry uint32 + data []byte + expiry uint32 // max value is 4294967295 -> Sun Feb 07 2106 06:28:15 GMT+0000 } // New creates a new memory storage @@ -49,16 +49,20 @@ func (s *Storage) Get(key string) ([]byte, error) { s.mux.RLock() v, ok := s.db[key] s.mux.RUnlock() - if !ok || v.expiry != 0 && v.expiry <= atomic.LoadUint32(&internal.Timestamp) { + if !ok || (v.expiry != 0 && v.expiry <= atomic.LoadUint32(&internal.Timestamp)) { return nil, nil } return v.data, nil } +// GetWithContext gets value by key (dummy context support) +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { + return s.Get(key) +} + // Set key with value func (s *Storage) Set(key string, val []byte, exp time.Duration) error { - // Ain't Nobody Got Time For That if len(key) <= 0 || len(val) <= 0 { return nil } @@ -75,9 +79,13 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { return nil } +// SetWithContext sets value by key (dummy context support) +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { + return s.Set(key, val, exp) +} + // Delete key by key func (s *Storage) Delete(key string) error { - // Ain't Nobody Got Time For That if len(key) <= 0 { return nil } @@ -87,6 +95,11 @@ func (s *Storage) Delete(key string) error { return nil } +// DeleteWithContext deletes key (dummy context support) +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { + return s.Delete(key) +} + // Reset all keys func (s *Storage) Reset() error { ndb := make(map[string]entry) @@ -96,6 +109,11 @@ func (s *Storage) Reset() error { return nil } +// ResetWithContext resets all keys (dummy context support) +func (s *Storage) ResetWithContext(ctx context.Context) error { + return s.Reset() +} + // Close the memory storage func (s *Storage) Close() error { s.done <- struct{}{} @@ -122,8 +140,6 @@ func (s *Storage) gc() { } s.mux.RUnlock() s.mux.Lock() - // Double-checked locking. - // We might have replaced the item in the meantime. for i := range expired { v := s.db[expired[i]] if v.expiry != 0 && v.expiry <= ts { @@ -135,14 +151,14 @@ func (s *Storage) gc() { } } -// Return database client +// Conn returns database client func (s *Storage) Conn() map[string]entry { s.mux.RLock() defer s.mux.RUnlock() return s.db } -// Return all the keys +// Keys returns all the keys func (s *Storage) Keys() ([][]byte, error) { s.mux.RLock() defer s.mux.RUnlock() @@ -154,13 +170,11 @@ func (s *Storage) Keys() ([][]byte, error) { ts := atomic.LoadUint32(&internal.Timestamp) keys := make([][]byte, 0, len(s.db)) for key, v := range s.db { - // Filter out the expired keys if v.expiry == 0 || v.expiry > ts { keys = append(keys, []byte(key)) } } - // Double check if no valid keys were found if len(keys) == 0 { return nil, nil } diff --git a/mockstorage/mockstorage.go b/mockstorage/mockstorage.go index 000757d8..36f59c89 100644 --- a/mockstorage/mockstorage.go +++ b/mockstorage/mockstorage.go @@ -1,6 +1,7 @@ package mockstorage import ( + "context" "errors" "sync" "time" @@ -26,31 +27,26 @@ type Entry struct { // CustomFuncs allows injecting custom behaviors for testing. type CustomFuncs struct { - GetFunc func(key string) ([]byte, error) - SetFunc func(key string, val []byte, exp time.Duration) error - DeleteFunc func(key string) error - ResetFunc func() error - CloseFunc func() error - ConnFunc func() map[string]Entry - KeysFunc func() ([][]byte, error) + GetFunc func(key string) ([]byte, error) + GetWithContext func(ctx context.Context, key string) ([]byte, error) + SetFunc func(key string, val []byte, exp time.Duration) error + SetWithContext func(ctx context.Context, key string, val []byte, exp time.Duration) error + DeleteFunc func(key string) error + DeleteWithContext func(ctx context.Context, key string) error + ResetFunc func() error + ResetWithContext func(ctx context.Context) error + CloseFunc func() error + ConnFunc func() map[string]Entry + KeysFunc func() ([][]byte, error) } // New creates a new mock storage with optional configuration. func New(config ...Config) *Storage { s := &Storage{ - data: make(map[string]Entry), - custom: &CustomFuncs{ - GetFunc: nil, - SetFunc: nil, - DeleteFunc: nil, - ResetFunc: nil, - CloseFunc: nil, - ConnFunc: nil, - KeysFunc: nil, - }, + data: make(map[string]Entry), + custom: &CustomFuncs{}, // default no-op } - // If a config is provided and it has CustomFuncs, use them if len(config) > 0 && config[0].CustomFuncs != nil { s.custom = config[0].CustomFuncs } @@ -78,7 +74,15 @@ func (s *Storage) Get(key string) ([]byte, error) { return e.Value, nil } -// Set sets the value for a given key with an expiration time. +// GetWithContext retrieves value by key using a context (functional or fallback) +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { + if s.custom.GetWithContext != nil { + return s.custom.GetWithContext(ctx, key) + } + return s.Get(key) +} + +// Set sets the value for a given key with expiration. func (s *Storage) Set(key string, val []byte, exp time.Duration) error { if s.custom.SetFunc != nil { return s.custom.SetFunc(key, val, exp) @@ -96,6 +100,14 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { return nil } +// SetWithContext sets value using context. +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { + if s.custom.SetWithContext != nil { + return s.custom.SetWithContext(ctx, key, val, exp) + } + return s.Set(key, val, exp) +} + // Delete removes a key from the storage. func (s *Storage) Delete(key string) error { if s.custom.DeleteFunc != nil { @@ -104,12 +116,19 @@ func (s *Storage) Delete(key string) error { s.mu.Lock() defer s.mu.Unlock() - delete(s.data, key) return nil } -// Reset clears all keys from the storage. +// DeleteWithContext deletes key using context. +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { + if s.custom.DeleteWithContext != nil { + return s.custom.DeleteWithContext(ctx, key) + } + return s.Delete(key) +} + +// Reset clears all keys. func (s *Storage) Reset() error { if s.custom.ResetFunc != nil { return s.custom.ResetFunc() @@ -117,22 +136,27 @@ func (s *Storage) Reset() error { s.mu.Lock() defer s.mu.Unlock() - s.data = make(map[string]Entry) return nil } -// Close closes the storage (no-op for mock). +// ResetWithContext resets storage using context. +func (s *Storage) ResetWithContext(ctx context.Context) error { + if s.custom.ResetWithContext != nil { + return s.custom.ResetWithContext(ctx) + } + return s.Reset() +} + +// Close closes the mock storage (no-op). func (s *Storage) Close() error { if s.custom.CloseFunc != nil { return s.custom.CloseFunc() } - - // No resources to clean up in mock return nil } -// Conn returns the internal data map (for testing purposes). +// Conn returns internal map. func (s *Storage) Conn() map[string]Entry { if s.custom.ConnFunc != nil { return s.custom.ConnFunc() @@ -148,7 +172,7 @@ func (s *Storage) Conn() map[string]Entry { return copyData } -// Keys returns all keys in the storage. +// Keys returns all keys. func (s *Storage) Keys() ([][]byte, error) { if s.custom.KeysFunc != nil { return s.custom.KeysFunc() @@ -164,7 +188,7 @@ func (s *Storage) Keys() ([][]byte, error) { return keys, nil } -// SetCustomFuncs allows setting custom function implementations. +// SetCustomFuncs allows runtime injection of function implementations. func (s *Storage) SetCustomFuncs(custom *CustomFuncs) { s.custom = custom } diff --git a/mockstorage/mockstorage_test.go b/mockstorage/mockstorage_test.go index f5cbbf9d..c605063f 100644 --- a/mockstorage/mockstorage_test.go +++ b/mockstorage/mockstorage_test.go @@ -2,6 +2,7 @@ package mockstorage import ( "bytes" + "context" "errors" "testing" "time" @@ -274,3 +275,111 @@ func TestStorageConnAndKeys(t *testing.T) { t.Errorf("Keys() = %v, want %v", keys, [][]byte{[]byte("key1")}) } } + +func TestGetWithContext(t *testing.T) { + store := New() + + // fallback to Get + _ = store.Set("key1", []byte("val1"), 0) + val, err := store.GetWithContext(context.Background(), "key1") + if err != nil || !bytes.Equal(val, []byte("val1")) { + t.Errorf("GetWithContext fallback failed: got %v, err %v", val, err) + } + + // custom override + store.SetCustomFuncs(&CustomFuncs{ + GetWithContext: func(ctx context.Context, key string) ([]byte, error) { + if key == "override" { + return []byte("ctx-value"), nil + } + return nil, errors.New("not found") + }, + }) + val, err = store.GetWithContext(context.TODO(), "override") + if err != nil || !bytes.Equal(val, []byte("ctx-value")) { + t.Errorf("GetWithContext custom failed: got %v, err %v", val, err) + } +} + +func TestSetWithContext(t *testing.T) { + store := New() + + // fallback to Set + err := store.SetWithContext(context.TODO(), "key2", []byte("val2"), 0) + if err != nil { + t.Errorf("SetWithContext fallback failed: %v", err) + } + val, _ := store.Get("key2") + if !bytes.Equal(val, []byte("val2")) { + t.Errorf("SetWithContext fallback mismatch: got %v", val) + } + + // custom override + store.SetCustomFuncs(&CustomFuncs{ + SetWithContext: func(ctx context.Context, key string, val []byte, exp time.Duration) error { + if key == "readonly" { + return errors.New("forbidden") + } + return nil + }, + }) + err = store.SetWithContext(context.TODO(), "readonly", []byte("fail"), 0) + if err == nil || err.Error() != "forbidden" { + t.Errorf("SetWithContext custom override failed: err=%v", err) + } +} + +func TestDeleteWithContext(t *testing.T) { + store := New() + + // fallback to Delete + _ = store.Set("key3", []byte("val3"), 0) + err := store.DeleteWithContext(context.TODO(), "key3") + if err != nil { + t.Errorf("DeleteWithContext fallback failed: %v", err) + } + val, err := store.Get("key3") + if err == nil { + t.Errorf("expected deletion, but got value: %v", val) + } + + // custom override + store.SetCustomFuncs(&CustomFuncs{ + DeleteWithContext: func(ctx context.Context, key string) error { + if key == "undeletable" { + return errors.New("blocked") + } + return nil + }, + }) + err = store.DeleteWithContext(context.TODO(), "undeletable") + if err == nil || err.Error() != "blocked" { + t.Errorf("DeleteWithContext custom override failed: err=%v", err) + } +} + +func TestResetWithContext(t *testing.T) { + store := New() + + // fallback to Reset + _ = store.Set("key4", []byte("val4"), 0) + err := store.ResetWithContext(context.TODO()) + if err != nil { + t.Errorf("ResetWithContext fallback failed: %v", err) + } + val, err := store.Get("key4") + if err == nil { + t.Errorf("expected reset to remove key, but got value: %v", val) + } + + // custom override + store.SetCustomFuncs(&CustomFuncs{ + ResetWithContext: func(ctx context.Context) error { + return errors.New("custom reset error") + }, + }) + err = store.ResetWithContext(context.Background()) + if err == nil || err.Error() != "custom reset error" { + t.Errorf("ResetWithContext custom override failed: err=%v", err) + } +} diff --git a/pebble/pebble.go b/pebble/pebble.go index e0573199..c1c6b536 100644 --- a/pebble/pebble.go +++ b/pebble/pebble.go @@ -1,6 +1,7 @@ package pebble import ( + "context" "encoding/json" "errors" "log" @@ -39,8 +40,7 @@ func New(config ...Config) *Storage { } } -// // Implement the logic to retrieve the value for the given key from the storage provider -// // Return nil, nil if the key does not exist +// Get retrieves the value by key. func (s *Storage) Get(key string) ([]byte, error) { if len(key) <= 0 { return nil, nil @@ -60,7 +60,6 @@ func (s *Storage) Get(key string) ([]byte, error) { var cache CacheType err = json.Unmarshal(data, &cache) - if err != nil { return nil, nil } @@ -71,18 +70,26 @@ func (s *Storage) Get(key string) ([]byte, error) { err = s.db.Delete([]byte(key), nil) return nil, err } + return cache.Data, nil } -// // Implement the logic to store the given value for the given key in the storage provider -// // Use the provided expiration value (0 means no expiration) -// // Ignore empty key or value without returning an error +// GetWithContext retrieves value by key (dummy context support) +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { + return s.Get(key) +} + +// Set stores the given value with optional expiration func (s *Storage) Set(key string, val []byte, exp time.Duration) error { if len(key) <= 0 || len(val) <= 0 { return nil } - cache := CacheType{Data: []byte(val), Created: time.Now().Unix(), Expires: 0} + cache := CacheType{ + Data: val, + Created: time.Now().Unix(), + Expires: 0, + } if exp > 0 { cache.Expires = cache.Created + int64(exp.Seconds()) @@ -95,8 +102,12 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { return s.db.Set([]byte(key), jsonString, s.writeOptions) } -// // Implement the logic to delete the value for the given key from the storage provider -// // Return no error if the key does not exist in the storage +// SetWithContext sets value by key (dummy context support) +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { + return s.Set(key, val, exp) +} + +// Delete removes a value by key func (s *Storage) Delete(key string) error { if len(key) <= 0 { return nil @@ -104,26 +115,36 @@ func (s *Storage) Delete(key string) error { return s.db.Delete([]byte(key), s.writeOptions) } +// DeleteWithContext deletes key (dummy context support) +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { + return s.Delete(key) +} + +// Reset flushes the DB func (s *Storage) Reset() error { return s.db.Flush() } +// ResetWithContext resets storage (dummy context support) +func (s *Storage) ResetWithContext(ctx context.Context) error { + return s.Reset() +} + +// Close closes the database func (s *Storage) Close() error { return s.db.Close() } -// // Return database client +// Conn returns the database client func (s *Storage) Conn() *pebble.DB { return s.db } func isValid(fp string) bool { - // Check if file already exists if _, err := os.Stat(fp); err == nil { return true } - // Attempt to create it var d []byte err := os.WriteFile(fp, d, 0o600) if err != nil { diff --git a/ristretto/ristretto.go b/ristretto/ristretto.go index e7cd6873..6dd29d8a 100644 --- a/ristretto/ristretto.go +++ b/ristretto/ristretto.go @@ -1,6 +1,7 @@ package ristretto import ( + "context" "time" "github.com/dgraph-io/ristretto" @@ -52,6 +53,11 @@ func (s *Storage) Get(key string) ([]byte, error) { return buf, nil } +// GetWithContext gets the value by key (dummy context support) +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { + return s.Get(key) +} + // Set stores the given value for the given key along // with an expiration value, time.Time{} means no expiration. // Empty key or value will be ignored without an error. @@ -66,8 +72,12 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { return nil } +// SetWithContext sets value by key (dummy context support) +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { + return s.Set(key, val, exp) +} + // Delete deletes the value for the given key. -// It returns no error if the storage does not contain the key, func (s *Storage) Delete(key string) error { if len(key) <= 0 { return nil @@ -76,12 +86,22 @@ func (s *Storage) Delete(key string) error { return nil } -// Reset resets the storage and delete all keys. +// DeleteWithContext deletes key (dummy context support) +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { + return s.Delete(key) +} + +// Reset resets the storage and deletes all keys. func (s *Storage) Reset() error { s.cache.Clear() return nil } +// ResetWithContext resets storage (dummy context support) +func (s *Storage) ResetWithContext(ctx context.Context) error { + return s.Reset() +} + // Close closes the storage and will stop any running garbage // collectors and open connections. func (s *Storage) Close() error { @@ -89,7 +109,7 @@ func (s *Storage) Close() error { return nil } -// Return database client +// Conn returns the database client func (s *Storage) Conn() *ristretto.Cache { return s.cache }