diff --git a/logger/zapgorm.go b/logger/zapgorm.go new file mode 100644 index 0000000..bc663ef --- /dev/null +++ b/logger/zapgorm.go @@ -0,0 +1,78 @@ +package logger + +import ( + "context" + "errors" + "time" + + "go.uber.org/zap" + "gorm.io/gorm" + gormLogger "gorm.io/gorm/logger" +) + +// ZapGormLogger implements gorm.io/gorm/logger.Interface using zap.SugaredLogger. +type ZapGormLogger struct { + sugar *zap.SugaredLogger + cfg gormLogger.Config +} + +// NewZapGormLogger constructs a ZapGormLogger. +// If sugar is nil, it falls back to logger.SugarLog; ensure Init() was called. +func NewZapGormLogger(sugar *zap.SugaredLogger, cfg gormLogger.Config) gormLogger.Interface { + if sugar == nil { + sugar = SugarLog + } + if cfg.SlowThreshold == 0 { + cfg.SlowThreshold = 200 * time.Millisecond + } + return &ZapGormLogger{sugar: sugar, cfg: cfg} +} + +// LogMode sets the logging level. +func (l *ZapGormLogger) LogMode(level gormLogger.LogLevel) gormLogger.Interface { + newCfg := l.cfg + newCfg.LogLevel = level + return &ZapGormLogger{sugar: l.sugar, cfg: newCfg} +} + +func (l *ZapGormLogger) Info(ctx context.Context, msg string, data ...interface{}) { + if l.cfg.LogLevel < gormLogger.Info { + return + } + l.sugar.Infof(msg, data...) +} + +func (l *ZapGormLogger) Warn(ctx context.Context, msg string, data ...interface{}) { + if l.cfg.LogLevel < gormLogger.Warn { + return + } + l.sugar.Warnf(msg, data...) +} + +func (l *ZapGormLogger) Error(ctx context.Context, msg string, data ...interface{}) { + if l.cfg.LogLevel < gormLogger.Error { + return + } + l.sugar.Errorf(msg, data...) +} + +// Trace prints SQL logs. It honors SlowThreshold and IgnoreRecordNotFoundError. +func (l *ZapGormLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rows int64), err error) { + if l.cfg.LogLevel == gormLogger.Silent { + return + } + elapsed := time.Since(begin) + sql, rows := fc() + + switch { + case err != nil && l.cfg.LogLevel >= gormLogger.Error: + if l.cfg.IgnoreRecordNotFoundError && errors.Is(err, gorm.ErrRecordNotFound) { + break + } + l.sugar.Errorf("gorm sql error: elapsed=%s rows=%d err=%v sql=%s", elapsed, rows, err, sql) + case l.cfg.SlowThreshold != 0 && elapsed > l.cfg.SlowThreshold && l.cfg.LogLevel >= gormLogger.Warn: + l.sugar.Warnf("gorm slow sql: elapsed=%s rows=%d sql=%s", elapsed, rows, sql) + case l.cfg.LogLevel >= gormLogger.Info: + l.sugar.Infof("gorm sql: elapsed=%s rows=%d sql=%s", elapsed, rows, sql) + } +} \ No newline at end of file diff --git a/store/use.go b/store/use.go index 0ed9711..e7511de 100644 --- a/store/use.go +++ b/store/use.go @@ -15,22 +15,30 @@ var ( sqlServerStores sync.Map ) -func DB(configKey string) *MySQLStore { +func DB(configKey string, options ...*gorm.Config) *MySQLStore { if store, ok := mysqlStores.Load(configKey); ok { return store.(*MySQLStore) } newStore := &MySQLStore{configKey: configKey} + if len(options) > 0 && options[0] != nil { + newStore.Options(options[0]) + } + mysqlStores.Store(configKey, newStore) return newStore } -func SQLite(configKey string) *SQLiteStore { +func SQLite(configKey string, options ...*gorm.Config) *SQLiteStore { if store, ok := sqliteStores.Load(configKey); ok { return store.(*SQLiteStore) } newStore := &SQLiteStore{configKey: configKey} + if len(options) > 0 && options[0] != nil { + newStore.Options(options[0]) + } + sqliteStores.Store(configKey, newStore) return newStore } @@ -45,12 +53,16 @@ func Redis(configKey string) *RedisStore { return newStore } -func SqlServer(configKey string) *SqlServerStore { +func SqlServer(configKey string, options ...*gorm.Config) *SqlServerStore { if store, ok := sqlServerStores.Load(configKey); ok { return store.(*SqlServerStore) } newStore := &SqlServerStore{configKey: configKey} + if len(options) > 0 && options[0] != nil { + newStore.Options(options[0]) + } + sqlServerStores.Store(configKey, newStore) return newStore } diff --git a/store/use_mysql.go b/store/use_mysql.go index b4657e9..be5eceb 100644 --- a/store/use_mysql.go +++ b/store/use_mysql.go @@ -1,15 +1,15 @@ package store import ( - "github.com/spf13/viper" - "go.uber.org/zap" - "gorm.io/driver/mysql" - "gorm.io/gorm" - gormLogger "gorm.io/gorm/logger" - "gorm.io/gorm/schema" + "github.com/spf13/viper" + "go.uber.org/zap" + "gorm.io/driver/mysql" + "gorm.io/gorm" + gormLogger "gorm.io/gorm/logger" + "gorm.io/gorm/schema" - "github.com/wonli/aqi/internal/config" - "github.com/wonli/aqi/logger" + "github.com/wonli/aqi/internal/config" + "github.com/wonli/aqi/logger" ) type MySQLStore struct { @@ -45,9 +45,9 @@ func (m *MySQLStore) Callback(fn callback) { } func (m *MySQLStore) Use() *gorm.DB { - if m.gormDB != nil { - return m.gormDB - } + if m.gormDB != nil { + return m.gormDB + } r := m.Config() if r == nil { @@ -58,12 +58,12 @@ func (m *MySQLStore) Use() *gorm.DB { return nil } - conf := &gorm.Config{ - Logger: gormLogger.Default.LogMode(gormLogger.LogLevel(r.LogLevel)), - NamingStrategy: schema.NamingStrategy{ - TablePrefix: r.Prefix, - }, - } + conf := &gorm.Config{ + Logger: logger.NewZapGormLogger(logger.SugarLog, gormLogger.Config{LogLevel: gormLogger.LogLevel(r.LogLevel)}), + NamingStrategy: schema.NamingStrategy{ + TablePrefix: r.Prefix, + }, + } if m.options != nil { if m.options.Logger == nil { @@ -77,11 +77,11 @@ func (m *MySQLStore) Use() *gorm.DB { m.options = conf } - db, err := gorm.Open(mysql.Open(r.GetDsn()), conf) - if err != nil { - logger.SugarLog.Error("Failed to connect to MySQL database", zap.String("error", err.Error())) - return nil - } + db, err := gorm.Open(mysql.Open(r.GetDsn()), m.options) + if err != nil { + logger.SugarLog.Error("Failed to connect to MySQL database", zap.String("error", err.Error())) + return nil + } if m.hasCallback { m.callback(db)