feat: refactor user and role services to remove unnecessary database dependency, optimize service instantiation, and improve code clarity

This commit is contained in:
limitcool
2025-04-09 20:40:59 +08:00
parent 569f081fb4
commit 6a39976a00
5 changed files with 27 additions and 34 deletions

3
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,3 @@
{
"makefile.configureOnOpen": false
}

View File

@@ -26,9 +26,7 @@ func (uc *UserController) UserLogin(ctx *gin.Context) {
// 获取客户端IP地址
clientIP := ctx.ClientIP()
db := services.Instance().GetDB()
userService := services.NewNormalUserService(db)
userService := services.NewNormalUserService()
tokenResponse, err := userService.Login(req.Username, req.Password, clientIP)
if err != nil {
if errorx.IsErrCode(err) {
@@ -65,8 +63,7 @@ func (uc *UserController) UserRegister(c *gin.Context) {
RegisterIP: clientIP,
}
db := services.Instance().GetDB()
userService := services.NewNormalUserService(db)
userService := services.NewNormalUserService()
user, err := userService.Register(registerReq)
if err != nil {
if errorx.IsErrCode(err) {
@@ -98,8 +95,7 @@ func (uc *UserController) UserChangePassword(c *gin.Context) {
return
}
db := services.Instance().GetDB()
userService := services.NewNormalUserService(db)
userService := services.NewNormalUserService()
err := userService.ChangePassword(userID.(uint), req.OldPassword, req.NewPassword)
if err != nil {
if errorx.IsErrCode(err) {
@@ -118,8 +114,8 @@ func (uc *UserController) UserInfo(c *gin.Context) {
// 获取用户ID
userID, _ := c.Get("user_id")
db := services.Instance().GetDB()
userService := services.NewNormalUserService(db)
userService := services.NewNormalUserService()
user, err := userService.GetUserByID(userID.(uint))
if err != nil {
if errorx.IsErrCode(err) {

View File

@@ -10,7 +10,6 @@ import (
// RoleService 角色服务
type RoleService struct {
db *gorm.DB
casbinService *CasbinService
}
@@ -20,32 +19,30 @@ func NewRoleService() *RoleService {
if serviceInstance != nil {
// 使用ServiceManager获取依赖服务
return &RoleService{
db: db,
casbinService: serviceInstance.GetCasbinService(),
}
}
// 兼容旧代码如果ServiceManager未初始化则直接创建依赖服务
return &RoleService{
db: db,
casbinService: NewCasbinService(),
}
}
// CreateRole 创建角色
func (s *RoleService) CreateRole(role *model.Role) error {
return s.db.Create(role).Error
return db.Create(role).Error
}
// UpdateRole 更新角色
func (s *RoleService) UpdateRole(role *model.Role) error {
return s.db.Model(&model.Role{}).Where("id = ?", role.ID).Updates(role).Error
return db.Model(&model.Role{}).Where("id = ?", role.ID).Updates(role).Error
}
// DeleteRole 删除角色
func (s *RoleService) DeleteRole(id uint) error {
// 开启事务
return s.db.Transaction(func(tx *gorm.DB) error {
return db.Transaction(func(tx *gorm.DB) error {
// 检查角色是否已分配给用户
var count int64
if err := tx.Model(&model.UserRole{}).Where("role_id = ?", id).Count(&count).Error; err != nil {
@@ -80,21 +77,21 @@ func (s *RoleService) DeleteRole(id uint) error {
// GetRoleByID 根据ID获取角色
func (s *RoleService) GetRoleByID(id uint) (*model.Role, error) {
var role model.Role
err := s.db.Where("id = ?", id).First(&role).Error
err := db.Where("id = ?", id).First(&role).Error
return &role, err
}
// GetRoles 获取角色列表
func (s *RoleService) GetRoles() ([]model.Role, error) {
var roles []model.Role
err := s.db.Order("sort").Find(&roles).Error
err := db.Order("sort").Find(&roles).Error
return roles, err
}
// AssignRolesToUser 为用户分配角色
func (s *RoleService) AssignRolesToUser(userID uint, roleIDs []uint) error {
// 开启事务
return s.db.Transaction(func(tx *gorm.DB) error {
return db.Transaction(func(tx *gorm.DB) error {
// 删除原有的用户角色关联
if err := tx.Where("user_id = ?", userID).Delete(&model.UserRole{}).Error; err != nil {
return err
@@ -155,7 +152,7 @@ func (s *RoleService) AssignRolesToUser(userID uint, roleIDs []uint) error {
// GetUserRoleIDs 获取用户角色ID列表
func (s *RoleService) GetUserRoleIDs(userID uint) ([]uint, error) {
var userRoles []model.UserRole
err := s.db.Where("user_id = ?", userID).Find(&userRoles).Error
err := db.Where("user_id = ?", userID).Find(&userRoles).Error
if err != nil {
return nil, err
}
@@ -171,7 +168,7 @@ func (s *RoleService) GetUserRoleIDs(userID uint) ([]uint, error) {
// GetRoleMenuIDs 获取角色菜单ID列表
func (s *RoleService) GetRoleMenuIDs(roleID uint) ([]uint, error) {
var roleMenus []model.RoleMenu
err := s.db.Where("role_id = ?", roleID).Find(&roleMenus).Error
err := db.Where("role_id = ?", roleID).Find(&roleMenus).Error
if err != nil {
return nil, err
}

View File

@@ -100,7 +100,7 @@ func (sm *ServiceManager) GetOperationLogService() *OperationLogService {
// GetNormalUserService 获取NormalUserService实例
func (sm *ServiceManager) GetNormalUserService() *NormalUserService {
if sm.normalUserService == nil {
sm.normalUserService = NewNormalUserService(sm.GetDB())
sm.normalUserService = NewNormalUserService()
}
return sm.normalUserService
}

View File

@@ -18,20 +18,17 @@ import (
// NormalUserService 普通用户服务
type NormalUserService struct {
db *gorm.DB
}
// NewNormalUserService 创建普通用户服务
func NewNormalUserService(db *gorm.DB) *NormalUserService {
return &NormalUserService{
db: db,
}
func NewNormalUserService() *NormalUserService {
return &NormalUserService{}
}
// GetUserByID 根据ID获取用户
func (s *NormalUserService) GetUserByID(id uint) (*model.User, error) {
var user model.User
err := s.db.First(&user, id).Error
err := db.First(&user, id).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errorx.NewErrCodeMsg(errorx.UserNotFound, "用户不存在")
}
@@ -44,7 +41,7 @@ func (s *NormalUserService) GetUserByID(id uint) (*model.User, error) {
// GetUserByUsername 根据用户名获取用户
func (s *NormalUserService) GetUserByUsername(username string) (*model.User, error) {
var user model.User
err := s.db.Where("username = ?", username).First(&user).Error
err := db.Where("username = ?", username).First(&user).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errorx.NewErrCodeMsg(errorx.UserNotFound, "用户不存在")
}
@@ -76,7 +73,7 @@ type RegisterRequest struct {
func (s *NormalUserService) Register(req RegisterRequest) (*model.User, error) {
// 检查用户名是否已存在
var count int64
if err := s.db.Model(&model.User{}).Where("username = ?", req.Username).Count(&count).Error; err != nil {
if err := db.Model(&model.User{}).Where("username = ?", req.Username).Count(&count).Error; err != nil {
return nil, err
}
if count > 0 {
@@ -103,7 +100,7 @@ func (s *NormalUserService) Register(req RegisterRequest) (*model.User, error) {
RegisterIP: req.RegisterIP,
}
if err := s.db.Create(user).Error; err != nil {
if err := db.Create(user).Error; err != nil {
return nil, err
}
@@ -137,7 +134,7 @@ func (s *NormalUserService) Login(username, password string, ip string) (*LoginR
}
// 更新最后登录时间和IP
s.db.Model(user).Updates(map[string]interface{}{
db.Model(user).Updates(map[string]interface{}{
"last_login": time.Now(),
"last_ip": ip,
})
@@ -190,7 +187,7 @@ func (s *NormalUserService) UpdateUser(id uint, data map[string]interface{}) err
delete(data, "deleted_at")
// 更新用户信息
return s.db.Model(&model.User{}).Where("id = ?", id).Updates(data).Error
return db.Model(&model.User{}).Where("id = ?", id).Updates(data).Error
}
// ChangePassword 修改密码
@@ -213,11 +210,11 @@ func (s *NormalUserService) ChangePassword(id uint, oldPassword, newPassword str
}
// 更新密码
return s.db.Model(&model.User{}).Where("id = ?", id).Update("password", hashedPassword).Error
return db.Model(&model.User{}).Where("id = ?", id).Update("password", hashedPassword).Error
}
func GetUserInfo(ctx *gin.Context) {
userId, exists := ctx.Get("userID")
userId, exists := ctx.Get("userID")
if !exists {
response.Fail(ctx, errorx.UserNoLogin, "")
return