mirror of
https://github.com/gofiber/storage.git
synced 2025-10-05 00:33:03 +08:00
417 lines
12 KiB
Go
417 lines
12 KiB
Go
package cassandra
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"time"
|
|
|
|
"github.com/gocql/gocql"
|
|
)
|
|
|
|
// Storage represents a Cassandra storage implementation
|
|
type Storage struct {
|
|
cluster *gocql.ClusterConfig
|
|
session *gocql.Session
|
|
keyspace string
|
|
table string
|
|
ttl int
|
|
}
|
|
|
|
// SchemaInfo represents the schema metadata
|
|
type SchemaInfo struct {
|
|
Version int
|
|
Description string
|
|
CreatedAt time.Time
|
|
UpdatedAt time.Time
|
|
}
|
|
|
|
// New creates a new Cassandra storage instance
|
|
func New(cfg Config) *Storage {
|
|
// Create cluster config
|
|
cluster := gocql.NewCluster(cfg.Hosts...)
|
|
|
|
// Don't set keyspace initially - we need to create it first
|
|
// We'll connect to system keyspace first
|
|
|
|
// Convert expiration to seconds for TTL
|
|
ttl := 0
|
|
if cfg.Expiration > 0 {
|
|
ttl = int(cfg.Expiration.Seconds())
|
|
}
|
|
|
|
// Create storage instance
|
|
storage := &Storage{
|
|
cluster: cluster,
|
|
keyspace: cfg.Keyspace,
|
|
table: cfg.Table,
|
|
ttl: ttl,
|
|
}
|
|
|
|
// Initialize keyspace
|
|
if err := storage.createOrVerifyKeySpace(cfg.Reset); err != nil {
|
|
log.Printf("Failed to initialize keyspace: %v", err)
|
|
panic(err)
|
|
}
|
|
|
|
return storage
|
|
}
|
|
|
|
// createOrVerifyKeySpace ensures the keyspace and table exist with proper keyspace
|
|
func (s *Storage) createOrVerifyKeySpace(reset bool) error {
|
|
// Connect to system keyspace first to create our keyspace if needed
|
|
systemCluster := gocql.NewCluster(s.cluster.Hosts...)
|
|
systemCluster.Consistency = s.cluster.Consistency
|
|
systemCluster.Timeout = s.cluster.Timeout
|
|
|
|
// Connect to the system keyspace
|
|
systemSession, err := systemCluster.CreateSession()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect to system keyspace: %w", err)
|
|
}
|
|
defer systemSession.Close()
|
|
|
|
// Create keyspace if not exists
|
|
err = s.ensureKeyspace(systemSession)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to ensure keyspace exists: %w", err)
|
|
}
|
|
|
|
// Now connect to our keyspace
|
|
s.cluster.Keyspace = s.keyspace
|
|
session, err := s.cluster.CreateSession()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect to keyspace %s: %w", s.keyspace, err)
|
|
}
|
|
s.session = session
|
|
|
|
// Drop tables if reset is requested
|
|
if reset {
|
|
if err := s.dropTables(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Create data table if necessary
|
|
if err := s.createDataTable(); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ensureKeyspace creates the keyspace if it doesn't exist
|
|
func (s *Storage) ensureKeyspace(systemSession *gocql.Session) error {
|
|
// Check if keyspace exists
|
|
var count int
|
|
if err := systemSession.Query(
|
|
"SELECT COUNT(*) FROM system_schema.keyspaces WHERE keyspace_name = ?",
|
|
s.keyspace,
|
|
).Scan(&count); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Create keyspace if it doesn't exist
|
|
if count == 0 {
|
|
query := fmt.Sprintf(
|
|
"CREATE KEYSPACE %s WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1}",
|
|
s.keyspace,
|
|
)
|
|
if err := systemSession.Query(query).Exec(); err != nil {
|
|
return err
|
|
}
|
|
log.Printf("Created keyspace: %s", s.keyspace)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// createDataTable creates the data table for key-value storage
|
|
func (s *Storage) createDataTable() error {
|
|
query := fmt.Sprintf(`
|
|
CREATE TABLE IF NOT EXISTS %s.%s (
|
|
key text PRIMARY KEY,
|
|
value blob,
|
|
created_at timestamp,
|
|
expires_at timestamp
|
|
)
|
|
`, s.keyspace, s.table)
|
|
|
|
return s.session.Query(query).Exec()
|
|
}
|
|
|
|
// dropTables drops existing tables for reset
|
|
func (s *Storage) dropTables() error {
|
|
// Drop data table
|
|
query := fmt.Sprintf("DROP TABLE IF EXISTS %s.%s", s.keyspace, s.table)
|
|
if err := s.session.Query(query).Exec(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Drop schema_info table
|
|
query = fmt.Sprintf("DROP TABLE IF EXISTS %s.schema_info", s.keyspace)
|
|
return s.session.Query(query).Exec()
|
|
}
|
|
|
|
// Set stores a key-value pair with optional expiration
|
|
func (s *Storage) Set(key string, value []byte, exp time.Duration) error {
|
|
// Calculate expiration time
|
|
var expiresAt *time.Time
|
|
var ttl int = -1 // Default to no TTL
|
|
|
|
if exp > 0 {
|
|
// Specific expiration provided
|
|
ttl = int(exp.Seconds())
|
|
t := time.Now().Add(exp)
|
|
expiresAt = &t
|
|
} else if exp == 0 && s.ttl > 0 {
|
|
// Use default TTL from config
|
|
ttl = s.ttl
|
|
t := time.Now().Add(time.Duration(s.ttl) * time.Second)
|
|
expiresAt = &t
|
|
}
|
|
// If exp < 0, we'll use no TTL (indefinite storage)
|
|
|
|
// Insert with TTL if specified
|
|
var query string
|
|
if ttl > 0 {
|
|
query = fmt.Sprintf("INSERT INTO %s.%s (key, value, created_at, expires_at) VALUES (?, ?, ?, ?) USING TTL %d",
|
|
s.keyspace, s.table, ttl)
|
|
} else {
|
|
query = fmt.Sprintf("INSERT INTO %s.%s (key, value, created_at, expires_at) VALUES (?, ?, ?, ?)",
|
|
s.keyspace, s.table)
|
|
}
|
|
|
|
return s.session.Query(query, key, value, time.Now(), expiresAt).Exec()
|
|
}
|
|
|
|
// Get retrieves a value by key
|
|
func (s *Storage) Get(key string) ([]byte, error) {
|
|
var value []byte
|
|
var expiresAt time.Time
|
|
|
|
query := fmt.Sprintf("SELECT value, expires_at FROM %s.%s WHERE key = ?", s.keyspace, s.table)
|
|
if err := s.session.Query(query, key).Scan(&value, &expiresAt); err != nil {
|
|
if errors.Is(err, gocql.ErrNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
// Check if expired (as a backup in case TTL didn't work)
|
|
if !expiresAt.IsZero() && expiresAt.Before(time.Now()) {
|
|
// Expired but not yet removed by TTL
|
|
err := s.Delete(key)
|
|
if err != nil {
|
|
log.Printf("Failed to delete expired key %s: %v", key, err)
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
return value, nil
|
|
}
|
|
|
|
// Delete removes a key from storage
|
|
func (s *Storage) Delete(key string) error {
|
|
query := fmt.Sprintf("DELETE FROM %s.%s WHERE key = ?", s.keyspace, s.table)
|
|
return s.session.Query(query, key).Exec()
|
|
}
|
|
|
|
// Reset clears all keys from storage
|
|
func (s *Storage) Reset() error {
|
|
query := fmt.Sprintf("TRUNCATE TABLE %s.%s", s.keyspace, s.table)
|
|
return s.session.Query(query).Exec()
|
|
}
|
|
|
|
// Close closes the storage connection
|
|
func (s *Storage) Close() {
|
|
if s.session != nil {
|
|
s.session.Close()
|
|
}
|
|
}
|
|
|
|
// Test functions
|
|
|
|
// // setupCassandraContainer creates a Cassandra container using the official module
|
|
// func setupCassandraContainer(ctx context.Context) (*cassandracontainer.CassandraContainer, string, error) {
|
|
// cassandraContainer, err := cassandracontainer.RunContainer(ctx,
|
|
// testcontainers.WithImage("cassandra:4.1"),
|
|
// cassandracontainer.WithInitialWaitTimeout(2*time.Minute),
|
|
// )
|
|
// if err != nil {
|
|
// return nil, "", err
|
|
// }
|
|
|
|
// // Get connection parameters
|
|
// host, err := cassandraContainer.Host(ctx)
|
|
// if err != nil {
|
|
// return nil, "", err
|
|
// }
|
|
|
|
// mappedPort, err := cassandraContainer.MappedPort(ctx, "9042/tcp")
|
|
// if err != nil {
|
|
// return nil, "", err
|
|
// }
|
|
|
|
// connectionURL := host + ":" + mappedPort.Port()
|
|
// return cassandraContainer, connectionURL, nil
|
|
// }
|
|
|
|
// func TestSchemaManagement(t *testing.T) {
|
|
// ctx := context.Background()
|
|
|
|
// // Start Cassandra container
|
|
// cassandraContainer, connectionURL, err := setupCassandraContainer(ctx)
|
|
// if err != nil {
|
|
// t.Fatalf("Failed to start Cassandra container: %v", err)
|
|
// }
|
|
// defer func() {
|
|
// if err := cassandraContainer.Terminate(ctx); err != nil {
|
|
// t.Logf("Failed to terminate container: %v", err)
|
|
// }
|
|
// }()
|
|
|
|
// // 1. Test keyspace creation
|
|
// store := New(Config{
|
|
// Hosts: []string{connectionURL},
|
|
// Keyspace: "test_keyspace_creation",
|
|
// Table: "test_table",
|
|
// SchemaVersion: 1,
|
|
// SchemaDescription: "Initial Schema",
|
|
// })
|
|
|
|
// // Verify keyspace was created
|
|
// systemCluster := gocql.NewCluster(connectionURL)
|
|
// systemSession, err := systemCluster.CreateSession()
|
|
// if err != nil {
|
|
// t.Fatalf("Failed to connect to system keyspace: %v", err)
|
|
// }
|
|
|
|
// var count int
|
|
// err = systemSession.Query(
|
|
// "SELECT COUNT(*) FROM system_schema.keyspaces WHERE keyspace_name = ?",
|
|
// "test_keyspace_creation",
|
|
// ).Scan(&count)
|
|
// assert.NoError(t, err)
|
|
// assert.Equal(t, 1, count, "Keyspace should have been created")
|
|
// systemSession.Close()
|
|
|
|
// // 2. Test table creation
|
|
// // Connect to the keyspace and check if tables exist
|
|
// cluster := gocql.NewCluster(connectionURL)
|
|
// cluster.Keyspace = "test_keyspace_creation"
|
|
// session, err := cluster.CreateSession()
|
|
// if err != nil {
|
|
// t.Fatalf("Failed to connect to keyspace: %v", err)
|
|
// }
|
|
|
|
// // Check for data table
|
|
// err = session.Query(
|
|
// "SELECT COUNT(*) FROM system_schema.tables WHERE keyspace_name = ? AND table_name = ?",
|
|
// "test_keyspace_creation", "test_table",
|
|
// ).Scan(&count)
|
|
// assert.NoError(t, err)
|
|
// assert.Equal(t, 1, count, "Data table should have been created")
|
|
|
|
// // Check for schema_info table
|
|
// err = session.Query(
|
|
// "SELECT COUNT(*) FROM system_schema.tables WHERE keyspace_name = ? AND table_name = ?",
|
|
// "test_keyspace_creation", "schema_info",
|
|
// ).Scan(&count)
|
|
// assert.NoError(t, err)
|
|
// assert.Equal(t, 1, count, "Schema info table should have been created")
|
|
|
|
// session.Close()
|
|
// store.Close()
|
|
|
|
// // 3. Test schema update
|
|
// // Create initial schema
|
|
// storeV1 := New(Config{
|
|
// Hosts: []string{connectionURL},
|
|
// Keyspace: "test_schema_update",
|
|
// Table: "test_table",
|
|
// SchemaVersion: 1,
|
|
// SchemaDescription: "Schema v1",
|
|
// })
|
|
|
|
// // Verify initial schema
|
|
// schemaInfo, err := storeV1.GetSchemaInfo()
|
|
// assert.NoError(t, err)
|
|
// assert.Equal(t, 1, schemaInfo.Version)
|
|
// assert.Equal(t, "Schema v1", schemaInfo.Description)
|
|
// createdAt := schemaInfo.CreatedAt
|
|
|
|
// // Close and create with higher version
|
|
// storeV1.Close()
|
|
|
|
// // Create updated schema
|
|
// storeV2 := New(Config{
|
|
// Hosts: []string{connectionURL},
|
|
// Keyspace: "test_schema_update",
|
|
// Table: "test_table",
|
|
// SchemaVersion: 2,
|
|
// SchemaDescription: "Schema v2",
|
|
// })
|
|
|
|
// // Verify schema was updated
|
|
// updatedSchema, err := storeV2.GetSchemaInfo()
|
|
// assert.NoError(t, err)
|
|
// assert.Equal(t, 2, updatedSchema.Version)
|
|
// assert.Equal(t, "Schema v2", updatedSchema.Description)
|
|
// assert.Equal(t, createdAt.Format(time.RFC3339), updatedSchema.CreatedAt.Format(time.RFC3339))
|
|
// assert.True(t, updatedSchema.UpdatedAt.After(createdAt))
|
|
|
|
// storeV2.Close()
|
|
|
|
// // 4. Test forced schema update with same version
|
|
// storeForce := New(Config{
|
|
// Hosts: []string{connectionURL},
|
|
// Keyspace: "test_schema_update",
|
|
// Table: "test_table",
|
|
// SchemaVersion: 2, // Same version
|
|
// SchemaDescription: "Schema v2 forced",
|
|
// ForceSchemaUpdate: true,
|
|
// })
|
|
|
|
// // Verify schema was updated despite same version
|
|
// forcedSchema, err := storeForce.GetSchemaInfo()
|
|
// assert.NoError(t, err)
|
|
// assert.Equal(t, 2, forcedSchema.Version)
|
|
// assert.Equal(t, "Schema v2 forced", forcedSchema.Description)
|
|
// assert.True(t, forcedSchema.UpdatedAt.After(updatedSchema.UpdatedAt))
|
|
|
|
// storeForce.Close()
|
|
|
|
// // 5. Test reset functionality
|
|
// resetStore := New(Config{
|
|
// Hosts: []string{connectionURL},
|
|
// Keyspace: "test_schema_reset",
|
|
// Table: "test_table",
|
|
// SchemaVersion: 1,
|
|
// SchemaDescription: "Initial Schema",
|
|
// })
|
|
|
|
// // Add some data
|
|
// err = resetStore.Set("key1", []byte("value1"), 0)
|
|
// assert.NoError(t, err)
|
|
|
|
// // Close and reopen with reset
|
|
// resetStore.Close()
|
|
|
|
// resetStore = New(Config{
|
|
// Hosts: []string{connectionURL},
|
|
// Keyspace: "test_schema_reset",
|
|
// Table: "test_table",
|
|
// SchemaVersion: 1,
|
|
// SchemaDescription: "Reset Schema",
|
|
// Reset: true,
|
|
// })
|
|
|
|
// // Check that data is gone
|
|
// val, err := resetStore.Get("key1")
|
|
// assert.NoError(t, err)
|
|
// assert.Nil(t, val, "Key should be gone after reset")
|
|
|
|
// resetStore.Close()
|
|
// }
|