Refactor and optimize code

This commit is contained in:
MitulShah1
2025-04-20 12:16:35 +05:30
parent bd94f13340
commit 40586cc9df
4 changed files with 140 additions and 58 deletions

View File

@@ -23,7 +23,7 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error
func (s *Storage) Delete(key string) error func (s *Storage) Delete(key string) error
func (s *Storage) Reset() error func (s *Storage) Reset() error
func (s *Storage) Close() error func (s *Storage) Close() error
func (s *Storage) Conn() *Session func (s *Storage) Conn() *gocql.Session
``` ```
### Installation ### Installation

View File

@@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"log" "log"
"regexp"
"time" "time"
"github.com/gocql/gocql" "github.com/gocql/gocql"
@@ -18,48 +19,68 @@ type Storage struct {
ttl int ttl int
} }
// New creates a new Cassandra storage instance var (
func New(cnfg Config) *Storage { // identifierPattern matches valid Cassandra identifiers
identifierPattern = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
)
// validateIdentifier checks if an identifier is valid
func validateIdentifier(name, field string) (string, error) {
if !identifierPattern.MatchString(name) {
return "", fmt.Errorf("invalid %s name: must contain only alphanumeric characters and underscores", field)
}
return name, nil
}
// New creates a new Cassandra storage instance
func New(cnfg Config) (*Storage, error) {
// Default config // Default config
cfg := configDefault(cnfg) cfg := configDefault(cnfg)
// Validate and escape identifiers
keyspace, err := validateIdentifier(cfg.Keyspace, "keyspace")
if err != nil {
return nil, err
}
table, err := validateIdentifier(cfg.Table, "table")
if err != nil {
return nil, err
}
// Create cluster config // Create cluster config
cluster := gocql.NewCluster(cfg.Hosts...) cluster := gocql.NewCluster(cfg.Hosts...)
cluster.Consistency = cfg.Consistency cluster.Consistency = cfg.Consistency
// Don't set keyspace initially - we need to create it first
// We'll connect to system keyspace first
// Convert expiration to seconds for TTL // Convert expiration to seconds for TTL
ttl := 0 ttl := 0
if cfg.Expiration > 0 { if cfg.Expiration > 0 {
ttl = int(cfg.Expiration.Seconds()) ttl = int(cfg.Expiration.Seconds())
} else if cfg.Expiration < 0 {
// Expiration < 0 means indefinite storage
cfg.Expiration = 0
} }
// Create storage instance // Create storage instance
storage := &Storage{ storage := &Storage{
cluster: cluster, cluster: cluster,
keyspace: cfg.Keyspace, keyspace: keyspace,
table: cfg.Table, table: table,
ttl: ttl, ttl: ttl,
} }
// Initialize keyspace // Initialize keyspace
if err := storage.createOrVerifyKeySpace(cfg.Reset); err != nil { if err := storage.createOrVerifyKeySpace(cfg.Reset); err != nil {
log.Printf("Failed to initialize keyspace: %v", err) return nil, fmt.Errorf("cassandra storage init: %w", err)
panic(err)
} }
return storage return storage, nil
} }
// createOrVerifyKeySpace ensures the keyspace and table exist with proper keyspace // createOrVerifyKeySpace ensures the keyspace and table exist with proper keyspace
func (s *Storage) createOrVerifyKeySpace(reset bool) error { func (s *Storage) createOrVerifyKeySpace(reset bool) error {
// Connect to system keyspace first to create our keyspace if needed // Clone the original cluster config and set system keyspace
systemCluster := gocql.NewCluster(s.cluster.Hosts...) systemCluster := *s.cluster
systemCluster.Consistency = s.cluster.Consistency systemCluster.Keyspace = "system"
systemCluster.Timeout = s.cluster.Timeout
// Connect to the system keyspace // Connect to the system keyspace
systemSession, err := systemCluster.CreateSession() systemSession, err := systemCluster.CreateSession()
@@ -153,7 +174,7 @@ func (s *Storage) dropTables() error {
func (s *Storage) Set(key string, value []byte, exp time.Duration) error { func (s *Storage) Set(key string, value []byte, exp time.Duration) error {
// Calculate expiration time // Calculate expiration time
var expiresAt *time.Time var expiresAt *time.Time
var ttl = -1 // Default to no TTL var ttl int
if exp > 0 { if exp > 0 {
// Specific expiration provided // Specific expiration provided
@@ -166,7 +187,7 @@ func (s *Storage) Set(key string, value []byte, exp time.Duration) error {
t := time.Now().Add(time.Duration(s.ttl) * time.Second) t := time.Now().Add(time.Duration(s.ttl) * time.Second)
expiresAt = &t expiresAt = &t
} }
// If exp < 0, we'll use no TTL (indefinite storage) // If exp == 0 and s.ttl == 0, no TTL will be set (live forever)
// Insert with TTL if specified // Insert with TTL if specified
var query string var query string
@@ -219,6 +240,11 @@ func (s *Storage) Reset() error {
return s.session.Query(query).Exec() return s.session.Query(query).Exec()
} }
// Conn returns the underlying gocql session.
func (s *Storage) Conn() *gocql.Session {
return s.session
}
// Close closes the storage connection // Close closes the storage connection
func (s *Storage) Close() { func (s *Storage) Close() {
if s.session != nil { if s.session != nil {

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"sync"
"testing" "testing"
"time" "time"
@@ -51,63 +52,72 @@ func TestCassandraStorage(t *testing.T) {
// Test cases // Test cases
t.Run("KeyspaceCreation", func(t *testing.T) { t.Run("KeyspaceCreation", func(t *testing.T) {
connectionURL := newTestStore(t) connectionURL := newTestStore(t)
store := New(Config{ store, err := New(Config{
Hosts: []string{connectionURL}, Hosts: []string{connectionURL},
Keyspace: "test_keyspace_creation", Keyspace: "test_keyspace_creation",
Table: "test_kv", Table: "test_kv",
Expiration: 5 * time.Second, Expiration: 5 * time.Second,
}) })
require.NoError(t, err)
defer store.Close() defer store.Close()
testKeyspaceCreation(t, connectionURL, store) testKeyspaceCreation(t, connectionURL, store)
}) })
t.Run("BasicOperations", func(t *testing.T) { t.Run("BasicOperations", func(t *testing.T) {
connectionURL := newTestStore(t) connectionURL := newTestStore(t)
store := New(Config{ store, err := New(Config{
Hosts: []string{connectionURL}, Hosts: []string{connectionURL},
Keyspace: "test_basic_ops", Keyspace: "test_basic_ops",
Table: "test_kv", Table: "test_kv",
Expiration: 5 * time.Second, Expiration: 5 * time.Second,
}) })
require.NoError(t, err)
defer store.Close() defer store.Close()
testBasicOperations(t, store) testBasicOperations(t, store)
}) })
t.Run("ExpirableKeys", func(t *testing.T) { t.Run("ExpirableKeys", func(t *testing.T) {
connectionURL := newTestStore(t) connectionURL := newTestStore(t)
store := New(Config{ store, err := New(Config{
Hosts: []string{connectionURL}, Hosts: []string{connectionURL},
Keyspace: "test_expirable", Keyspace: "test_expirable",
Table: "test_kv", Table: "test_kv",
Expiration: 5 * time.Second, Expiration: 5 * time.Second,
}) })
require.NoError(t, err)
defer store.Close() defer store.Close()
testExpirableKeys(t, store) testExpirableKeys(t, store)
}) })
t.Run("ConcurrentAccess", func(t *testing.T) { t.Run("ConcurrentAccess", func(t *testing.T) {
connectionURL := newTestStore(t) connectionURL := newTestStore(t)
store := New(Config{ store, err := New(Config{
Hosts: []string{connectionURL}, Hosts: []string{connectionURL},
Keyspace: "test_concurrent", Keyspace: "test_concurrent",
Table: "test_kv", Table: "test_kv",
Expiration: 5 * time.Second, Expiration: 5 * time.Second,
}) })
require.NoError(t, err)
defer store.Close() defer store.Close()
testConcurrentAccess(t, store) testConcurrentAccess(t, store)
}) })
t.Run("Reset", func(t *testing.T) { t.Run("Reset", func(t *testing.T) {
connectionURL := newTestStore(t) connectionURL := newTestStore(t)
store := New(Config{ store, err := New(Config{
Hosts: []string{connectionURL}, Hosts: []string{connectionURL},
Keyspace: "test_reset", Keyspace: "test_reset",
Table: "test_kv", Table: "test_kv",
Expiration: 5 * time.Second, Expiration: 5 * time.Second,
}) })
require.NoError(t, err)
defer store.Close() defer store.Close()
testReset(t, connectionURL, store) testReset(t, connectionURL, store)
}) })
t.Run("IdentifierValidation", func(t *testing.T) {
testIdentifierValidation(t)
})
} }
// testKeyspaceCreation tests the keyspace creation functionality. // testKeyspaceCreation tests the keyspace creation functionality.
@@ -171,7 +181,6 @@ func testBasicOperations(t *testing.T, store *Storage) {
// testExpirableKeys tests the expirable keys functionality. // testExpirableKeys tests the expirable keys functionality.
func testExpirableKeys(t *testing.T, store *Storage) { func testExpirableKeys(t *testing.T, store *Storage) {
// Set keys with different expiration settings // Set keys with different expiration settings
// Key with default TTL (exp = 0 means use default) // Key with default TTL (exp = 0 means use default)
err := store.Set("key_default_ttl", []byte("value1"), 0) err := store.Set("key_default_ttl", []byte("value1"), 0)
@@ -199,37 +208,37 @@ func testExpirableKeys(t *testing.T, store *Storage) {
require.Equal(t, []byte("value3"), value) require.Equal(t, []byte("value3"), value)
// Wait for specific TTL to expire // Wait for specific TTL to expire
time.Sleep(1500 * time.Millisecond) require.Eventually(t, func() bool {
v, _ := store.Get("key_specific_ttl")
// Specific TTL key should be gone, others should remain return v == nil
value, err = store.Get("key_specific_ttl") }, 3*time.Second, 100*time.Millisecond,
require.NoError(t, err) "Key with 1s TTL should have expired")
require.Nil(t, value, "Key with 1s TTL should have expired")
// Default TTL key should still exist
value, err = store.Get("key_default_ttl") value, err = store.Get("key_default_ttl")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, []byte("value1"), value, "Key with default TTL should still exist") require.Equal(t, []byte("value1"), value, "Key with default TTL should still exist")
// No TTL key should still exist
value, err = store.Get("key_no_ttl") value, err = store.Get("key_no_ttl")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, []byte("value3"), value, "Key with no TTL should still exist") require.Equal(t, []byte("value3"), value, "Key with no TTL should still exist")
// Wait for default TTL to expire // Wait for default TTL to expire
time.Sleep(4 * time.Second) require.Eventually(t, func() bool {
v, _ := store.Get("key_default_ttl")
// Default TTL key should be gone, no TTL key should remain return v == nil
value, err = store.Get("key_default_ttl") }, 6*time.Second, 100*time.Millisecond,
require.NoError(t, err) "Key with default TTL should have expired")
require.Nil(t, value, "Key with default TTL should have expired")
// No TTL key should still exist
value, err = store.Get("key_no_ttl") value, err = store.Get("key_no_ttl")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, []byte("value3"), value, "Key with no TTL should still exist") require.Equal(t, []byte("value3"), value, "Key with no TTL should still exist")
} }
// / testReset tests the Reset method. // testReset tests the Reset method.
func testReset(t *testing.T, connectionURL string, store *Storage) { func testReset(t *testing.T, connectionURL string, store *Storage) {
// Set some keys // Set some keys
err := store.Set("key1", []byte("value1"), 0) err := store.Set("key1", []byte("value1"), 0)
require.NoError(t, err) require.NoError(t, err)
@@ -255,35 +264,40 @@ func testReset(t *testing.T, connectionURL string, store *Storage) {
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, value, "Key should be deleted after reset") require.Nil(t, value, "Key should be deleted after reset")
// Close the first store before creating a new one
store.Close()
// Create new storage with Reset flag // Create new storage with Reset flag
store = New(Config{ newStore, err := New(Config{
Hosts: []string{connectionURL}, Hosts: []string{connectionURL},
Keyspace: "test_reset", Keyspace: "test_reset",
Table: "test_kv", Table: "test_kv",
Reset: true, Reset: true,
}) })
defer store.Close() require.NoError(t, err)
defer newStore.Close()
// Set a key // Set a key
err = store.Set("key3", []byte("value3"), 0) err = newStore.Set("key3", []byte("value3"), 0)
require.NoError(t, err) require.NoError(t, err)
// Verify key exists // Verify key exists
value, err = store.Get("key3") value, err = newStore.Get("key3")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, []byte("value3"), value) require.Equal(t, []byte("value3"), value)
} }
// testConcurrentAccess tests concurrent access to the storage. // testConcurrentAccess tests concurrent access to the storage.
func testConcurrentAccess(t *testing.T, store *Storage) { func testConcurrentAccess(t *testing.T, store *Storage) {
// Number of goroutines // Number of goroutines
const concurrentOps = 10 const concurrentOps = 10
done := make(chan bool, concurrentOps) var wg sync.WaitGroup
wg.Add(concurrentOps)
// Run concurrent operations // Run concurrent operations
for i := 0; i < concurrentOps; i++ { for i := 0; i < concurrentOps; i++ {
go func(id int) { go func(id int) {
defer wg.Done()
// Set key // Set key
key := fmt.Sprintf("key%d", id) key := fmt.Sprintf("key%d", id)
value := []byte(fmt.Sprintf("value%d", id)) value := []byte(fmt.Sprintf("value%d", id))
@@ -303,33 +317,28 @@ func testConcurrentAccess(t *testing.T, store *Storage) {
retrievedValue, err = store.Get(key) retrievedValue, err = store.Get(key)
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, retrievedValue) require.Nil(t, retrievedValue)
done <- true
}(i) }(i)
} }
// Wait for all goroutines to complete // Wait for all goroutines to complete
for i := 0; i < concurrentOps; i++ { wg.Wait()
<-done
}
} }
func Benchmark_Cassandra_Set(b *testing.B) { func Benchmark_Cassandra_Set(b *testing.B) {
connectionURL := newTestStore(b) connectionURL := newTestStore(b)
// Create new storage // Create new storage
store := New(Config{ store, err := New(Config{
Hosts: []string{connectionURL}, Hosts: []string{connectionURL},
Keyspace: "test_concurrent", Keyspace: "test_concurrent",
Table: "test_kv", Table: "test_kv",
}) })
require.NoError(b, err)
defer store.Close() defer store.Close()
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
var err error
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
err = store.Set("john", []byte("doe"), 0) err = store.Set("john", []byte("doe"), 0)
} }
@@ -338,21 +347,22 @@ func Benchmark_Cassandra_Set(b *testing.B) {
} }
func Benchmark_Cassandra_Get(b *testing.B) { func Benchmark_Cassandra_Get(b *testing.B) {
connectionURL := newTestStore(b) connectionURL := newTestStore(b)
// Create new storage // Create new storage
client := New(Config{ client, err := New(Config{
Hosts: []string{connectionURL}, Hosts: []string{connectionURL},
Keyspace: "test_concurrent", Keyspace: "test_concurrent",
Table: "test_kv", Table: "test_kv",
}) })
require.NoError(b, err)
defer client.Close() defer client.Close()
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
err := client.Set("john", []byte("doe"), 0) err = client.Set("john", []byte("doe"), 0)
require.NoError(b, err)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, err = client.Get("john") _, err = client.Get("john")
@@ -362,21 +372,20 @@ func Benchmark_Cassandra_Get(b *testing.B) {
} }
func Benchmark_Cassandra_Set_And_Delete(b *testing.B) { func Benchmark_Cassandra_Set_And_Delete(b *testing.B) {
connectionURL := newTestStore(b) connectionURL := newTestStore(b)
// Create new storage // Create new storage
client := New(Config{ client, err := New(Config{
Hosts: []string{connectionURL}, Hosts: []string{connectionURL},
Keyspace: "test_concurrent", Keyspace: "test_concurrent",
Table: "test_kv", Table: "test_kv",
}) })
require.NoError(b, err)
defer client.Close() defer client.Close()
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
var err error
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_ = client.Set("john", []byte("doe"), 0) _ = client.Set("john", []byte("doe"), 0)
err = client.Delete("john") err = client.Delete("john")
@@ -384,3 +393,47 @@ func Benchmark_Cassandra_Set_And_Delete(b *testing.B) {
require.NoError(b, err) require.NoError(b, err)
} }
// testIdentifierValidation tests the validateIdentifier function
func testIdentifierValidation(t *testing.T) {
// Test valid identifiers
validCases := []string{
"test",
"test123",
"test_123",
"TEST",
"Test123",
"test_table",
}
for _, tc := range validCases {
t.Run(fmt.Sprintf("valid_%s", tc), func(t *testing.T) {
result, err := validateIdentifier(tc, "test")
require.NoError(t, err)
require.Equal(t, tc, result)
})
}
// Test invalid identifiers
invalidCases := []struct {
name string
value string
}{
{"empty", ""},
{"space", "test table"},
{"hyphen", "test-table"},
{"dot", "test.table"},
{"quote", `test"table`},
{"semicolon", "test;table"},
{"sql_injection", `test"; DROP KEYSPACE prod; --`},
{"unicode", "test表"},
}
for _, tc := range invalidCases {
t.Run(fmt.Sprintf("invalid_%s", tc.name), func(t *testing.T) {
_, err := validateIdentifier(tc.value, "test")
require.Error(t, err)
require.Contains(t, err.Error(), "invalid test name")
})
}
}

View File

@@ -38,7 +38,7 @@ var ConfigDefault = Config{
Expiration: 10 * time.Minute, Expiration: 10 * time.Minute,
} }
// ConfigDefault is the Helper function to apply default config // configDefault applies `ConfigDefault` values to a usersupplied Config.
func configDefault(config ...Config) Config { func configDefault(config ...Config) Config {
// Return default config if nothing provided // Return default config if nothing provided
if len(config) < 1 { if len(config) < 1 {
@@ -67,6 +67,9 @@ func configDefault(config ...Config) Config {
if cfg.Expiration == 0 { if cfg.Expiration == 0 {
cfg.Expiration = ConfigDefault.Expiration cfg.Expiration = ConfigDefault.Expiration
} else if cfg.Expiration < 0 {
// Disallow negative expirations they produce invalid TTLs.
cfg.Expiration = 0
} }
return cfg return cfg