feat: refactor user and role services to remove unnecessary database dependency, optimize service instantiation, and improve code clarity

This commit is contained in:
limitcool
2025-04-09 20:40:59 +08:00
parent 569f081fb4
commit 6a39976a00
5 changed files with 27 additions and 34 deletions

3
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,3 @@
{
"makefile.configureOnOpen": false
}

View File

@@ -26,9 +26,7 @@ func (uc *UserController) UserLogin(ctx *gin.Context) {
// 获取客户端IP地址 // 获取客户端IP地址
clientIP := ctx.ClientIP() clientIP := ctx.ClientIP()
userService := services.NewNormalUserService()
db := services.Instance().GetDB()
userService := services.NewNormalUserService(db)
tokenResponse, err := userService.Login(req.Username, req.Password, clientIP) tokenResponse, err := userService.Login(req.Username, req.Password, clientIP)
if err != nil { if err != nil {
if errorx.IsErrCode(err) { if errorx.IsErrCode(err) {
@@ -65,8 +63,7 @@ func (uc *UserController) UserRegister(c *gin.Context) {
RegisterIP: clientIP, RegisterIP: clientIP,
} }
db := services.Instance().GetDB() userService := services.NewNormalUserService()
userService := services.NewNormalUserService(db)
user, err := userService.Register(registerReq) user, err := userService.Register(registerReq)
if err != nil { if err != nil {
if errorx.IsErrCode(err) { if errorx.IsErrCode(err) {
@@ -98,8 +95,7 @@ func (uc *UserController) UserChangePassword(c *gin.Context) {
return return
} }
db := services.Instance().GetDB() userService := services.NewNormalUserService()
userService := services.NewNormalUserService(db)
err := userService.ChangePassword(userID.(uint), req.OldPassword, req.NewPassword) err := userService.ChangePassword(userID.(uint), req.OldPassword, req.NewPassword)
if err != nil { if err != nil {
if errorx.IsErrCode(err) { if errorx.IsErrCode(err) {
@@ -118,8 +114,8 @@ func (uc *UserController) UserInfo(c *gin.Context) {
// 获取用户ID // 获取用户ID
userID, _ := c.Get("user_id") userID, _ := c.Get("user_id")
db := services.Instance().GetDB()
userService := services.NewNormalUserService(db) userService := services.NewNormalUserService()
user, err := userService.GetUserByID(userID.(uint)) user, err := userService.GetUserByID(userID.(uint))
if err != nil { if err != nil {
if errorx.IsErrCode(err) { if errorx.IsErrCode(err) {

View File

@@ -10,7 +10,6 @@ import (
// RoleService 角色服务 // RoleService 角色服务
type RoleService struct { type RoleService struct {
db *gorm.DB
casbinService *CasbinService casbinService *CasbinService
} }
@@ -20,32 +19,30 @@ func NewRoleService() *RoleService {
if serviceInstance != nil { if serviceInstance != nil {
// 使用ServiceManager获取依赖服务 // 使用ServiceManager获取依赖服务
return &RoleService{ return &RoleService{
db: db,
casbinService: serviceInstance.GetCasbinService(), casbinService: serviceInstance.GetCasbinService(),
} }
} }
// 兼容旧代码如果ServiceManager未初始化则直接创建依赖服务 // 兼容旧代码如果ServiceManager未初始化则直接创建依赖服务
return &RoleService{ return &RoleService{
db: db,
casbinService: NewCasbinService(), casbinService: NewCasbinService(),
} }
} }
// CreateRole 创建角色 // CreateRole 创建角色
func (s *RoleService) CreateRole(role *model.Role) error { func (s *RoleService) CreateRole(role *model.Role) error {
return s.db.Create(role).Error return db.Create(role).Error
} }
// UpdateRole 更新角色 // UpdateRole 更新角色
func (s *RoleService) UpdateRole(role *model.Role) error { func (s *RoleService) UpdateRole(role *model.Role) error {
return s.db.Model(&model.Role{}).Where("id = ?", role.ID).Updates(role).Error return db.Model(&model.Role{}).Where("id = ?", role.ID).Updates(role).Error
} }
// DeleteRole 删除角色 // DeleteRole 删除角色
func (s *RoleService) DeleteRole(id uint) error { func (s *RoleService) DeleteRole(id uint) error {
// 开启事务 // 开启事务
return s.db.Transaction(func(tx *gorm.DB) error { return db.Transaction(func(tx *gorm.DB) error {
// 检查角色是否已分配给用户 // 检查角色是否已分配给用户
var count int64 var count int64
if err := tx.Model(&model.UserRole{}).Where("role_id = ?", id).Count(&count).Error; err != nil { if err := tx.Model(&model.UserRole{}).Where("role_id = ?", id).Count(&count).Error; err != nil {
@@ -80,21 +77,21 @@ func (s *RoleService) DeleteRole(id uint) error {
// GetRoleByID 根据ID获取角色 // GetRoleByID 根据ID获取角色
func (s *RoleService) GetRoleByID(id uint) (*model.Role, error) { func (s *RoleService) GetRoleByID(id uint) (*model.Role, error) {
var role model.Role var role model.Role
err := s.db.Where("id = ?", id).First(&role).Error err := db.Where("id = ?", id).First(&role).Error
return &role, err return &role, err
} }
// GetRoles 获取角色列表 // GetRoles 获取角色列表
func (s *RoleService) GetRoles() ([]model.Role, error) { func (s *RoleService) GetRoles() ([]model.Role, error) {
var roles []model.Role var roles []model.Role
err := s.db.Order("sort").Find(&roles).Error err := db.Order("sort").Find(&roles).Error
return roles, err return roles, err
} }
// AssignRolesToUser 为用户分配角色 // AssignRolesToUser 为用户分配角色
func (s *RoleService) AssignRolesToUser(userID uint, roleIDs []uint) error { func (s *RoleService) AssignRolesToUser(userID uint, roleIDs []uint) error {
// 开启事务 // 开启事务
return s.db.Transaction(func(tx *gorm.DB) error { return db.Transaction(func(tx *gorm.DB) error {
// 删除原有的用户角色关联 // 删除原有的用户角色关联
if err := tx.Where("user_id = ?", userID).Delete(&model.UserRole{}).Error; err != nil { if err := tx.Where("user_id = ?", userID).Delete(&model.UserRole{}).Error; err != nil {
return err return err
@@ -155,7 +152,7 @@ func (s *RoleService) AssignRolesToUser(userID uint, roleIDs []uint) error {
// GetUserRoleIDs 获取用户角色ID列表 // GetUserRoleIDs 获取用户角色ID列表
func (s *RoleService) GetUserRoleIDs(userID uint) ([]uint, error) { func (s *RoleService) GetUserRoleIDs(userID uint) ([]uint, error) {
var userRoles []model.UserRole var userRoles []model.UserRole
err := s.db.Where("user_id = ?", userID).Find(&userRoles).Error err := db.Where("user_id = ?", userID).Find(&userRoles).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -171,7 +168,7 @@ func (s *RoleService) GetUserRoleIDs(userID uint) ([]uint, error) {
// GetRoleMenuIDs 获取角色菜单ID列表 // GetRoleMenuIDs 获取角色菜单ID列表
func (s *RoleService) GetRoleMenuIDs(roleID uint) ([]uint, error) { func (s *RoleService) GetRoleMenuIDs(roleID uint) ([]uint, error) {
var roleMenus []model.RoleMenu var roleMenus []model.RoleMenu
err := s.db.Where("role_id = ?", roleID).Find(&roleMenus).Error err := db.Where("role_id = ?", roleID).Find(&roleMenus).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -100,7 +100,7 @@ func (sm *ServiceManager) GetOperationLogService() *OperationLogService {
// GetNormalUserService 获取NormalUserService实例 // GetNormalUserService 获取NormalUserService实例
func (sm *ServiceManager) GetNormalUserService() *NormalUserService { func (sm *ServiceManager) GetNormalUserService() *NormalUserService {
if sm.normalUserService == nil { if sm.normalUserService == nil {
sm.normalUserService = NewNormalUserService(sm.GetDB()) sm.normalUserService = NewNormalUserService()
} }
return sm.normalUserService return sm.normalUserService
} }

View File

@@ -18,20 +18,17 @@ import (
// NormalUserService 普通用户服务 // NormalUserService 普通用户服务
type NormalUserService struct { type NormalUserService struct {
db *gorm.DB
} }
// NewNormalUserService 创建普通用户服务 // NewNormalUserService 创建普通用户服务
func NewNormalUserService(db *gorm.DB) *NormalUserService { func NewNormalUserService() *NormalUserService {
return &NormalUserService{ return &NormalUserService{}
db: db,
}
} }
// GetUserByID 根据ID获取用户 // GetUserByID 根据ID获取用户
func (s *NormalUserService) GetUserByID(id uint) (*model.User, error) { func (s *NormalUserService) GetUserByID(id uint) (*model.User, error) {
var user model.User var user model.User
err := s.db.First(&user, id).Error err := db.First(&user, id).Error
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errorx.NewErrCodeMsg(errorx.UserNotFound, "用户不存在") return nil, errorx.NewErrCodeMsg(errorx.UserNotFound, "用户不存在")
} }
@@ -44,7 +41,7 @@ func (s *NormalUserService) GetUserByID(id uint) (*model.User, error) {
// GetUserByUsername 根据用户名获取用户 // GetUserByUsername 根据用户名获取用户
func (s *NormalUserService) GetUserByUsername(username string) (*model.User, error) { func (s *NormalUserService) GetUserByUsername(username string) (*model.User, error) {
var user model.User var user model.User
err := s.db.Where("username = ?", username).First(&user).Error err := db.Where("username = ?", username).First(&user).Error
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errorx.NewErrCodeMsg(errorx.UserNotFound, "用户不存在") return nil, errorx.NewErrCodeMsg(errorx.UserNotFound, "用户不存在")
} }
@@ -76,7 +73,7 @@ type RegisterRequest struct {
func (s *NormalUserService) Register(req RegisterRequest) (*model.User, error) { func (s *NormalUserService) Register(req RegisterRequest) (*model.User, error) {
// 检查用户名是否已存在 // 检查用户名是否已存在
var count int64 var count int64
if err := s.db.Model(&model.User{}).Where("username = ?", req.Username).Count(&count).Error; err != nil { if err := db.Model(&model.User{}).Where("username = ?", req.Username).Count(&count).Error; err != nil {
return nil, err return nil, err
} }
if count > 0 { if count > 0 {
@@ -103,7 +100,7 @@ func (s *NormalUserService) Register(req RegisterRequest) (*model.User, error) {
RegisterIP: req.RegisterIP, RegisterIP: req.RegisterIP,
} }
if err := s.db.Create(user).Error; err != nil { if err := db.Create(user).Error; err != nil {
return nil, err return nil, err
} }
@@ -137,7 +134,7 @@ func (s *NormalUserService) Login(username, password string, ip string) (*LoginR
} }
// 更新最后登录时间和IP // 更新最后登录时间和IP
s.db.Model(user).Updates(map[string]interface{}{ db.Model(user).Updates(map[string]interface{}{
"last_login": time.Now(), "last_login": time.Now(),
"last_ip": ip, "last_ip": ip,
}) })
@@ -190,7 +187,7 @@ func (s *NormalUserService) UpdateUser(id uint, data map[string]interface{}) err
delete(data, "deleted_at") delete(data, "deleted_at")
// 更新用户信息 // 更新用户信息
return s.db.Model(&model.User{}).Where("id = ?", id).Updates(data).Error return db.Model(&model.User{}).Where("id = ?", id).Updates(data).Error
} }
// ChangePassword 修改密码 // ChangePassword 修改密码
@@ -213,7 +210,7 @@ func (s *NormalUserService) ChangePassword(id uint, oldPassword, newPassword str
} }
// 更新密码 // 更新密码
return s.db.Model(&model.User{}).Where("id = ?", id).Update("password", hashedPassword).Error return db.Model(&model.User{}).Where("id = ?", id).Update("password", hashedPassword).Error
} }
func GetUserInfo(ctx *gin.Context) { func GetUserInfo(ctx *gin.Context) {