Files
storage/cassandra/cassandra.go
2025-04-15 12:03:18 +05:30

228 lines
5.6 KiB
Go

package cassandra
import (
"errors"
"fmt"
"log"
"time"
"github.com/gocql/gocql"
)
// Storage represents a Cassandra storage implementation
type Storage struct {
cluster *gocql.ClusterConfig
session *gocql.Session
keyspace string
table string
ttl int
}
// New creates a new Cassandra storage instance
func New(cnfg Config) *Storage {
// Default config
cfg := configDefault(cnfg)
// Create cluster config
cluster := gocql.NewCluster(cfg.Hosts...)
cluster.Consistency = cfg.Consistency
// Don't set keyspace initially - we need to create it first
// We'll connect to system keyspace first
// Convert expiration to seconds for TTL
ttl := 0
if cfg.Expiration > 0 {
ttl = int(cfg.Expiration.Seconds())
}
// Create storage instance
storage := &Storage{
cluster: cluster,
keyspace: cfg.Keyspace,
table: cfg.Table,
ttl: ttl,
}
// Initialize keyspace
if err := storage.createOrVerifyKeySpace(cfg.Reset); err != nil {
log.Printf("Failed to initialize keyspace: %v", err)
panic(err)
}
return storage
}
// createOrVerifyKeySpace ensures the keyspace and table exist with proper keyspace
func (s *Storage) createOrVerifyKeySpace(reset bool) error {
// Connect to system keyspace first to create our keyspace if needed
systemCluster := gocql.NewCluster(s.cluster.Hosts...)
systemCluster.Consistency = s.cluster.Consistency
systemCluster.Timeout = s.cluster.Timeout
// Connect to the system keyspace
systemSession, err := systemCluster.CreateSession()
if err != nil {
return fmt.Errorf("failed to connect to system keyspace: %w", err)
}
defer systemSession.Close()
// Create keyspace if not exists
err = s.ensureKeyspace(systemSession)
if err != nil {
return fmt.Errorf("failed to ensure keyspace exists: %w", err)
}
// Now connect to our keyspace
s.cluster.Keyspace = s.keyspace
session, err := s.cluster.CreateSession()
if err != nil {
return fmt.Errorf("failed to connect to keyspace %s: %w", s.keyspace, err)
}
s.session = session
// Drop tables if reset is requested
if reset {
if err := s.dropTables(); err != nil {
return err
}
}
// Create data table if necessary
if err := s.createDataTable(); err != nil {
return err
}
return nil
}
// ensureKeyspace creates the keyspace if it doesn't exist
func (s *Storage) ensureKeyspace(systemSession *gocql.Session) error {
// Check if keyspace exists
var count int
if err := systemSession.Query(
"SELECT COUNT(*) FROM system_schema.keyspaces WHERE keyspace_name = ?",
s.keyspace,
).Scan(&count); err != nil {
return err
}
// Create keyspace if it doesn't exist
if count == 0 {
query := fmt.Sprintf(
"CREATE KEYSPACE %s WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1}",
s.keyspace,
)
if err := systemSession.Query(query).Exec(); err != nil {
return err
}
log.Printf("Created keyspace: %s", s.keyspace)
}
return nil
}
// createDataTable creates the data table for key-value storage
func (s *Storage) createDataTable() error {
query := fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s.%s (
key text PRIMARY KEY,
value blob,
expires_at timestamp
)
`, s.keyspace, s.table)
return s.session.Query(query).Exec()
}
// dropTables drops existing tables for reset
func (s *Storage) dropTables() error {
// Drop data table
query := fmt.Sprintf("DROP TABLE IF EXISTS %s.%s", s.keyspace, s.table)
if err := s.session.Query(query).Exec(); err != nil {
return err
}
// Drop schema_info table
query = fmt.Sprintf("DROP TABLE IF EXISTS %s.schema_info", s.keyspace)
return s.session.Query(query).Exec()
}
// Set stores a key-value pair with optional expiration
func (s *Storage) Set(key string, value []byte, exp time.Duration) error {
// Calculate expiration time
var expiresAt *time.Time
var ttl = -1 // Default to no TTL
if exp > 0 {
// Specific expiration provided
ttl = int(exp.Seconds())
t := time.Now().Add(exp)
expiresAt = &t
} else if exp == 0 && s.ttl > 0 {
// Use default TTL from config
ttl = s.ttl
t := time.Now().Add(time.Duration(s.ttl) * time.Second)
expiresAt = &t
}
// If exp < 0, we'll use no TTL (indefinite storage)
// Insert with TTL if specified
var query string
if ttl > 0 {
query = fmt.Sprintf("INSERT INTO %s.%s (key, value, expires_at) VALUES (?, ?, ?) USING TTL %d",
s.keyspace, s.table, ttl)
} else {
query = fmt.Sprintf("INSERT INTO %s.%s (key, value, expires_at) VALUES (?, ?, ?)",
s.keyspace, s.table)
}
return s.session.Query(query, key, value, expiresAt).Exec()
}
// Get retrieves a value by key
func (s *Storage) Get(key string) ([]byte, error) {
var value []byte
var expiresAt time.Time
query := fmt.Sprintf("SELECT value, expires_at FROM %s.%s WHERE key = ?", s.keyspace, s.table)
if err := s.session.Query(query, key).Scan(&value, &expiresAt); err != nil {
if errors.Is(err, gocql.ErrNotFound) {
return nil, nil
}
return nil, err
}
// Check if expired (as a backup in case TTL didn't work)
if !expiresAt.IsZero() && expiresAt.Before(time.Now()) {
// Expired but not yet removed by TTL
err := s.Delete(key)
if err != nil {
log.Printf("Failed to delete expired key %s: %v", key, err)
}
return nil, nil
}
return value, nil
}
// Delete removes a key from storage
func (s *Storage) Delete(key string) error {
query := fmt.Sprintf("DELETE FROM %s.%s WHERE key = ?", s.keyspace, s.table)
return s.session.Query(query, key).Exec()
}
// Reset clears all keys from storage
func (s *Storage) Reset() error {
query := fmt.Sprintf("TRUNCATE TABLE %s.%s", s.keyspace, s.table)
return s.session.Query(query).Exec()
}
// Close closes the storage connection
func (s *Storage) Close() {
if s.session != nil {
s.session.Close()
}
}