Files
storage/cassandra/cassandra_test.go
2025-04-21 09:37:59 +05:30

281 lines
6.3 KiB
Go

package cassandra
import (
"context"
"fmt"
"os"
"sync"
"testing"
"time"
"github.com/gocql/gocql"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
cassandracontainer "github.com/testcontainers/testcontainers-go/modules/cassandra"
)
const (
// cassandraImage is the default image used for running cassandra in tests.
cassandraImage = "cassandra:latest"
cassandraImageEnvVar string = "TEST_CASSANDRA_IMAGE"
cassandraPort = "9042/tcp"
)
// newTestStore creates a Cassandra container using the official module
func newTestStore(t testing.TB, keyspace string) *Storage {
t.Helper()
img := cassandraImage
if imgFromEnv := os.Getenv(cassandraImageEnvVar); imgFromEnv != "" {
img = imgFromEnv
}
ctx := context.Background()
c, err := cassandracontainer.Run(ctx, img)
testcontainers.CleanupContainer(t, c)
require.NoError(t, err)
host, err := c.Host(ctx)
require.NoError(t, err)
port, err := c.MappedPort(ctx, cassandraPort)
require.NoError(t, err)
store, err := New(Config{
Hosts: []string{fmt.Sprintf("%s:%d", host, port.Int())},
Keyspace: keyspace,
Table: "test_kv",
Consistency: gocql.One,
Expiration: 5 * time.Second,
})
require.NoError(t, err)
t.Cleanup(func() {
store.Close()
})
return store
}
// Test_keyspace_creation tests the keyspace creation functionality
func Test_keyspace_creation(t *testing.T) {
store := newTestStore(t, "test_keyspace_creation")
require.NotNil(t, store)
}
// Test_set tests the Set operation
func Test_set(t *testing.T) {
store := newTestStore(t, "test_basic_ops")
require.NotNil(t, store)
// Test Set
err := store.Set("test", []byte("value"), 0)
require.NoError(t, err)
// Verify the value was set
val, err := store.Get("test")
require.NoError(t, err)
require.Equal(t, []byte("value"), val)
}
// Test_get tests the Get operation
func Test_get(t *testing.T) {
store := newTestStore(t, "test_basic_ops")
require.NotNil(t, store)
// Set a value first
err := store.Set("test", []byte("value"), 0)
require.NoError(t, err)
// Test Get
val, err := store.Get("test")
require.NoError(t, err)
require.Equal(t, []byte("value"), val)
// Test Get non-existent key
val, err = store.Get("nonexistent")
require.Error(t, err)
require.Nil(t, val)
}
// Test_delete tests the Delete operation
func Test_delete(t *testing.T) {
store := newTestStore(t, "test_basic_ops")
require.NotNil(t, store)
// Set a value first
err := store.Set("test", []byte("value"), 0)
require.NoError(t, err)
// Verify the value exists
val, err := store.Get("test")
require.NoError(t, err)
require.Equal(t, []byte("value"), val)
// Test Delete
err = store.Delete("test")
require.NoError(t, err)
// Verify deletion
val, err = store.Get("test")
require.Error(t, err)
require.Nil(t, val)
}
// Test_expirable_keys tests the expirable keys functionality
func Test_expirable_keys(t *testing.T) {
store := newTestStore(t, "test_expirable")
require.NotNil(t, store)
// Set key with 1 second expiration
err := store.Set("test", []byte("value"), time.Second)
require.NoError(t, err)
// Verify key exists
val, err := store.Get("test")
require.NoError(t, err)
require.Equal(t, []byte("value"), val)
// Wait for expiration using Eventually
require.Eventually(t, func() bool {
val, err := store.Get("test")
return err != nil && val == nil
}, 3*time.Second, 100*time.Millisecond, "Key should expire within 3 seconds")
}
// Test_concurrent_access tests concurrent access to the storage
func Test_concurrent_access(t *testing.T) {
store := newTestStore(t, "test_concurrent")
require.NotNil(t, store)
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
key := fmt.Sprintf("key%d", i)
value := []byte(fmt.Sprintf("value%d", i))
err := store.Set(key, value, 0)
require.NoError(t, err)
val, err := store.Get(key)
require.NoError(t, err)
require.Equal(t, value, val)
}(i)
}
wg.Wait()
}
// Test_reset tests the Reset method
func Test_reset(t *testing.T) {
store := newTestStore(t, "test_reset")
require.NotNil(t, store)
// Add some data
err := store.Set("test1", []byte("value1"), 0)
require.NoError(t, err)
err = store.Set("test2", []byte("value2"), 0)
require.NoError(t, err)
// Reset storage
err = store.Reset()
require.NoError(t, err)
// Verify data is gone
val, err := store.Get("test1")
require.Error(t, err)
require.Nil(t, val)
val, err = store.Get("test2")
require.Error(t, err)
require.Nil(t, val)
}
// Test_valid_identifiers tests valid identifier cases
func Test_valid_identifiers(t *testing.T) {
store := newTestStore(t, "test_validation")
require.NotNil(t, store)
validCases := []struct {
name string
key string
}{
{"test", "test"},
{"test123", "test123"},
{"test_123", "test_123"},
{"TEST", "TEST"},
{"Test123", "Test123"},
{"test-table", "test-table"},
{"test.table", "test.table"},
}
for _, tc := range validCases {
t.Run(tc.name, func(t *testing.T) {
err := store.Set(tc.key, []byte("value"), 0)
require.NoError(t, err)
})
}
}
// Test_invalid_identifiers tests invalid identifier cases
func Test_invalid_identifiers(t *testing.T) {
invalidCases := []struct {
name string
key string
}{
{"empty", ""},
{"space", "test key"},
{"quote", `test"key`},
{"semicolon", "test;key"},
{"sql_injection", "test' OR '1'='1"},
{"unicode", "test\u2028key"},
}
for _, tc := range invalidCases {
t.Run(fmt.Sprintf("invalid_%s", tc.name), func(t *testing.T) {
_, err := validateIdentifier(tc.key, "test")
require.Error(t, err)
require.Contains(t, err.Error(), "invalid test name")
})
}
}
func Benchmark_Cassandra_Set(b *testing.B) {
connectionURL := newTestStore(b, "test_concurrent")
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
err := connectionURL.Set("john", []byte("doe"), 0)
require.NoError(b, err)
}
}
func Benchmark_Cassandra_Get(b *testing.B) {
connectionURL := newTestStore(b, "test_concurrent")
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := connectionURL.Get("john")
require.NoError(b, err)
}
}
func Benchmark_Cassandra_Set_And_Delete(b *testing.B) {
connectionURL := newTestStore(b, "test_concurrent")
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = connectionURL.Set("john", []byte("doe"), 0)
err := connectionURL.Delete("john")
require.NoError(b, err)
}
}