diff --git a/mysql/mysql_test.go b/mysql/mysql_test.go index 040a9440..2f70030a 100644 --- a/mysql/mysql_test.go +++ b/mysql/mysql_test.go @@ -132,15 +132,9 @@ func Test_MYSQL_Non_UTF8(t *testing.T) { require.Equal(t, val, result) } -func Test_MYSQL_Conn(t *testing.T) { - testStore := newTestStore(t) - defer testStore.Close() - - require.True(t, testStore.Conn() != nil) -} - func TestMySQLStorageTCK(t *testing.T) { - s, err := tck.New(context.Background(), t, &MySQLStorageTCK{}, tck.PerTest) + // The TCK needs the concrete type of the storage and the driver type returned by the Conn method. + s, err := tck.New[*Storage, *sql.DB](context.Background(), t, &MySQLStorageTCK{}, tck.PerTest) require.NoError(t, err) suite.Run(t, &s) diff --git a/storage.go b/storage.go index ebf58023..994f01f7 100644 --- a/storage.go +++ b/storage.go @@ -5,6 +5,13 @@ import ( "time" ) +type StorageWithConn[T any] interface { + // Conn returns a connection to the storage. + // Implementations should return a connection to the storage, + // using the proper driver for the storage. + Conn() T +} + // Storage interface for communicating with different database/key-value // providers. Visit https://github.com/gofiber/storage for more info. type Storage interface { diff --git a/testhelpers/tck/suite.go b/testhelpers/tck/suite.go index 02bd46ce..5bde894f 100644 --- a/testhelpers/tck/suite.go +++ b/testhelpers/tck/suite.go @@ -20,28 +20,28 @@ const ( PerSuite ) -type TCKSuite[T storage.Storage] interface { +type TCKSuite[T storage.Storage, D any] interface { NewStoreWithContainer() func(ctx context.Context, tb testing.TB) (T, testcontainers.Container, error) } // New creates a new [StorageTestSuite] with the given [TCKSuite]. -func New[T storage.Storage](ctx context.Context, t *testing.T, tckSuite TCKSuite[T], creationHook CreationHook) (StorageTestSuite[T], error) { +func New[T storage.Storage, D any](ctx context.Context, t *testing.T, tckSuite TCKSuite[T, D], creationHook CreationHook) (StorageTestSuite[T, D], error) { if creationHook != PerSuite && creationHook != PerTest { - return StorageTestSuite[T]{}, fmt.Errorf("invalid creation hook: %d", creationHook) + return StorageTestSuite[T, D]{}, fmt.Errorf("invalid creation hook: %d", creationHook) } if tckSuite == nil { - return StorageTestSuite[T]{}, fmt.Errorf("test suite is nil") + return StorageTestSuite[T, D]{}, fmt.Errorf("test suite is nil") } - return StorageTestSuite[T]{ + return StorageTestSuite[T, D]{ ctx: ctx, creationHook: creationHook, createFn: tckSuite.NewStoreWithContainer(), }, nil } -type StorageTestSuite[T storage.Storage] struct { +type StorageTestSuite[T storage.Storage, D any] struct { suite.Suite stats *suite.SuiteInformation ctx context.Context @@ -53,7 +53,7 @@ type StorageTestSuite[T storage.Storage] struct { ctr testcontainers.Container } -func (s *StorageTestSuite[T]) cleanup() error { +func (s *StorageTestSuite[T, D]) cleanup() error { t := s.T() t.Log("🧹 Cleaning up store and container") @@ -79,11 +79,11 @@ func (s *StorageTestSuite[T]) cleanup() error { // Hooks // ---------------------------------------------------------------------------- -func (s *StorageTestSuite[T]) HandleStats(_ string, stats *suite.SuiteInformation) { +func (s *StorageTestSuite[T, D]) HandleStats(_ string, stats *suite.SuiteInformation) { s.stats = stats } -func (s *StorageTestSuite[T]) SetupSuite() { +func (s *StorageTestSuite[T, D]) SetupSuite() { if s.creationHook == PerSuite { t := s.T() @@ -99,13 +99,13 @@ func (s *StorageTestSuite[T]) SetupSuite() { } } -func (s *StorageTestSuite[T]) TearDownSuite() { +func (s *StorageTestSuite[T, D]) TearDownSuite() { if s.creationHook == PerSuite { s.Require().NoError(s.cleanup()) } } -func (s *StorageTestSuite[T]) SetupTest() { +func (s *StorageTestSuite[T, D]) SetupTest() { if s.creationHook == PerTest { t := s.T() @@ -121,7 +121,7 @@ func (s *StorageTestSuite[T]) SetupTest() { } } -func (s *StorageTestSuite[T]) TearDownTest() { +func (s *StorageTestSuite[T, D]) TearDownTest() { if s.creationHook == PerTest { s.Require().NoError(s.cleanup()) } @@ -131,11 +131,22 @@ func (s *StorageTestSuite[T]) TearDownTest() { // Tests // ---------------------------------------------------------------------------- -func (s *StorageTestSuite[T]) TestSet() { +func (s *StorageTestSuite[T, D]) TestConn() { + storeWithConn, ok := s.store.(storage.StorageWithConn[D]) + if !ok { + s.T().Skip("Storage does not implement StorageWithConn") + return + } + + conn := storeWithConn.Conn() + s.Require().NotNil(conn, "Conn should not be nil") +} + +func (s *StorageTestSuite[T, D]) TestSet() { s.setValue("test_key", []byte("test_value")) } -func (s *StorageTestSuite[T]) TestSetWithContext() { +func (s *StorageTestSuite[T, D]) TestSetWithContext() { ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -143,19 +154,19 @@ func (s *StorageTestSuite[T]) TestSetWithContext() { s.Require().ErrorIs(err, context.Canceled) } -func (s *StorageTestSuite[T]) TestSetAndOverride() { +func (s *StorageTestSuite[T, D]) TestSetAndOverride() { s.setValue("test_key", []byte("test_value")) s.setValue("test_key", []byte("test_value_2")) s.requireKeyHasValue("test_key", []byte("test_value_2")) } -func (s *StorageTestSuite[T]) TestSetAndGet() { +func (s *StorageTestSuite[T, D]) TestSetAndGet() { s.setValue("test_key", []byte("test_value")) s.requireKeyHasValue("test_key", []byte("test_value")) } -func (s *StorageTestSuite[T]) TestGetWithContext() { +func (s *StorageTestSuite[T, D]) TestGetWithContext() { s.setValue("test_key", []byte("test_value")) ctx, cancel := context.WithCancel(context.Background()) @@ -166,13 +177,13 @@ func (s *StorageTestSuite[T]) TestGetWithContext() { s.Require().Zero(len(result)) } -func (s *StorageTestSuite[T]) TestGetMissing() { +func (s *StorageTestSuite[T, D]) TestGetMissing() { val, err := s.store.Get("non-existent-key") s.Require().NoError(err) s.Require().Zero(len(val)) } -func (s *StorageTestSuite[T]) TestGetExpired() { +func (s *StorageTestSuite[T, D]) TestGetExpired() { s.setValueWithTTL("temp_key", []byte("temp_value"), 500*time.Millisecond) s.Eventually(func() bool { @@ -181,7 +192,7 @@ func (s *StorageTestSuite[T]) TestGetExpired() { }, 2*time.Second, 100*time.Millisecond, "Key should expire") } -func (s *StorageTestSuite[T]) TestDelete() { +func (s *StorageTestSuite[T, D]) TestDelete() { s.setValue("delete_me", []byte("delete_value")) err := s.store.Delete("delete_me") @@ -190,7 +201,7 @@ func (s *StorageTestSuite[T]) TestDelete() { s.requireKeyNotExists("delete_me") } -func (s *StorageTestSuite[T]) TestDeleteWithContext() { +func (s *StorageTestSuite[T, D]) TestDeleteWithContext() { s.setValue("delete_me", []byte("delete_value")) ctx, cancel := context.WithCancel(context.Background()) @@ -204,7 +215,7 @@ func (s *StorageTestSuite[T]) TestDeleteWithContext() { s.Require().Equal([]byte("delete_value"), result) } -func (s *StorageTestSuite[T]) TestReset() { +func (s *StorageTestSuite[T, D]) TestReset() { s.setValue("key1", []byte("value1")) s.setValue("key2", []byte("value2")) @@ -218,7 +229,7 @@ func (s *StorageTestSuite[T]) TestReset() { s.requireKeyNotExists("key2") } -func (s *StorageTestSuite[T]) TestResetWithContext() { +func (s *StorageTestSuite[T, D]) TestResetWithContext() { s.setValue("key1", []byte("value1")) s.setValue("key2", []byte("value2")) @@ -235,7 +246,7 @@ func (s *StorageTestSuite[T]) TestResetWithContext() { s.requireKeyHasValue("key2", []byte("value2")) } -func (s *StorageTestSuite[T]) TestClose() { +func (s *StorageTestSuite[T, D]) TestClose() { err := s.store.Close() s.Require().NoError(err) @@ -249,27 +260,27 @@ func (s *StorageTestSuite[T]) TestClose() { // Helpers // ---------------------------------------------------------------------------- -func (s *StorageTestSuite[T]) setValue(key string, value []byte) { +func (s *StorageTestSuite[T, D]) setValue(key string, value []byte) { s.setValueWithTTL(key, value, 0) } -func (s *StorageTestSuite[T]) setValueWithTTL(key string, value []byte, ttl time.Duration) { +func (s *StorageTestSuite[T, D]) setValueWithTTL(key string, value []byte, ttl time.Duration) { err := s.store.Set(key, value, ttl) s.Require().NoError(err) } -func (s *StorageTestSuite[T]) getValue(key string) []byte { +func (s *StorageTestSuite[T, D]) getValue(key string) []byte { val, err := s.store.Get(key) s.Require().NoError(err) return val } -func (s *StorageTestSuite[T]) requireKeyHasValue(key string, expectedValue []byte) { +func (s *StorageTestSuite[T, D]) requireKeyHasValue(key string, expectedValue []byte) { val := s.getValue(key) s.Require().Equal(expectedValue, val) } -func (s *StorageTestSuite[T]) requireKeyNotExists(key string) { +func (s *StorageTestSuite[T, D]) requireKeyNotExists(key string) { val := s.getValue(key) s.Require().Nil(val) }