diff --git a/store/use.go b/store/use.go index 42d75a8..8f7e564 100644 --- a/store/use.go +++ b/store/use.go @@ -1,5 +1,9 @@ package store +import "gorm.io/gorm" + +type callback func(db *gorm.DB) + func DB(configKey string) *MySQLStore { return &MySQLStore{configKey: configKey} } diff --git a/store/use_mysql.go b/store/use_mysql.go index 6e7b4c8..80ef887 100644 --- a/store/use_mysql.go +++ b/store/use_mysql.go @@ -14,6 +14,9 @@ import ( type MySQLStore struct { configKey string + + callback callback + hasCallback bool } func (m *MySQLStore) Config() *config.MySQL { @@ -26,6 +29,11 @@ func (m *MySQLStore) Config() *config.MySQL { return r } +func (m *MySQLStore) Callback(fn callback) { + m.callback = fn + m.hasCallback = true +} + func (m *MySQLStore) Use() *gorm.DB { r := m.Config() if r == nil { @@ -49,6 +57,10 @@ func (m *MySQLStore) Use() *gorm.DB { return nil } + if m.hasCallback { + m.callback(db) + } + sqlDB, err := db.DB() if err != nil { logger.SugarLog.Error("Error pinging database", zap.String("error", err.Error())) diff --git a/store/use_sqlite.go b/store/use_sqlite.go index 3aa5190..7851721 100644 --- a/store/use_sqlite.go +++ b/store/use_sqlite.go @@ -14,6 +14,9 @@ import ( type SQLiteStore struct { configKey string + + callback callback + hasCallback bool } func (m *SQLiteStore) Config() *config.Sqlite { @@ -26,6 +29,11 @@ func (m *SQLiteStore) Config() *config.Sqlite { return r } +func (m *SQLiteStore) Callback(fn callback) { + m.callback = fn + m.hasCallback = true +} + func (m *SQLiteStore) Use() *gorm.DB { r := m.Config() if r == nil { @@ -45,6 +53,10 @@ func (m *SQLiteStore) Use() *gorm.DB { return nil } + if m.hasCallback { + m.callback(db) + } + sqlDB, err := db.DB() if err != nil { logger.SugarLog.Error("Ping SQLite error", diff --git a/store/use_sqlserver.go b/store/use_sqlserver.go index b82fcf5..8c3e121 100644 --- a/store/use_sqlserver.go +++ b/store/use_sqlserver.go @@ -13,6 +13,9 @@ import ( type SqlServerStore struct { configKey string + + callback callback + hasCallback bool } func (m *SqlServerStore) Config() *config.SqlServer { @@ -25,6 +28,11 @@ func (m *SqlServerStore) Config() *config.SqlServer { return r } +func (m *SqlServerStore) Callback(fn callback) { + m.callback = fn + m.hasCallback = true +} + func (m *SqlServerStore) Use() *gorm.DB { r := m.Config() if r == nil { @@ -44,6 +52,10 @@ func (m *SqlServerStore) Use() *gorm.DB { return nil } + if m.hasCallback { + m.callback(db) + } + sqlDB, err := db.DB() if err != nil { logger.SugarLog.Errorf("%s (ping)", err.Error())