diff --git a/app/http/controller/api.go b/app/http/controller/api.go index 48266b8..ce812a8 100644 --- a/app/http/controller/api.go +++ b/app/http/controller/api.go @@ -6,16 +6,15 @@ import ( ) type Api struct { - } // CurrentUser 获取当前用户 -func (a Api)currentUser(c *gin.Context) *model.User { +func (a Api) currentUser(c *gin.Context) *model.User { if userId, _ := c.Get("userId"); userId != nil { - user, err := model.GetUser(userId) + user, err := model.GetUser(c, userId) if err == nil { return &user } } return nil -} \ No newline at end of file +} diff --git a/app/http/controller/user.go b/app/http/controller/user.go index 6b58229..44a4bf2 100644 --- a/app/http/controller/user.go +++ b/app/http/controller/user.go @@ -28,7 +28,7 @@ func (a UserController) UserRegister(c *gin.Context) (res interface{}, err error return } - res = service.Register(param) + res = service.Register(c, param) return } diff --git a/app/http/service/user_service.go b/app/http/service/user_service.go index bcd01e4..5b82273 100644 --- a/app/http/service/user_service.go +++ b/app/http/service/user_service.go @@ -14,10 +14,10 @@ type UserService struct { } // Login 用户登录函数 -func (service *UserService) Login(c *gin.Context, loginRequest request.LoginRequest) response.Response { +func (service *UserService) Login(ctx *gin.Context, loginRequest request.LoginRequest) response.Response { var userModel model.User - if err := global.DB.Where("user_name = ?", loginRequest.UserName).First(&userModel).Error; err != nil { + if err := global.DB(ctx).Where("user_name = ?", loginRequest.UserName).First(&userModel).Error; err != nil { return response.ParamErr("账号或密码错误", nil, nil) } @@ -34,7 +34,7 @@ func (service *UserService) Login(c *gin.Context, loginRequest request.LoginRequ } // valid 验证表单 -func (service *UserService) valid(registerRequest request.RegisterRequest) *response.Response { +func (service *UserService) valid(ctx *gin.Context, registerRequest request.RegisterRequest) *response.Response { if registerRequest.PasswordConfirm != registerRequest.Password { return &response.Response{ Code: 40001, @@ -43,7 +43,7 @@ func (service *UserService) valid(registerRequest request.RegisterRequest) *resp } count := int64(0) - global.DB.Model(&model.User{}).Where("nickname = ?", registerRequest.Nickname).Count(&count) + global.DB(ctx).Model(&model.User{}).Where("nickname = ?", registerRequest.Nickname).Count(&count) if count > 0 { return &response.Response{ Code: 40001, @@ -52,7 +52,7 @@ func (service *UserService) valid(registerRequest request.RegisterRequest) *resp } count = 0 - global.DB.Model(&model.User{}).Where("user_name = ?", registerRequest.UserName).Count(&count) + global.DB(ctx).Model(&model.User{}).Where("user_name = ?", registerRequest.UserName).Count(&count) if count > 0 { return &response.Response{ Code: 40001, @@ -64,7 +64,7 @@ func (service *UserService) valid(registerRequest request.RegisterRequest) *resp } // Register 用户注册 -func (service *UserService) Register(registerRequest request.RegisterRequest) response.Response { +func (service *UserService) Register(ctx *gin.Context, registerRequest request.RegisterRequest) response.Response { user := model.User{ Nickname: registerRequest.Nickname, UserName: registerRequest.UserName, @@ -72,7 +72,7 @@ func (service *UserService) Register(registerRequest request.RegisterRequest) re } // 表单验证 - if err := service.valid(registerRequest); err != nil { + if err := service.valid(ctx, registerRequest); err != nil { return *err } @@ -87,7 +87,7 @@ func (service *UserService) Register(registerRequest request.RegisterRequest) re } // 创建用户 - if err := global.DB.Create(&user).Error; err != nil { + if err := global.DB(ctx).Create(&user).Error; err != nil { return response.ParamErr("注册失败", nil, err) } diff --git a/app/model/user.go b/app/model/user.go index b22683d..d58817d 100644 --- a/app/model/user.go +++ b/app/model/user.go @@ -1,6 +1,7 @@ package model import ( + "context" "fastApi/core/global" "github.com/golang-jwt/jwt/v4" "golang.org/x/crypto/bcrypt" @@ -34,9 +35,9 @@ const ( ) // GetUser 用ID获取用户 -func GetUser(ID interface{}) (User, error) { +func GetUser(ctx context.Context, ID interface{}) (User, error) { var user User - result := global.DB.First(&user, ID) + result := global.DB(ctx).First(&user, ID) return user, result.Error } diff --git a/cmd/migrate.go b/cmd/migrate.go index 622fd17..80d4551 100644 --- a/cmd/migrate.go +++ b/cmd/migrate.go @@ -16,5 +16,5 @@ func init() { } func migrate(cmd *cobra.Command, args []string) { - _ = global.DB.AutoMigrate(&model.User{}) + _ = global.GDB.AutoMigrate(&model.User{}) } diff --git a/core/global/core.go b/core/global/core.go index ea8b051..f132541 100644 --- a/core/global/core.go +++ b/core/global/core.go @@ -1,6 +1,7 @@ package global import ( + "context" ut "github.com/go-playground/universal-translator" "github.com/go-redis/redis" "github.com/nsqio/go-nsq" @@ -8,11 +9,16 @@ import ( "gorm.io/gorm" ) +const DBKey = "DB" + var ( Trans ut.Translator // 定义一个全局翻译器T Log *zap.Logger - SLog *zap.SugaredLogger - DB *gorm.DB // DB 数据库链接单例 + GDB *gorm.DB // DB 数据库链接单例 Redis *redis.Client Producer *nsq.Producer ) + +func DB(ctx context.Context) *gorm.DB { + return ctx.Value(DBKey).(*gorm.DB) +} diff --git a/core/gorm.go b/core/gorm.go index fde628e..bfb8d4e 100644 --- a/core/gorm.go +++ b/core/gorm.go @@ -43,5 +43,5 @@ func Database() { sqlDB.SetMaxIdleConns(viper.GetInt("database.max_idle_conn")) //打开 sqlDB.SetMaxOpenConns(viper.GetInt("database.max_open_conn")) - global.DB = db + global.GDB = db } diff --git a/core/logger/zap.go b/core/logger/zap.go index 29e9a04..3c2ca09 100644 --- a/core/logger/zap.go +++ b/core/logger/zap.go @@ -1,7 +1,9 @@ package logger import ( + "context" "fastApi/core/global" + "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/spf13/viper" "go.uber.org/zap" @@ -10,6 +12,8 @@ import ( "os" ) +const loggerKey = "Log" +const loggerSugarKey = "LogSugar" const TraceId = "traceId" var logger *zap.Logger @@ -56,8 +60,31 @@ func CalcTraceId() (traceId string) { return uuid.New().String() } -func With(fields ...zap.Field) { - global.Log = logger.With(fields...) - global.SLog = global.Log.Sugar() - global.DB.Logger = NewGormLog(global.Log) +func With(c *gin.Context, fields ...zap.Field) { + log := logger.With(fields...) + slog := log.Sugar() + db := global.GDB + db.Logger = NewGormLog(log) + c.Set(loggerKey, log) + c.Set(loggerSugarKey, slog) + c.Set(global.DBKey, db) +} + +func WithC(c context.Context, fields ...zap.Field) context.Context { + log := logger.With(fields...) + slog := log.Sugar() + db := global.GDB + db.Logger = NewGormLog(log) + c = context.WithValue(c, loggerKey, log) + c = context.WithValue(c, loggerSugarKey, slog) + c = context.WithValue(c, global.DBKey, db) + return c +} + +func Log(c context.Context) *zap.Logger { + return c.Value(loggerKey).(*zap.Logger) +} + +func SLog(c context.Context) *zap.SugaredLogger { + return c.Value(loggerSugarKey).(*zap.SugaredLogger) } diff --git a/core/middleware/zap.go b/core/middleware/zap.go index a5bee90..4254395 100644 --- a/core/middleware/zap.go +++ b/core/middleware/zap.go @@ -128,6 +128,7 @@ func AddTraceId() gin.HandlerFunc { traceId := logger.CalcTraceId() ctx.Set(logger.TraceId, traceId) logger.With( + ctx, zap.String("traceId", traceId), ) ctx.Next() diff --git a/crontab/cron_init.go b/crontab/cron_init.go index e0d812d..95a4dff 100644 --- a/crontab/cron_init.go +++ b/crontab/cron_init.go @@ -50,8 +50,9 @@ func CronInit() { Cron.Start() } -func WithRequestId(name, traceId string) { - logger.With( +func WithRequestId(ctx context.Context, name, traceId string) context.Context { + return logger.WithC( + ctx, zap.String("traceId", traceId), zap.String("name", name), ) @@ -62,7 +63,7 @@ func BaseCronFuc(name string, cmd func(context.Context)) func() { traceId := uuid.New().String() ctx := context.WithValue(context.Background(), logger.TraceId, traceId) - WithRequestId(name, traceId) + ctx = WithRequestId(ctx, name, traceId) cmd(ctx) } } diff --git a/crontab/testJob.go b/crontab/testJob.go index fae6163..501be77 100644 --- a/crontab/testJob.go +++ b/crontab/testJob.go @@ -3,7 +3,7 @@ package crontab import ( "context" "fastApi/app/model" - "fastApi/core/global" + "fastApi/core/logger" ) func init() { @@ -22,6 +22,6 @@ func (j testJob) getName() string { } func (j testJob) Run(ctx context.Context) { - model.GetUser(1) - global.Log.Info("tick every 1 second run once") + model.GetUser(ctx, 1) + logger.Log(ctx).Info("tick every 1 second run once") } diff --git a/mq/mq_base.go b/mq/mq_base.go index b63b81f..32edd0d 100644 --- a/mq/mq_base.go +++ b/mq/mq_base.go @@ -14,6 +14,8 @@ import ( var MQList []InterfaceMQ +type HandleFunc func(context.Context, string) error + type InterfaceMQ interface { Producer(ctx context.Context, message []byte, delay ...time.Duration) error HandleMessage(msg *nsq.Message) error @@ -65,7 +67,7 @@ func (b *BaseMQ) Producer(ctx context.Context, message []byte, delay ...time.Dur return err } -func (b *BaseMQ) Handle(msg *nsq.Message, h func(string) error) error { +func (b *BaseMQ) Handle(msg *nsq.Message, h HandleFunc) error { startTime := time.Now() var data map[string]string @@ -79,14 +81,16 @@ func (b *BaseMQ) Handle(msg *nsq.Message, h func(string) error) error { ).Error("数据解析失败: " + err.Error()) } - logger.With( + ctx := context.WithValue(context.Background(), logger.TraceId, data["traceId"]) + logger.WithC( + ctx, zap.String("traceId", data["traceId"]), ) - err = h(data["message"]) + err = h(ctx, data["message"]) endTime := time.Now() latencyTime := endTime.Sub(startTime) - log := global.Log.With( + log := logger.Log(ctx).With( zap.String("url", b.GetTopic()), zap.String("params", data["message"]), zap.Uint16("attempts", msg.Attempts), diff --git a/mq/send_registered_email.go b/mq/send_registered_email.go index ea18392..37ac154 100644 --- a/mq/send_registered_email.go +++ b/mq/send_registered_email.go @@ -1,8 +1,9 @@ package mq import ( + "context" "fastApi/app/model" - "fastApi/core/global" + "fastApi/core/logger" "github.com/nsqio/go-nsq" ) @@ -11,9 +12,9 @@ type SendRegisteredEmail struct { } func (c *SendRegisteredEmail) HandleMessage(msg *nsq.Message) error { - return c.Handle(msg, func(data string) error { - model.GetUser(1) - global.Log.Info("ok") + return c.Handle(msg, func(ctx context.Context, data string) error { + model.GetUser(ctx, 1) + logger.Log(ctx).Info("ok") return nil }) }