add context support to more storages

This commit is contained in:
Muhammed Efe Cetin
2025-06-25 14:36:56 +03:00
parent 75b43b2ac4
commit 2841c64d32
4 changed files with 179 additions and 21 deletions

View File

@@ -1,6 +1,7 @@
package aerospike
import (
"context"
"log"
"time"
@@ -216,6 +217,11 @@ func (s *Storage) Get(key string) ([]byte, error) {
return data, nil
}
// GetWithContext gets value by key (dummy context support)
func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) {
return s.Get(key)
}
// Set key with value
func (s *Storage) Set(key string, val []byte, exp time.Duration) error {
k, err := aerospike.NewKey(s.namespace, s.setName, key)
@@ -242,6 +248,11 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error {
return s.client.Put(writePolicy, k, bins)
}
// SetWithContext sets value by key (dummy context support)
func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error {
return s.Set(key, val, exp)
}
// Delete key
func (s *Storage) Delete(key string) error {
k, err := aerospike.NewKey(s.namespace, s.setName, key)
@@ -253,6 +264,11 @@ func (s *Storage) Delete(key string) error {
return err
}
// DeleteWithContext deletes key (dummy context support)
func (s *Storage) DeleteWithContext(ctx context.Context, key string) error {
return s.Delete(key)
}
// Reset all keys
func (s *Storage) Reset() error {
// Use ScanAll which returns a Recordset
@@ -293,6 +309,11 @@ func (s *Storage) Reset() error {
return nil
}
// ResetWithContext resets all keys (dummy context support)
func (s *Storage) ResetWithContext(ctx context.Context) error {
return s.Reset()
}
// Close the storage
func (s *Storage) Close() error {
s.client.Close()

View File

@@ -1,6 +1,7 @@
package cassandra
import (
"context"
"fmt"
"strings"
"time"
@@ -218,8 +219,8 @@ type queryResult struct {
ExpiresAt time.Time `db:"expires_at"`
}
// Set stores a key-value pair with optional expiration
func (s *Storage) Set(key string, value []byte, exp time.Duration) error {
// SetWithContext stores a key-value pair with optional expiration with context support
func (s *Storage) SetWithContext(ctx context.Context, key string, value []byte, exp time.Duration) error {
// Validate key
if _, err := validateIdentifier(key, "key"); err != nil {
return err
@@ -256,11 +257,16 @@ func (s *Storage) Set(key string, value []byte, exp time.Duration) error {
"key": key,
"value": value,
"expires_at": expiresAt,
}).ExecRelease()
}).WithContext(ctx).ExecRelease()
}
// Get retrieves a value by key
func (s *Storage) Get(key string) ([]byte, error) {
// Set stores a key-value pair with optional expiration
func (s *Storage) Set(key string, value []byte, exp time.Duration) error {
return s.SetWithContext(context.Background(), key, value, exp)
}
// GetWithContext retrieves a value by key with context support.
func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) {
// Use query builder for select
stmt, names := qb.Select(fmt.Sprintf("%s.%s", s.keyspace, s.table)).
Columns("value", "expires_at").
@@ -271,7 +277,7 @@ func (s *Storage) Get(key string) ([]byte, error) {
// Use gocqlx session
if err := s.sx.Query(stmt, names).BindMap(map[string]interface{}{
"key": key,
}).GetRelease(&result); err != nil {
}).WithContext(ctx).GetRelease(&result); err != nil {
if err == gocql.ErrNotFound {
return nil, ErrNotFound
}
@@ -290,8 +296,13 @@ func (s *Storage) Get(key string) ([]byte, error) {
return result.Value, nil
}
// Delete removes a key from storage
func (s *Storage) Delete(key string) error {
// Get retrieves a value by key.
func (s *Storage) Get(key string) ([]byte, error) {
return s.GetWithContext(context.Background(), key)
}
// DeleteWithContext removes a key from storage with context support.
func (s *Storage) DeleteWithContext(ctx context.Context, key string) error {
// Use query builder for delete
stmt, names := qb.Delete(fmt.Sprintf("%s.%s", s.keyspace, s.table)).
Where(qb.Eq("key")).
@@ -300,14 +311,26 @@ func (s *Storage) Delete(key string) error {
// Use gocqlx session
return s.sx.Query(stmt, names).BindMap(map[string]interface{}{
"key": key,
}).ExecRelease()
}).WithContext(ctx).ExecRelease()
}
// Reset clears all keys from storage
func (s *Storage) Reset() error {
// Delete removes a key from storage.
func (s *Storage) Delete(key string) error {
// Use the context-free version
return s.DeleteWithContext(context.Background(), key)
}
// ResetWithContext clears all keys from storage with context support.
func (s *Storage) ResetWithContext(ctx context.Context) error {
// Use direct TRUNCATE query with proper escaping
query := fmt.Sprintf("TRUNCATE TABLE %s.%s", s.keyspace, s.table)
return s.sx.Query(query, []string{}).ExecRelease()
return s.sx.Query(query, []string{}).WithContext(ctx).ExecRelease()
}
// Reset clears all keys from storage.
func (s *Storage) Reset() error {
// Use the context-free version
return s.ResetWithContext(context.Background())
}
// Conn returns the underlying gocql session.

View File

@@ -66,6 +66,22 @@ func Test_Set(t *testing.T) {
require.Equal(t, []byte("value"), val)
}
func Test_SetWithContext(t *testing.T) {
store := newTestStore(t)
defer store.Close()
// Test SetWithContext
ctx, cancel := context.WithCancel(context.Background())
cancel()
err := store.SetWithContext(ctx, "test", []byte("value"), 0)
require.ErrorIs(t, err, context.Canceled)
// Verify the value was not set
val, err := store.Get("test")
require.Error(t, err)
require.Empty(t, val)
}
// Test_Get tests the Get operation
func Test_Get(t *testing.T) {
store := newTestStore(t)
@@ -86,6 +102,28 @@ func Test_Get(t *testing.T) {
require.Nil(t, val)
}
// Test_GetWithContext tests the Get operation with context
func Test_GetWithContext(t *testing.T) {
store := newTestStore(t)
defer store.Close()
// Set a value first
err := store.Set("test", []byte("value"), 0)
require.NoError(t, err)
// Test GetWithContext
ctx, cancel := context.WithCancel(context.Background())
cancel()
val, err := store.GetWithContext(ctx, "test")
require.ErrorIs(t, err, context.Canceled)
require.Nil(t, val)
// Verify the value still exists
val, err = store.Get("test")
require.NoError(t, err)
require.Equal(t, []byte("value"), val)
}
// Test_Delete tests the Delete operation
func Test_Delete(t *testing.T) {
store := newTestStore(t)
@@ -110,6 +148,32 @@ func Test_Delete(t *testing.T) {
require.Nil(t, val)
}
// Test_DeleteWithContext tests the Delete operation with context
func Test_DeleteWithContext(t *testing.T) {
store := newTestStore(t)
defer store.Close()
// Set a value first
err := store.Set("test", []byte("value"), 0)
require.NoError(t, err)
// Verify the value exists
val, err := store.Get("test")
require.NoError(t, err)
require.Equal(t, []byte("value"), val)
// Test DeleteWithContext
ctx, cancel := context.WithCancel(context.Background())
cancel()
err = store.DeleteWithContext(ctx, "test")
require.ErrorIs(t, err, context.Canceled)
// Verify the value still exists
val, err = store.Get("test")
require.NoError(t, err)
require.Equal(t, []byte("value"), val)
}
// Test_Expirable_Keys tests the expirable keys functionality
func Test_Expirable_Keys(t *testing.T) {
store := newTestStore(t)
@@ -180,6 +244,33 @@ func Test_Reset(t *testing.T) {
require.Nil(t, val)
}
// Test_ResetWithContext tests the Reset method with context
func Test_ResetWithContext(t *testing.T) {
store := newTestStore(t)
defer store.Close()
// Add some data
err := store.Set("test1", []byte("value1"), 0)
require.NoError(t, err)
err = store.Set("test2", []byte("value2"), 0)
require.NoError(t, err)
// Reset storage with context
ctx, cancel := context.WithCancel(context.Background())
cancel()
err = store.ResetWithContext(ctx)
require.ErrorIs(t, err, context.Canceled)
// Verify data is still there
val, err := store.Get("test1")
require.NoError(t, err)
require.Equal(t, []byte("value1"), val)
val, err = store.Get("test2")
require.NoError(t, err)
require.Equal(t, []byte("value2"), val)
}
// Test_Valid_Identifiers tests valid identifier cases
func Test_Valid_Identifiers(t *testing.T) {
store := newTestStore(t)

View File

@@ -1,11 +1,13 @@
package surrealdb
import (
"context"
"encoding/json"
"errors"
"time"
"github.com/surrealdb/surrealdb.go"
"github.com/surrealdb/surrealdb.go/pkg/models"
"time"
)
// Storage interface that is implemented by storage providers
@@ -17,12 +19,6 @@ type Storage struct {
}
// model represents a key-value storage record used in SurrealDB.
// It contains the key name, the stored byte value, and an optional expiration timestamp.
//
// Fields:
// - Key: the unique identifier for the stored item.
// - Body: the value stored as a byte slice (can represent any serialized data, such as JSON).
// - Exp: the expiration time as a Unix timestamp (0 means no expiration).
type model struct {
Key string `json:"key"`
Body []byte `json:"body"`
@@ -30,7 +26,6 @@ type model struct {
}
// New creates a new SurrealDB storage instance using the provided configuration.
// Returns an error if the connection or authentication fails.
func New(config ...Config) *Storage {
cfg := configDefault(config...)
db, err := surrealdb.New(cfg.ConnectionString)
@@ -67,6 +62,7 @@ func New(config ...Config) *Storage {
return storage
}
// Get returns the value by key
func (s *Storage) Get(key string) ([]byte, error) {
if len(key) == 0 {
return nil, errors.New("key is required")
@@ -86,6 +82,12 @@ func (s *Storage) Get(key string) ([]byte, error) {
return m.Body, nil
}
// GetWithContext dummy context support: calls Get ignoring ctx
func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) {
return s.Get(key)
}
// Set sets a value by key with optional expiration
func (s *Storage) Set(key string, val []byte, exp time.Duration) error {
if len(key) == 0 {
return errors.New("key is required")
@@ -96,7 +98,6 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error {
expiresAt = time.Now().Add(exp).Unix()
}
// Upsert is used instead of Create to allow overriding the same key if it already exists.
_, err := surrealdb.Upsert[model](s.db, models.NewRecordID(s.table, key), &model{
Key: key,
Body: val,
@@ -105,6 +106,12 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error {
return err
}
// SetWithContext dummy context support: calls Set ignoring ctx
func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error {
return s.Set(key, val, exp)
}
// Delete removes a key from storage
func (s *Storage) Delete(key string) error {
if len(key) == 0 {
return errors.New("key is required")
@@ -114,20 +121,34 @@ func (s *Storage) Delete(key string) error {
return err
}
// DeleteWithContext dummy context support: calls Delete ignoring ctx
func (s *Storage) DeleteWithContext(ctx context.Context, key string) error {
return s.Delete(key)
}
// Reset clears all keys in the storage table
func (s *Storage) Reset() error {
_, err := surrealdb.Delete[[]model](s.db, models.Table(s.table))
return err
}
// ResetWithContext dummy context support: calls Reset ignoring ctx
func (s *Storage) ResetWithContext(ctx context.Context) error {
return s.Reset()
}
// Close stops GC and closes the DB connection
func (s *Storage) Close() error {
close(s.stopGC)
return s.db.Close()
}
// Conn returns the underlying SurrealDB client
func (s *Storage) Conn() *surrealdb.DB {
return s.db
}
// List returns all stored keys and values as JSON
func (s *Storage) List() ([]byte, error) {
records, err := surrealdb.Select[[]model, models.Table](s.db, models.Table(s.table))
if err != nil {
@@ -148,6 +169,7 @@ func (s *Storage) List() ([]byte, error) {
return json.Marshal(data)
}
// gc runs periodic cleanup of expired keys
func (s *Storage) gc() {
ticker := time.NewTicker(s.interval)
defer ticker.Stop()
@@ -162,6 +184,7 @@ func (s *Storage) gc() {
}
}
// cleanupExpired deletes expired keys from storage
func (s *Storage) cleanupExpired() {
records, err := surrealdb.Select[[]model, models.Table](s.db, models.Table(s.table))
if err != nil {