diff --git a/aerospike/aerospike.go b/aerospike/aerospike.go index 6101a223..015d3875 100644 --- a/aerospike/aerospike.go +++ b/aerospike/aerospike.go @@ -1,6 +1,7 @@ package aerospike import ( + "context" "log" "time" @@ -216,6 +217,11 @@ func (s *Storage) Get(key string) ([]byte, error) { return data, nil } +// GetWithContext gets value by key (dummy context support) +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { + return s.Get(key) +} + // Set key with value func (s *Storage) Set(key string, val []byte, exp time.Duration) error { k, err := aerospike.NewKey(s.namespace, s.setName, key) @@ -242,6 +248,11 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { return s.client.Put(writePolicy, k, bins) } +// SetWithContext sets value by key (dummy context support) +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { + return s.Set(key, val, exp) +} + // Delete key func (s *Storage) Delete(key string) error { k, err := aerospike.NewKey(s.namespace, s.setName, key) @@ -253,6 +264,11 @@ func (s *Storage) Delete(key string) error { return err } +// DeleteWithContext deletes key (dummy context support) +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { + return s.Delete(key) +} + // Reset all keys func (s *Storage) Reset() error { // Use ScanAll which returns a Recordset @@ -293,6 +309,11 @@ func (s *Storage) Reset() error { return nil } +// ResetWithContext resets all keys (dummy context support) +func (s *Storage) ResetWithContext(ctx context.Context) error { + return s.Reset() +} + // Close the storage func (s *Storage) Close() error { s.client.Close() diff --git a/cassandra/cassandra.go b/cassandra/cassandra.go index dddd8488..05f5a48c 100644 --- a/cassandra/cassandra.go +++ b/cassandra/cassandra.go @@ -1,6 +1,7 @@ package cassandra import ( + "context" "fmt" "strings" "time" @@ -218,8 +219,8 @@ type queryResult struct { ExpiresAt time.Time `db:"expires_at"` } -// Set stores a key-value pair with optional expiration -func (s *Storage) Set(key string, value []byte, exp time.Duration) error { +// SetWithContext stores a key-value pair with optional expiration with context support +func (s *Storage) SetWithContext(ctx context.Context, key string, value []byte, exp time.Duration) error { // Validate key if _, err := validateIdentifier(key, "key"); err != nil { return err @@ -256,11 +257,16 @@ func (s *Storage) Set(key string, value []byte, exp time.Duration) error { "key": key, "value": value, "expires_at": expiresAt, - }).ExecRelease() + }).WithContext(ctx).ExecRelease() } -// Get retrieves a value by key -func (s *Storage) Get(key string) ([]byte, error) { +// Set stores a key-value pair with optional expiration +func (s *Storage) Set(key string, value []byte, exp time.Duration) error { + return s.SetWithContext(context.Background(), key, value, exp) +} + +// GetWithContext retrieves a value by key with context support. +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { // Use query builder for select stmt, names := qb.Select(fmt.Sprintf("%s.%s", s.keyspace, s.table)). Columns("value", "expires_at"). @@ -271,7 +277,7 @@ func (s *Storage) Get(key string) ([]byte, error) { // Use gocqlx session if err := s.sx.Query(stmt, names).BindMap(map[string]interface{}{ "key": key, - }).GetRelease(&result); err != nil { + }).WithContext(ctx).GetRelease(&result); err != nil { if err == gocql.ErrNotFound { return nil, ErrNotFound } @@ -290,8 +296,13 @@ func (s *Storage) Get(key string) ([]byte, error) { return result.Value, nil } -// Delete removes a key from storage -func (s *Storage) Delete(key string) error { +// Get retrieves a value by key. +func (s *Storage) Get(key string) ([]byte, error) { + return s.GetWithContext(context.Background(), key) +} + +// DeleteWithContext removes a key from storage with context support. +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { // Use query builder for delete stmt, names := qb.Delete(fmt.Sprintf("%s.%s", s.keyspace, s.table)). Where(qb.Eq("key")). @@ -300,14 +311,26 @@ func (s *Storage) Delete(key string) error { // Use gocqlx session return s.sx.Query(stmt, names).BindMap(map[string]interface{}{ "key": key, - }).ExecRelease() + }).WithContext(ctx).ExecRelease() } -// Reset clears all keys from storage -func (s *Storage) Reset() error { +// Delete removes a key from storage. +func (s *Storage) Delete(key string) error { + // Use the context-free version + return s.DeleteWithContext(context.Background(), key) +} + +// ResetWithContext clears all keys from storage with context support. +func (s *Storage) ResetWithContext(ctx context.Context) error { // Use direct TRUNCATE query with proper escaping query := fmt.Sprintf("TRUNCATE TABLE %s.%s", s.keyspace, s.table) - return s.sx.Query(query, []string{}).ExecRelease() + return s.sx.Query(query, []string{}).WithContext(ctx).ExecRelease() +} + +// Reset clears all keys from storage. +func (s *Storage) Reset() error { + // Use the context-free version + return s.ResetWithContext(context.Background()) } // Conn returns the underlying gocql session. diff --git a/cassandra/cassandra_test.go b/cassandra/cassandra_test.go index f360737e..69faf9d0 100644 --- a/cassandra/cassandra_test.go +++ b/cassandra/cassandra_test.go @@ -66,6 +66,22 @@ func Test_Set(t *testing.T) { require.Equal(t, []byte("value"), val) } +func Test_SetWithContext(t *testing.T) { + store := newTestStore(t) + defer store.Close() + + // Test SetWithContext + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err := store.SetWithContext(ctx, "test", []byte("value"), 0) + require.ErrorIs(t, err, context.Canceled) + + // Verify the value was not set + val, err := store.Get("test") + require.Error(t, err) + require.Empty(t, val) +} + // Test_Get tests the Get operation func Test_Get(t *testing.T) { store := newTestStore(t) @@ -86,6 +102,28 @@ func Test_Get(t *testing.T) { require.Nil(t, val) } +// Test_GetWithContext tests the Get operation with context +func Test_GetWithContext(t *testing.T) { + store := newTestStore(t) + defer store.Close() + + // Set a value first + err := store.Set("test", []byte("value"), 0) + require.NoError(t, err) + + // Test GetWithContext + ctx, cancel := context.WithCancel(context.Background()) + cancel() + val, err := store.GetWithContext(ctx, "test") + require.ErrorIs(t, err, context.Canceled) + require.Nil(t, val) + + // Verify the value still exists + val, err = store.Get("test") + require.NoError(t, err) + require.Equal(t, []byte("value"), val) +} + // Test_Delete tests the Delete operation func Test_Delete(t *testing.T) { store := newTestStore(t) @@ -110,6 +148,32 @@ func Test_Delete(t *testing.T) { require.Nil(t, val) } +// Test_DeleteWithContext tests the Delete operation with context +func Test_DeleteWithContext(t *testing.T) { + store := newTestStore(t) + defer store.Close() + + // Set a value first + err := store.Set("test", []byte("value"), 0) + require.NoError(t, err) + + // Verify the value exists + val, err := store.Get("test") + require.NoError(t, err) + require.Equal(t, []byte("value"), val) + + // Test DeleteWithContext + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err = store.DeleteWithContext(ctx, "test") + require.ErrorIs(t, err, context.Canceled) + + // Verify the value still exists + val, err = store.Get("test") + require.NoError(t, err) + require.Equal(t, []byte("value"), val) +} + // Test_Expirable_Keys tests the expirable keys functionality func Test_Expirable_Keys(t *testing.T) { store := newTestStore(t) @@ -180,6 +244,33 @@ func Test_Reset(t *testing.T) { require.Nil(t, val) } +// Test_ResetWithContext tests the Reset method with context +func Test_ResetWithContext(t *testing.T) { + store := newTestStore(t) + defer store.Close() + + // Add some data + err := store.Set("test1", []byte("value1"), 0) + require.NoError(t, err) + err = store.Set("test2", []byte("value2"), 0) + require.NoError(t, err) + + // Reset storage with context + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err = store.ResetWithContext(ctx) + require.ErrorIs(t, err, context.Canceled) + + // Verify data is still there + val, err := store.Get("test1") + require.NoError(t, err) + require.Equal(t, []byte("value1"), val) + + val, err = store.Get("test2") + require.NoError(t, err) + require.Equal(t, []byte("value2"), val) +} + // Test_Valid_Identifiers tests valid identifier cases func Test_Valid_Identifiers(t *testing.T) { store := newTestStore(t) diff --git a/surrealdb/surrealdb.go b/surrealdb/surrealdb.go index 0d5f23a2..75c01360 100644 --- a/surrealdb/surrealdb.go +++ b/surrealdb/surrealdb.go @@ -1,11 +1,13 @@ package surrealdb import ( + "context" "encoding/json" "errors" + "time" + "github.com/surrealdb/surrealdb.go" "github.com/surrealdb/surrealdb.go/pkg/models" - "time" ) // Storage interface that is implemented by storage providers @@ -17,12 +19,6 @@ type Storage struct { } // model represents a key-value storage record used in SurrealDB. -// It contains the key name, the stored byte value, and an optional expiration timestamp. -// -// Fields: -// - Key: the unique identifier for the stored item. -// - Body: the value stored as a byte slice (can represent any serialized data, such as JSON). -// - Exp: the expiration time as a Unix timestamp (0 means no expiration). type model struct { Key string `json:"key"` Body []byte `json:"body"` @@ -30,7 +26,6 @@ type model struct { } // New creates a new SurrealDB storage instance using the provided configuration. -// Returns an error if the connection or authentication fails. func New(config ...Config) *Storage { cfg := configDefault(config...) db, err := surrealdb.New(cfg.ConnectionString) @@ -67,6 +62,7 @@ func New(config ...Config) *Storage { return storage } +// Get returns the value by key func (s *Storage) Get(key string) ([]byte, error) { if len(key) == 0 { return nil, errors.New("key is required") @@ -86,6 +82,12 @@ func (s *Storage) Get(key string) ([]byte, error) { return m.Body, nil } +// GetWithContext dummy context support: calls Get ignoring ctx +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { + return s.Get(key) +} + +// Set sets a value by key with optional expiration func (s *Storage) Set(key string, val []byte, exp time.Duration) error { if len(key) == 0 { return errors.New("key is required") @@ -96,7 +98,6 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { expiresAt = time.Now().Add(exp).Unix() } - // Upsert is used instead of Create to allow overriding the same key if it already exists. _, err := surrealdb.Upsert[model](s.db, models.NewRecordID(s.table, key), &model{ Key: key, Body: val, @@ -105,6 +106,12 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { return err } +// SetWithContext dummy context support: calls Set ignoring ctx +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { + return s.Set(key, val, exp) +} + +// Delete removes a key from storage func (s *Storage) Delete(key string) error { if len(key) == 0 { return errors.New("key is required") @@ -114,20 +121,34 @@ func (s *Storage) Delete(key string) error { return err } +// DeleteWithContext dummy context support: calls Delete ignoring ctx +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { + return s.Delete(key) +} + +// Reset clears all keys in the storage table func (s *Storage) Reset() error { _, err := surrealdb.Delete[[]model](s.db, models.Table(s.table)) return err } +// ResetWithContext dummy context support: calls Reset ignoring ctx +func (s *Storage) ResetWithContext(ctx context.Context) error { + return s.Reset() +} + +// Close stops GC and closes the DB connection func (s *Storage) Close() error { close(s.stopGC) return s.db.Close() } +// Conn returns the underlying SurrealDB client func (s *Storage) Conn() *surrealdb.DB { return s.db } +// List returns all stored keys and values as JSON func (s *Storage) List() ([]byte, error) { records, err := surrealdb.Select[[]model, models.Table](s.db, models.Table(s.table)) if err != nil { @@ -148,6 +169,7 @@ func (s *Storage) List() ([]byte, error) { return json.Marshal(data) } +// gc runs periodic cleanup of expired keys func (s *Storage) gc() { ticker := time.NewTicker(s.interval) defer ticker.Stop() @@ -162,6 +184,7 @@ func (s *Storage) gc() { } } +// cleanupExpired deletes expired keys from storage func (s *Storage) cleanupExpired() { records, err := surrealdb.Select[[]model, models.Table](s.db, models.Table(s.table)) if err != nil {