mirror of
https://github.com/gofiber/storage.git
synced 2025-10-05 08:37:10 +08:00
Refactor and optimize code
This commit is contained in:
@@ -23,7 +23,7 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error
|
||||
func (s *Storage) Delete(key string) error
|
||||
func (s *Storage) Reset() error
|
||||
func (s *Storage) Close() error
|
||||
func (s *Storage) Conn() *Session
|
||||
func (s *Storage) Conn() *gocql.Session
|
||||
```
|
||||
|
||||
### Installation
|
||||
|
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
@@ -18,48 +19,68 @@ type Storage struct {
|
||||
ttl int
|
||||
}
|
||||
|
||||
// New creates a new Cassandra storage instance
|
||||
func New(cnfg Config) *Storage {
|
||||
var (
|
||||
// identifierPattern matches valid Cassandra identifiers
|
||||
identifierPattern = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
|
||||
)
|
||||
|
||||
// validateIdentifier checks if an identifier is valid
|
||||
func validateIdentifier(name, field string) (string, error) {
|
||||
if !identifierPattern.MatchString(name) {
|
||||
return "", fmt.Errorf("invalid %s name: must contain only alphanumeric characters and underscores", field)
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
// New creates a new Cassandra storage instance
|
||||
func New(cnfg Config) (*Storage, error) {
|
||||
// Default config
|
||||
cfg := configDefault(cnfg)
|
||||
|
||||
// Validate and escape identifiers
|
||||
keyspace, err := validateIdentifier(cfg.Keyspace, "keyspace")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
table, err := validateIdentifier(cfg.Table, "table")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create cluster config
|
||||
cluster := gocql.NewCluster(cfg.Hosts...)
|
||||
cluster.Consistency = cfg.Consistency
|
||||
|
||||
// Don't set keyspace initially - we need to create it first
|
||||
// We'll connect to system keyspace first
|
||||
|
||||
// Convert expiration to seconds for TTL
|
||||
ttl := 0
|
||||
if cfg.Expiration > 0 {
|
||||
ttl = int(cfg.Expiration.Seconds())
|
||||
} else if cfg.Expiration < 0 {
|
||||
// Expiration < 0 means indefinite storage
|
||||
cfg.Expiration = 0
|
||||
}
|
||||
|
||||
// Create storage instance
|
||||
storage := &Storage{
|
||||
cluster: cluster,
|
||||
keyspace: cfg.Keyspace,
|
||||
table: cfg.Table,
|
||||
keyspace: keyspace,
|
||||
table: table,
|
||||
ttl: ttl,
|
||||
}
|
||||
|
||||
// Initialize keyspace
|
||||
if err := storage.createOrVerifyKeySpace(cfg.Reset); err != nil {
|
||||
log.Printf("Failed to initialize keyspace: %v", err)
|
||||
panic(err)
|
||||
return nil, fmt.Errorf("cassandra storage init: %w", err)
|
||||
}
|
||||
|
||||
return storage
|
||||
return storage, nil
|
||||
}
|
||||
|
||||
// createOrVerifyKeySpace ensures the keyspace and table exist with proper keyspace
|
||||
func (s *Storage) createOrVerifyKeySpace(reset bool) error {
|
||||
// Connect to system keyspace first to create our keyspace if needed
|
||||
systemCluster := gocql.NewCluster(s.cluster.Hosts...)
|
||||
systemCluster.Consistency = s.cluster.Consistency
|
||||
systemCluster.Timeout = s.cluster.Timeout
|
||||
// Clone the original cluster config and set system keyspace
|
||||
systemCluster := *s.cluster
|
||||
systemCluster.Keyspace = "system"
|
||||
|
||||
// Connect to the system keyspace
|
||||
systemSession, err := systemCluster.CreateSession()
|
||||
@@ -153,7 +174,7 @@ func (s *Storage) dropTables() error {
|
||||
func (s *Storage) Set(key string, value []byte, exp time.Duration) error {
|
||||
// Calculate expiration time
|
||||
var expiresAt *time.Time
|
||||
var ttl = -1 // Default to no TTL
|
||||
var ttl int
|
||||
|
||||
if exp > 0 {
|
||||
// Specific expiration provided
|
||||
@@ -166,7 +187,7 @@ func (s *Storage) Set(key string, value []byte, exp time.Duration) error {
|
||||
t := time.Now().Add(time.Duration(s.ttl) * time.Second)
|
||||
expiresAt = &t
|
||||
}
|
||||
// If exp < 0, we'll use no TTL (indefinite storage)
|
||||
// If exp == 0 and s.ttl == 0, no TTL will be set (live forever)
|
||||
|
||||
// Insert with TTL if specified
|
||||
var query string
|
||||
@@ -219,6 +240,11 @@ func (s *Storage) Reset() error {
|
||||
return s.session.Query(query).Exec()
|
||||
}
|
||||
|
||||
// Conn returns the underlying gocql session.
|
||||
func (s *Storage) Conn() *gocql.Session {
|
||||
return s.session
|
||||
}
|
||||
|
||||
// Close closes the storage connection
|
||||
func (s *Storage) Close() {
|
||||
if s.session != nil {
|
||||
|
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -51,63 +52,72 @@ func TestCassandraStorage(t *testing.T) {
|
||||
// Test cases
|
||||
t.Run("KeyspaceCreation", func(t *testing.T) {
|
||||
connectionURL := newTestStore(t)
|
||||
store := New(Config{
|
||||
store, err := New(Config{
|
||||
Hosts: []string{connectionURL},
|
||||
Keyspace: "test_keyspace_creation",
|
||||
Table: "test_kv",
|
||||
Expiration: 5 * time.Second,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer store.Close()
|
||||
testKeyspaceCreation(t, connectionURL, store)
|
||||
})
|
||||
|
||||
t.Run("BasicOperations", func(t *testing.T) {
|
||||
connectionURL := newTestStore(t)
|
||||
store := New(Config{
|
||||
store, err := New(Config{
|
||||
Hosts: []string{connectionURL},
|
||||
Keyspace: "test_basic_ops",
|
||||
Table: "test_kv",
|
||||
Expiration: 5 * time.Second,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer store.Close()
|
||||
testBasicOperations(t, store)
|
||||
})
|
||||
|
||||
t.Run("ExpirableKeys", func(t *testing.T) {
|
||||
connectionURL := newTestStore(t)
|
||||
store := New(Config{
|
||||
store, err := New(Config{
|
||||
Hosts: []string{connectionURL},
|
||||
Keyspace: "test_expirable",
|
||||
Table: "test_kv",
|
||||
Expiration: 5 * time.Second,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer store.Close()
|
||||
testExpirableKeys(t, store)
|
||||
})
|
||||
|
||||
t.Run("ConcurrentAccess", func(t *testing.T) {
|
||||
connectionURL := newTestStore(t)
|
||||
store := New(Config{
|
||||
store, err := New(Config{
|
||||
Hosts: []string{connectionURL},
|
||||
Keyspace: "test_concurrent",
|
||||
Table: "test_kv",
|
||||
Expiration: 5 * time.Second,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer store.Close()
|
||||
testConcurrentAccess(t, store)
|
||||
})
|
||||
|
||||
t.Run("Reset", func(t *testing.T) {
|
||||
connectionURL := newTestStore(t)
|
||||
store := New(Config{
|
||||
store, err := New(Config{
|
||||
Hosts: []string{connectionURL},
|
||||
Keyspace: "test_reset",
|
||||
Table: "test_kv",
|
||||
Expiration: 5 * time.Second,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer store.Close()
|
||||
testReset(t, connectionURL, store)
|
||||
})
|
||||
|
||||
t.Run("IdentifierValidation", func(t *testing.T) {
|
||||
testIdentifierValidation(t)
|
||||
})
|
||||
}
|
||||
|
||||
// testKeyspaceCreation tests the keyspace creation functionality.
|
||||
@@ -171,7 +181,6 @@ func testBasicOperations(t *testing.T, store *Storage) {
|
||||
|
||||
// testExpirableKeys tests the expirable keys functionality.
|
||||
func testExpirableKeys(t *testing.T, store *Storage) {
|
||||
|
||||
// Set keys with different expiration settings
|
||||
// Key with default TTL (exp = 0 means use default)
|
||||
err := store.Set("key_default_ttl", []byte("value1"), 0)
|
||||
@@ -199,37 +208,37 @@ func testExpirableKeys(t *testing.T, store *Storage) {
|
||||
require.Equal(t, []byte("value3"), value)
|
||||
|
||||
// Wait for specific TTL to expire
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
|
||||
// Specific TTL key should be gone, others should remain
|
||||
value, err = store.Get("key_specific_ttl")
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, value, "Key with 1s TTL should have expired")
|
||||
require.Eventually(t, func() bool {
|
||||
v, _ := store.Get("key_specific_ttl")
|
||||
return v == nil
|
||||
}, 3*time.Second, 100*time.Millisecond,
|
||||
"Key with 1s TTL should have expired")
|
||||
|
||||
// Default TTL key should still exist
|
||||
value, err = store.Get("key_default_ttl")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("value1"), value, "Key with default TTL should still exist")
|
||||
|
||||
// No TTL key should still exist
|
||||
value, err = store.Get("key_no_ttl")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("value3"), value, "Key with no TTL should still exist")
|
||||
|
||||
// Wait for default TTL to expire
|
||||
time.Sleep(4 * time.Second)
|
||||
|
||||
// Default TTL key should be gone, no TTL key should remain
|
||||
value, err = store.Get("key_default_ttl")
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, value, "Key with default TTL should have expired")
|
||||
require.Eventually(t, func() bool {
|
||||
v, _ := store.Get("key_default_ttl")
|
||||
return v == nil
|
||||
}, 6*time.Second, 100*time.Millisecond,
|
||||
"Key with default TTL should have expired")
|
||||
|
||||
// No TTL key should still exist
|
||||
value, err = store.Get("key_no_ttl")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("value3"), value, "Key with no TTL should still exist")
|
||||
}
|
||||
|
||||
// / testReset tests the Reset method.
|
||||
// testReset tests the Reset method.
|
||||
func testReset(t *testing.T, connectionURL string, store *Storage) {
|
||||
|
||||
// Set some keys
|
||||
err := store.Set("key1", []byte("value1"), 0)
|
||||
require.NoError(t, err)
|
||||
@@ -255,35 +264,40 @@ func testReset(t *testing.T, connectionURL string, store *Storage) {
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, value, "Key should be deleted after reset")
|
||||
|
||||
// Close the first store before creating a new one
|
||||
store.Close()
|
||||
|
||||
// Create new storage with Reset flag
|
||||
store = New(Config{
|
||||
newStore, err := New(Config{
|
||||
Hosts: []string{connectionURL},
|
||||
Keyspace: "test_reset",
|
||||
Table: "test_kv",
|
||||
Reset: true,
|
||||
})
|
||||
defer store.Close()
|
||||
require.NoError(t, err)
|
||||
defer newStore.Close()
|
||||
|
||||
// Set a key
|
||||
err = store.Set("key3", []byte("value3"), 0)
|
||||
err = newStore.Set("key3", []byte("value3"), 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify key exists
|
||||
value, err = store.Get("key3")
|
||||
value, err = newStore.Get("key3")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("value3"), value)
|
||||
}
|
||||
|
||||
// testConcurrentAccess tests concurrent access to the storage.
|
||||
func testConcurrentAccess(t *testing.T, store *Storage) {
|
||||
|
||||
// Number of goroutines
|
||||
const concurrentOps = 10
|
||||
done := make(chan bool, concurrentOps)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(concurrentOps)
|
||||
|
||||
// Run concurrent operations
|
||||
for i := 0; i < concurrentOps; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
// Set key
|
||||
key := fmt.Sprintf("key%d", id)
|
||||
value := []byte(fmt.Sprintf("value%d", id))
|
||||
@@ -303,33 +317,28 @@ func testConcurrentAccess(t *testing.T, store *Storage) {
|
||||
retrievedValue, err = store.Get(key)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, retrievedValue)
|
||||
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for i := 0; i < concurrentOps; i++ {
|
||||
<-done
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func Benchmark_Cassandra_Set(b *testing.B) {
|
||||
|
||||
connectionURL := newTestStore(b)
|
||||
|
||||
// Create new storage
|
||||
store := New(Config{
|
||||
store, err := New(Config{
|
||||
Hosts: []string{connectionURL},
|
||||
Keyspace: "test_concurrent",
|
||||
Table: "test_kv",
|
||||
})
|
||||
require.NoError(b, err)
|
||||
defer store.Close()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
var err error
|
||||
for i := 0; i < b.N; i++ {
|
||||
err = store.Set("john", []byte("doe"), 0)
|
||||
}
|
||||
@@ -338,21 +347,22 @@ func Benchmark_Cassandra_Set(b *testing.B) {
|
||||
}
|
||||
|
||||
func Benchmark_Cassandra_Get(b *testing.B) {
|
||||
|
||||
connectionURL := newTestStore(b)
|
||||
|
||||
// Create new storage
|
||||
client := New(Config{
|
||||
client, err := New(Config{
|
||||
Hosts: []string{connectionURL},
|
||||
Keyspace: "test_concurrent",
|
||||
Table: "test_kv",
|
||||
})
|
||||
require.NoError(b, err)
|
||||
defer client.Close()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
err := client.Set("john", []byte("doe"), 0)
|
||||
err = client.Set("john", []byte("doe"), 0)
|
||||
require.NoError(b, err)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err = client.Get("john")
|
||||
@@ -362,21 +372,20 @@ func Benchmark_Cassandra_Get(b *testing.B) {
|
||||
}
|
||||
|
||||
func Benchmark_Cassandra_Set_And_Delete(b *testing.B) {
|
||||
|
||||
connectionURL := newTestStore(b)
|
||||
|
||||
// Create new storage
|
||||
client := New(Config{
|
||||
client, err := New(Config{
|
||||
Hosts: []string{connectionURL},
|
||||
Keyspace: "test_concurrent",
|
||||
Table: "test_kv",
|
||||
})
|
||||
require.NoError(b, err)
|
||||
defer client.Close()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
var err error
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = client.Set("john", []byte("doe"), 0)
|
||||
err = client.Delete("john")
|
||||
@@ -384,3 +393,47 @@ func Benchmark_Cassandra_Set_And_Delete(b *testing.B) {
|
||||
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
// testIdentifierValidation tests the validateIdentifier function
|
||||
func testIdentifierValidation(t *testing.T) {
|
||||
// Test valid identifiers
|
||||
validCases := []string{
|
||||
"test",
|
||||
"test123",
|
||||
"test_123",
|
||||
"TEST",
|
||||
"Test123",
|
||||
"test_table",
|
||||
}
|
||||
|
||||
for _, tc := range validCases {
|
||||
t.Run(fmt.Sprintf("valid_%s", tc), func(t *testing.T) {
|
||||
result, err := validateIdentifier(tc, "test")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc, result)
|
||||
})
|
||||
}
|
||||
|
||||
// Test invalid identifiers
|
||||
invalidCases := []struct {
|
||||
name string
|
||||
value string
|
||||
}{
|
||||
{"empty", ""},
|
||||
{"space", "test table"},
|
||||
{"hyphen", "test-table"},
|
||||
{"dot", "test.table"},
|
||||
{"quote", `test"table`},
|
||||
{"semicolon", "test;table"},
|
||||
{"sql_injection", `test"; DROP KEYSPACE prod; --`},
|
||||
{"unicode", "test表"},
|
||||
}
|
||||
|
||||
for _, tc := range invalidCases {
|
||||
t.Run(fmt.Sprintf("invalid_%s", tc.name), func(t *testing.T) {
|
||||
_, err := validateIdentifier(tc.value, "test")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid test name")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -38,7 +38,7 @@ var ConfigDefault = Config{
|
||||
Expiration: 10 * time.Minute,
|
||||
}
|
||||
|
||||
// ConfigDefault is the Helper function to apply default config
|
||||
// configDefault applies `ConfigDefault` values to a user‑supplied Config.
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
@@ -67,6 +67,9 @@ func configDefault(config ...Config) Config {
|
||||
|
||||
if cfg.Expiration == 0 {
|
||||
cfg.Expiration = ConfigDefault.Expiration
|
||||
} else if cfg.Expiration < 0 {
|
||||
// Disallow negative expirations – they produce invalid TTLs.
|
||||
cfg.Expiration = 0
|
||||
}
|
||||
|
||||
return cfg
|
||||
|
Reference in New Issue
Block a user