package cassandra import ( "fmt" "strings" "time" "unicode" "github.com/gocql/gocql" "github.com/scylladb/gocqlx/v2" "github.com/scylladb/gocqlx/v2/qb" ) var ( // ErrNotFound is returned when the key does not exist ErrNotFound = fmt.Errorf("key not found") // ErrKeyExpired is returned when the key has expired ErrKeyExpired = fmt.Errorf("key expired") ) // Storage represents a Cassandra storage implementation type Storage struct { cluster *gocql.ClusterConfig session *gocql.Session sx gocqlx.Session keyspace string table string ttl int } // validateIdentifier checks if an identifier is valid func validateIdentifier(name, identifierType string) (string, error) { if name == "" { return "", fmt.Errorf("invalid %s name: cannot be empty", identifierType) } // Check for invalid characters if strings.ContainsAny(name, " \t\n\r\f\v") { return "", fmt.Errorf("invalid %s name: cannot contain whitespace", identifierType) } // Check for SQL injection attempts and special characters if strings.ContainsAny(name, ";'\"-.") { return "", fmt.Errorf("invalid %s name: cannot contain special characters", identifierType) } // Check for unicode characters for _, r := range name { if r > unicode.MaxASCII { return "", fmt.Errorf("invalid %s name: cannot contain unicode characters", identifierType) } } // Only allow alphanumeric characters and underscores for _, r := range name { if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' { return "", fmt.Errorf("invalid %s name: can only contain letters, numbers, and underscores", identifierType) } } return name, nil } // New creates a new Cassandra storage instance func New(cnfg Config) (*Storage, error) { // Default config cfg := configDefault(cnfg) // Validate and escape identifiers keyspace, err := validateIdentifier(cfg.Keyspace, "keyspace") if err != nil { return nil, err } table, err := validateIdentifier(cfg.Table, "table") if err != nil { return nil, err } // Create cluster config cluster := gocql.NewCluster(cfg.Hosts...) // Safe check for SSL options if cfg.SslOpts != nil { cluster.SslOpts = cfg.SslOpts } cluster.Consistency = cfg.Consistency cluster.ConnectTimeout = cfg.ConnectTimeout cluster.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: cfg.MaxRetries} // Convert expiration to seconds for TTL ttl := 0 if cfg.Expiration > 0 { ttl = int(cfg.Expiration.Seconds()) } else if cfg.Expiration < 0 { // Expiration < 0 means indefinite storage cfg.Expiration = 0 } // Create storage instance storage := &Storage{ cluster: cluster, keyspace: keyspace, table: table, ttl: ttl, } // Initialize keyspace if err := storage.createOrVerifyKeySpace(cfg.Reset); err != nil { return nil, fmt.Errorf("cassandra storage init: %w", err) } return storage, nil } // createOrVerifyKeySpace ensures the keyspace and table exist with proper keyspace func (s *Storage) createOrVerifyKeySpace(reset bool) error { // Clone the original cluster config and set system keyspace systemCluster := *s.cluster systemCluster.Keyspace = "system" systemCluster.PoolConfig = gocql.PoolConfig{ HostSelectionPolicy: gocql.TokenAwareHostPolicy(gocql.RoundRobinHostPolicy()), } // Connect to the system keyspace systemSession, err := systemCluster.CreateSession() if err != nil { return fmt.Errorf("failed to connect to system keyspace: %w", err) } defer systemSession.Close() // Create keyspace if not exists err = s.ensureKeyspace(systemSession) if err != nil { return fmt.Errorf("failed to ensure keyspace exists: %w", err) } // Now connect to our keyspace s.cluster.Keyspace = s.keyspace session, err := s.cluster.CreateSession() if err != nil { return fmt.Errorf("failed to connect to keyspace %s: %w", s.keyspace, err) } s.session = session s.sx = gocqlx.NewSession(session) // Drop tables if reset is requested if reset { if err := s.dropTables(); err != nil { return err } } // Create data table if necessary if err := s.createDataTable(); err != nil { return err } return nil } // ensureKeyspace creates the keyspace if it doesn't exist func (s *Storage) ensureKeyspace(systemSession *gocql.Session) error { // Check if keyspace exists var count int if err := systemSession.Query( "SELECT COUNT(*) FROM system_schema.keyspaces WHERE keyspace_name = ?", s.keyspace, ).Scan(&count); err != nil { return err } // Create keyspace if it doesn't exist if count == 0 { query := fmt.Sprintf( "CREATE KEYSPACE %s WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1}", s.keyspace, ) if err := systemSession.Query(query).Exec(); err != nil { return err } } return nil } // createDataTable creates the data table for key-value storage func (s *Storage) createDataTable() error { // Create table with proper escaping query := fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s.%s ( key text PRIMARY KEY, value blob, expires_at timestamp ) `, s.keyspace, s.table) // Use gocqlx session return s.sx.Query(query, []string{}).ExecRelease() } // dropTables drops existing tables for reset func (s *Storage) dropTables() error { // Drop data table with proper escaping query := fmt.Sprintf("DROP TABLE IF EXISTS %s.%s", s.keyspace, s.table) if err := s.sx.Query(query, []string{}).ExecRelease(); err != nil { return err } // Drop schema_info table with proper escaping query = fmt.Sprintf("DROP TABLE IF EXISTS %s.schema_info", s.keyspace) return s.sx.Query(query, []string{}).ExecRelease() } // queryResult holds the result of a SELECT query type queryResult struct { Value []byte `db:"value"` ExpiresAt time.Time `db:"expires_at"` } // Set stores a key-value pair with optional expiration func (s *Storage) Set(key string, value []byte, exp time.Duration) error { // Validate key if _, err := validateIdentifier(key, "key"); err != nil { return err } // Calculate expiration time var expiresAt *time.Time var ttl int if exp > 0 { // Specific expiration provided ttl = int(exp.Seconds()) t := time.Now().Add(exp) expiresAt = &t } else if exp == 0 && s.ttl > 0 { // Use default TTL from config ttl = s.ttl t := time.Now().Add(time.Duration(s.ttl) * time.Second) expiresAt = &t } // If exp == 0 and s.ttl == 0, no TTL will be set (live forever) // Use query builder for insert stmt, names := qb.Insert(fmt.Sprintf("%s.%s", s.keyspace, s.table)). Columns("key", "value", "expires_at"). ToCql() if ttl > 0 { stmt += fmt.Sprintf(" USING TTL %d", ttl) } // Use gocqlx session return s.sx.Query(stmt, names).BindMap(map[string]interface{}{ "key": key, "value": value, "expires_at": expiresAt, }).ExecRelease() } // Get retrieves a value by key func (s *Storage) Get(key string) ([]byte, error) { // Use query builder for select stmt, names := qb.Select(fmt.Sprintf("%s.%s", s.keyspace, s.table)). Columns("value", "expires_at"). Where(qb.Eq("key")). ToCql() var result queryResult // Use gocqlx session if err := s.sx.Query(stmt, names).BindMap(map[string]interface{}{ "key": key, }).GetRelease(&result); err != nil { if err == gocql.ErrNotFound { return nil, ErrNotFound } return nil, err } // Check if the key has expired if !result.ExpiresAt.IsZero() && time.Now().After(result.ExpiresAt) { // Delete the expired key if err := s.Delete(key); err != nil { return nil, err } return nil, ErrKeyExpired } return result.Value, nil } // Delete removes a key from storage func (s *Storage) Delete(key string) error { // Use query builder for delete stmt, names := qb.Delete(fmt.Sprintf("%s.%s", s.keyspace, s.table)). Where(qb.Eq("key")). ToCql() // Use gocqlx session return s.sx.Query(stmt, names).BindMap(map[string]interface{}{ "key": key, }).ExecRelease() } // Reset clears all keys from storage func (s *Storage) Reset() error { // Use direct TRUNCATE query with proper escaping query := fmt.Sprintf("TRUNCATE TABLE %s.%s", s.keyspace, s.table) return s.sx.Query(query, []string{}).ExecRelease() } // Conn returns the underlying gocql session. func (s *Storage) Conn() *gocql.Session { return s.session } // Close closes the storage connection. // This method is not thread-safe and should not be called concurrently with other methods. func (s *Storage) Close() { if s.session != nil { s.session.Close() } }