diff --git a/cassandra/README.md b/cassandra/README.md index b68699d7..a4412cb4 100644 --- a/cassandra/README.md +++ b/cassandra/README.md @@ -23,7 +23,7 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error func (s *Storage) Delete(key string) error func (s *Storage) Reset() error func (s *Storage) Close() error -func (s *Storage) Conn() *Session +func (s *Storage) Conn() *gocql.Session ``` ### Installation diff --git a/cassandra/cassandra.go b/cassandra/cassandra.go index fc61adc9..4cd5ea65 100644 --- a/cassandra/cassandra.go +++ b/cassandra/cassandra.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "log" + "regexp" "time" "github.com/gocql/gocql" @@ -18,48 +19,68 @@ type Storage struct { ttl int } -// New creates a new Cassandra storage instance -func New(cnfg Config) *Storage { +var ( + // identifierPattern matches valid Cassandra identifiers + identifierPattern = regexp.MustCompile(`^[a-zA-Z0-9_]+$`) +) +// validateIdentifier checks if an identifier is valid +func validateIdentifier(name, field string) (string, error) { + if !identifierPattern.MatchString(name) { + return "", fmt.Errorf("invalid %s name: must contain only alphanumeric characters and underscores", field) + } + return name, nil +} + +// New creates a new Cassandra storage instance +func New(cnfg Config) (*Storage, error) { // Default config cfg := configDefault(cnfg) + // Validate and escape identifiers + keyspace, err := validateIdentifier(cfg.Keyspace, "keyspace") + if err != nil { + return nil, err + } + table, err := validateIdentifier(cfg.Table, "table") + if err != nil { + return nil, err + } + // Create cluster config cluster := gocql.NewCluster(cfg.Hosts...) cluster.Consistency = cfg.Consistency - // Don't set keyspace initially - we need to create it first - // We'll connect to system keyspace first - // Convert expiration to seconds for TTL ttl := 0 if cfg.Expiration > 0 { ttl = int(cfg.Expiration.Seconds()) + } else if cfg.Expiration < 0 { + // Expiration < 0 means indefinite storage + cfg.Expiration = 0 } // Create storage instance storage := &Storage{ cluster: cluster, - keyspace: cfg.Keyspace, - table: cfg.Table, + keyspace: keyspace, + table: table, ttl: ttl, } // Initialize keyspace if err := storage.createOrVerifyKeySpace(cfg.Reset); err != nil { - log.Printf("Failed to initialize keyspace: %v", err) - panic(err) + return nil, fmt.Errorf("cassandra storage init: %w", err) } - return storage + return storage, nil } // createOrVerifyKeySpace ensures the keyspace and table exist with proper keyspace func (s *Storage) createOrVerifyKeySpace(reset bool) error { - // Connect to system keyspace first to create our keyspace if needed - systemCluster := gocql.NewCluster(s.cluster.Hosts...) - systemCluster.Consistency = s.cluster.Consistency - systemCluster.Timeout = s.cluster.Timeout + // Clone the original cluster config and set system keyspace + systemCluster := *s.cluster + systemCluster.Keyspace = "system" // Connect to the system keyspace systemSession, err := systemCluster.CreateSession() @@ -153,7 +174,7 @@ func (s *Storage) dropTables() error { func (s *Storage) Set(key string, value []byte, exp time.Duration) error { // Calculate expiration time var expiresAt *time.Time - var ttl = -1 // Default to no TTL + var ttl int if exp > 0 { // Specific expiration provided @@ -166,7 +187,7 @@ func (s *Storage) Set(key string, value []byte, exp time.Duration) error { t := time.Now().Add(time.Duration(s.ttl) * time.Second) expiresAt = &t } - // If exp < 0, we'll use no TTL (indefinite storage) + // If exp == 0 and s.ttl == 0, no TTL will be set (live forever) // Insert with TTL if specified var query string @@ -219,6 +240,11 @@ func (s *Storage) Reset() error { return s.session.Query(query).Exec() } +// Conn returns the underlying gocql session. +func (s *Storage) Conn() *gocql.Session { + return s.session +} + // Close closes the storage connection func (s *Storage) Close() { if s.session != nil { diff --git a/cassandra/cassandra_test.go b/cassandra/cassandra_test.go index 21d6e9a4..26705f9a 100644 --- a/cassandra/cassandra_test.go +++ b/cassandra/cassandra_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "sync" "testing" "time" @@ -51,63 +52,72 @@ func TestCassandraStorage(t *testing.T) { // Test cases t.Run("KeyspaceCreation", func(t *testing.T) { connectionURL := newTestStore(t) - store := New(Config{ + store, err := New(Config{ Hosts: []string{connectionURL}, Keyspace: "test_keyspace_creation", Table: "test_kv", Expiration: 5 * time.Second, }) + require.NoError(t, err) defer store.Close() testKeyspaceCreation(t, connectionURL, store) }) t.Run("BasicOperations", func(t *testing.T) { connectionURL := newTestStore(t) - store := New(Config{ + store, err := New(Config{ Hosts: []string{connectionURL}, Keyspace: "test_basic_ops", Table: "test_kv", Expiration: 5 * time.Second, }) + require.NoError(t, err) defer store.Close() testBasicOperations(t, store) }) t.Run("ExpirableKeys", func(t *testing.T) { connectionURL := newTestStore(t) - store := New(Config{ + store, err := New(Config{ Hosts: []string{connectionURL}, Keyspace: "test_expirable", Table: "test_kv", Expiration: 5 * time.Second, }) + require.NoError(t, err) defer store.Close() testExpirableKeys(t, store) }) t.Run("ConcurrentAccess", func(t *testing.T) { connectionURL := newTestStore(t) - store := New(Config{ + store, err := New(Config{ Hosts: []string{connectionURL}, Keyspace: "test_concurrent", Table: "test_kv", Expiration: 5 * time.Second, }) + require.NoError(t, err) defer store.Close() testConcurrentAccess(t, store) }) t.Run("Reset", func(t *testing.T) { connectionURL := newTestStore(t) - store := New(Config{ + store, err := New(Config{ Hosts: []string{connectionURL}, Keyspace: "test_reset", Table: "test_kv", Expiration: 5 * time.Second, }) + require.NoError(t, err) defer store.Close() testReset(t, connectionURL, store) }) + + t.Run("IdentifierValidation", func(t *testing.T) { + testIdentifierValidation(t) + }) } // testKeyspaceCreation tests the keyspace creation functionality. @@ -171,7 +181,6 @@ func testBasicOperations(t *testing.T, store *Storage) { // testExpirableKeys tests the expirable keys functionality. func testExpirableKeys(t *testing.T, store *Storage) { - // Set keys with different expiration settings // Key with default TTL (exp = 0 means use default) err := store.Set("key_default_ttl", []byte("value1"), 0) @@ -199,37 +208,37 @@ func testExpirableKeys(t *testing.T, store *Storage) { require.Equal(t, []byte("value3"), value) // Wait for specific TTL to expire - time.Sleep(1500 * time.Millisecond) - - // Specific TTL key should be gone, others should remain - value, err = store.Get("key_specific_ttl") - require.NoError(t, err) - require.Nil(t, value, "Key with 1s TTL should have expired") + require.Eventually(t, func() bool { + v, _ := store.Get("key_specific_ttl") + return v == nil + }, 3*time.Second, 100*time.Millisecond, + "Key with 1s TTL should have expired") + // Default TTL key should still exist value, err = store.Get("key_default_ttl") require.NoError(t, err) require.Equal(t, []byte("value1"), value, "Key with default TTL should still exist") + // No TTL key should still exist value, err = store.Get("key_no_ttl") require.NoError(t, err) require.Equal(t, []byte("value3"), value, "Key with no TTL should still exist") // Wait for default TTL to expire - time.Sleep(4 * time.Second) - - // Default TTL key should be gone, no TTL key should remain - value, err = store.Get("key_default_ttl") - require.NoError(t, err) - require.Nil(t, value, "Key with default TTL should have expired") + require.Eventually(t, func() bool { + v, _ := store.Get("key_default_ttl") + return v == nil + }, 6*time.Second, 100*time.Millisecond, + "Key with default TTL should have expired") + // No TTL key should still exist value, err = store.Get("key_no_ttl") require.NoError(t, err) require.Equal(t, []byte("value3"), value, "Key with no TTL should still exist") } -// / testReset tests the Reset method. +// testReset tests the Reset method. func testReset(t *testing.T, connectionURL string, store *Storage) { - // Set some keys err := store.Set("key1", []byte("value1"), 0) require.NoError(t, err) @@ -255,35 +264,40 @@ func testReset(t *testing.T, connectionURL string, store *Storage) { require.NoError(t, err) require.Nil(t, value, "Key should be deleted after reset") + // Close the first store before creating a new one + store.Close() + // Create new storage with Reset flag - store = New(Config{ + newStore, err := New(Config{ Hosts: []string{connectionURL}, Keyspace: "test_reset", Table: "test_kv", Reset: true, }) - defer store.Close() + require.NoError(t, err) + defer newStore.Close() // Set a key - err = store.Set("key3", []byte("value3"), 0) + err = newStore.Set("key3", []byte("value3"), 0) require.NoError(t, err) // Verify key exists - value, err = store.Get("key3") + value, err = newStore.Get("key3") require.NoError(t, err) require.Equal(t, []byte("value3"), value) } // testConcurrentAccess tests concurrent access to the storage. func testConcurrentAccess(t *testing.T, store *Storage) { - // Number of goroutines const concurrentOps = 10 - done := make(chan bool, concurrentOps) + var wg sync.WaitGroup + wg.Add(concurrentOps) // Run concurrent operations for i := 0; i < concurrentOps; i++ { go func(id int) { + defer wg.Done() // Set key key := fmt.Sprintf("key%d", id) value := []byte(fmt.Sprintf("value%d", id)) @@ -303,33 +317,28 @@ func testConcurrentAccess(t *testing.T, store *Storage) { retrievedValue, err = store.Get(key) require.NoError(t, err) require.Nil(t, retrievedValue) - - done <- true }(i) } // Wait for all goroutines to complete - for i := 0; i < concurrentOps; i++ { - <-done - } + wg.Wait() } func Benchmark_Cassandra_Set(b *testing.B) { - connectionURL := newTestStore(b) // Create new storage - store := New(Config{ + store, err := New(Config{ Hosts: []string{connectionURL}, Keyspace: "test_concurrent", Table: "test_kv", }) + require.NoError(b, err) defer store.Close() b.ReportAllocs() b.ResetTimer() - var err error for i := 0; i < b.N; i++ { err = store.Set("john", []byte("doe"), 0) } @@ -338,21 +347,22 @@ func Benchmark_Cassandra_Set(b *testing.B) { } func Benchmark_Cassandra_Get(b *testing.B) { - connectionURL := newTestStore(b) // Create new storage - client := New(Config{ + client, err := New(Config{ Hosts: []string{connectionURL}, Keyspace: "test_concurrent", Table: "test_kv", }) + require.NoError(b, err) defer client.Close() b.ReportAllocs() b.ResetTimer() - err := client.Set("john", []byte("doe"), 0) + err = client.Set("john", []byte("doe"), 0) + require.NoError(b, err) for i := 0; i < b.N; i++ { _, err = client.Get("john") @@ -362,21 +372,20 @@ func Benchmark_Cassandra_Get(b *testing.B) { } func Benchmark_Cassandra_Set_And_Delete(b *testing.B) { - connectionURL := newTestStore(b) // Create new storage - client := New(Config{ + client, err := New(Config{ Hosts: []string{connectionURL}, Keyspace: "test_concurrent", Table: "test_kv", }) + require.NoError(b, err) defer client.Close() b.ReportAllocs() b.ResetTimer() - var err error for i := 0; i < b.N; i++ { _ = client.Set("john", []byte("doe"), 0) err = client.Delete("john") @@ -384,3 +393,47 @@ func Benchmark_Cassandra_Set_And_Delete(b *testing.B) { require.NoError(b, err) } + +// testIdentifierValidation tests the validateIdentifier function +func testIdentifierValidation(t *testing.T) { + // Test valid identifiers + validCases := []string{ + "test", + "test123", + "test_123", + "TEST", + "Test123", + "test_table", + } + + for _, tc := range validCases { + t.Run(fmt.Sprintf("valid_%s", tc), func(t *testing.T) { + result, err := validateIdentifier(tc, "test") + require.NoError(t, err) + require.Equal(t, tc, result) + }) + } + + // Test invalid identifiers + invalidCases := []struct { + name string + value string + }{ + {"empty", ""}, + {"space", "test table"}, + {"hyphen", "test-table"}, + {"dot", "test.table"}, + {"quote", `test"table`}, + {"semicolon", "test;table"}, + {"sql_injection", `test"; DROP KEYSPACE prod; --`}, + {"unicode", "test表"}, + } + + for _, tc := range invalidCases { + t.Run(fmt.Sprintf("invalid_%s", tc.name), func(t *testing.T) { + _, err := validateIdentifier(tc.value, "test") + require.Error(t, err) + require.Contains(t, err.Error(), "invalid test name") + }) + } +} diff --git a/cassandra/config.go b/cassandra/config.go index 3fd122a1..c6db983c 100644 --- a/cassandra/config.go +++ b/cassandra/config.go @@ -38,7 +38,7 @@ var ConfigDefault = Config{ Expiration: 10 * time.Minute, } -// ConfigDefault is the Helper function to apply default config +// configDefault applies `ConfigDefault` values to a user‑supplied Config. func configDefault(config ...Config) Config { // Return default config if nothing provided if len(config) < 1 { @@ -67,6 +67,9 @@ func configDefault(config ...Config) Config { if cfg.Expiration == 0 { cfg.Expiration = ConfigDefault.Expiration + } else if cfg.Expiration < 0 { + // Disallow negative expirations – they produce invalid TTLs. + cfg.Expiration = 0 } return cfg