package redis import ( "context" "fmt" "time" "github.com/redis/go-redis/v9" ) // Storage interface that is implemented by storage providers type Storage struct { db redis.UniversalClient } // NewFromConnection creates a new instance of Storage using the provided Redis universal client. func NewFromConnection(conn redis.UniversalClient) *Storage { return &Storage{ db: conn, } } // New creates a new Redis storage instance. func New(config ...Config) *Storage { // Set default config cfg := configDefault(config...) // Create new redis universal client var db redis.UniversalClient // Parse the URL and update config values accordingly if cfg.URL != "" { options, err := redis.ParseURL(cfg.URL) if err != nil { panic(err) } // Update the config values with the parsed URL values cfg.Username = options.Username cfg.Password = options.Password cfg.Database = options.DB cfg.Addrs = []string{options.Addr} // If cfg.TLSConfig is not provided, and options returns one, use it. if cfg.TLSConfig == nil && options.TLSConfig != nil { cfg.TLSConfig = options.TLSConfig } } else if len(cfg.Addrs) == 0 { // Fallback to Host and Port values if Addrs is empty cfg.Addrs = []string{fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)} } // Create Universal Client db = redis.NewUniversalClient(&redis.UniversalOptions{ Addrs: cfg.Addrs, MasterName: cfg.MasterName, ClientName: cfg.ClientName, SentinelUsername: cfg.SentinelUsername, SentinelPassword: cfg.SentinelPassword, DB: cfg.Database, Username: cfg.Username, Password: cfg.Password, TLSConfig: cfg.TLSConfig, PoolSize: cfg.PoolSize, IsClusterMode: cfg.IsClusterMode, }) // Test connection if err := db.Ping(context.Background()).Err(); err != nil { panic(err) } // Empty collection if Clear is true if cfg.Reset { if err := db.FlushDB(context.Background()).Err(); err != nil { panic(err) } } // Create new store return &Storage{ db: db, } } // GetWithContext retrieves the value associated with the given key using the provided context. func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { if len(key) <= 0 { return nil, nil } val, err := s.db.Get(ctx, key).Bytes() if err == redis.Nil { return nil, nil } return val, err } // Get value by key func (s *Storage) Get(key string) ([]byte, error) { return s.GetWithContext(context.Background(), key) } // SetWithContext key with value with context func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { if len(key) <= 0 || len(val) <= 0 { return nil } return s.db.Set(ctx, key, val, exp).Err() } // Set key with value func (s *Storage) Set(key string, val []byte, exp time.Duration) error { return s.SetWithContext(context.Background(), key, val, exp) } // DeleteWithContext key by key with context func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { if len(key) <= 0 { return nil } return s.db.Del(ctx, key).Err() } // Delete key by key func (s *Storage) Delete(key string) error { return s.DeleteWithContext(context.Background(), key) } // ResetWithContext all keys with context func (s *Storage) ResetWithContext(ctx context.Context) error { return s.db.FlushDB(ctx).Err() } // Reset all keys func (s *Storage) Reset() error { return s.ResetWithContext(context.Background()) } // Close the database func (s *Storage) Close() error { return s.db.Close() } // Return database client func (s *Storage) Conn() redis.UniversalClient { return s.db } // Return all the keys func (s *Storage) Keys() ([][]byte, error) { var keys [][]byte var cursor uint64 var err error for { var batch []string if batch, cursor, err = s.db.Scan(context.Background(), cursor, "*", 10).Result(); err != nil { return nil, err } for _, key := range batch { keys = append(keys, []byte(key)) } if cursor == 0 { break } } if len(keys) == 0 { return nil, nil } return keys, nil }