mirror of
https://github.com/limitcool/starter.git
synced 2025-09-27 04:36:18 +08:00
Refactor user handler and middleware for improved error handling and logging
- Consolidated user ID retrieval and permission checks into helper functions. - Updated UserHandler to utilize BaseHandler for common database and configuration access. - Enhanced logging for user-related operations, including login, registration, and password changes. - Removed redundant context handling in middleware and improved readability. - Introduced FileUtil for file URL generation and management, encapsulating file-related logic. - Refactored FileRepo and UserRepo to streamline database operations and error handling. - Deleted unused request_id middleware and integrated its functionality into request_logger. - Removed legacy test runner script to simplify testing process.
This commit is contained in:
6
Makefile
6
Makefile
@@ -86,7 +86,11 @@ version:
|
|||||||
.PHONY: build-dev
|
.PHONY: build-dev
|
||||||
build-dev:
|
build-dev:
|
||||||
@echo "Building $(APP_NAME) for development..."
|
@echo "Building $(APP_NAME) for development..."
|
||||||
go build $(BUILDFLAGS) -ldflags="$(LDFLAGS)" -o $(APP_NAME)$(if $(filter windows,$(shell go env GOOS)),.exe,)
|
ifeq ($(OS),Windows_NT)
|
||||||
|
go build $(BUILDFLAGS) -ldflags="$(LDFLAGS)" -o $(APP_NAME).exe
|
||||||
|
else
|
||||||
|
go build $(BUILDFLAGS) -ldflags="$(LDFLAGS)" -o $(APP_NAME)
|
||||||
|
endif
|
||||||
|
|
||||||
# Docker 相关目标
|
# Docker 相关目标
|
||||||
.PHONY: docker-build
|
.PHONY: docker-build
|
||||||
|
@@ -42,44 +42,65 @@ type Handlers struct {
|
|||||||
Admin *handler.AdminHandler
|
Admin *handler.AdminHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InitStep 初始化步骤
|
||||||
|
type InitStep struct {
|
||||||
|
Name string
|
||||||
|
Required bool
|
||||||
|
Init func() error
|
||||||
|
}
|
||||||
|
|
||||||
// New 创建新的应用实例
|
// New 创建新的应用实例
|
||||||
func New(config *configs.Config) (*App, error) {
|
func New(config *configs.Config) (*App, error) {
|
||||||
app := &App{
|
app := &App{config: config}
|
||||||
config: config,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 按顺序初始化各个组件
|
// 定义初始化步骤
|
||||||
if err := app.initDatabase(); err != nil {
|
steps := app.getInitSteps()
|
||||||
return nil, fmt.Errorf("failed to initialize database: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := app.initRedis(); err != nil {
|
// 按顺序执行初始化
|
||||||
return nil, fmt.Errorf("failed to initialize redis: %w", err)
|
for _, step := range steps {
|
||||||
}
|
logger.Info("Initializing component", "component", step.Name)
|
||||||
|
|
||||||
if err := app.initStorage(); err != nil {
|
if err := step.Init(); err != nil {
|
||||||
return nil, fmt.Errorf("failed to initialize storage: %w", err)
|
logger.Error("Failed to initialize component",
|
||||||
}
|
"component", step.Name,
|
||||||
|
"required", step.Required,
|
||||||
|
"error", err)
|
||||||
|
|
||||||
if err := app.initHandlers(); err != nil {
|
if step.Required {
|
||||||
return nil, fmt.Errorf("failed to initialize handlers: %w", err)
|
return nil, fmt.Errorf("failed to initialize required component %s: %w", step.Name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := app.initRouter(); err != nil {
|
logger.Warn("Optional component initialization failed, continuing",
|
||||||
return nil, fmt.Errorf("failed to initialize router: %w", err)
|
"component", step.Name)
|
||||||
}
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if err := app.initServer(); err != nil {
|
logger.Info("Component initialized successfully", "component", step.Name)
|
||||||
return nil, fmt.Errorf("failed to initialize server: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := app.initPprof(); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to initialize pprof: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return app, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getInitSteps 获取初始化步骤列表
|
||||||
|
func (app *App) getInitSteps() []InitStep {
|
||||||
|
steps := []InitStep{
|
||||||
|
// 数据库和Redis根据配置启用,失败时不影响应用启动(内部有禁用检查)
|
||||||
|
{Name: "database", Required: false, Init: app.initDatabase},
|
||||||
|
{Name: "redis", Required: false, Init: app.initRedis},
|
||||||
|
|
||||||
|
// 存储服务是可选的,某些功能可能需要它
|
||||||
|
{Name: "storage", Required: false, Init: app.initStorage},
|
||||||
|
|
||||||
|
// 核心组件,必须成功初始化
|
||||||
|
{Name: "handlers", Required: true, Init: app.initHandlers},
|
||||||
|
{Name: "router", Required: true, Init: app.initRouter},
|
||||||
|
{Name: "server", Required: true, Init: app.initServer},
|
||||||
|
{Name: "pprof", Required: false, Init: app.initPprof},
|
||||||
|
}
|
||||||
|
|
||||||
|
return steps
|
||||||
|
}
|
||||||
|
|
||||||
// initDatabase 初始化数据库连接
|
// initDatabase 初始化数据库连接
|
||||||
func (a *App) initDatabase() error {
|
func (a *App) initDatabase() error {
|
||||||
if !a.config.Database.Enabled {
|
if !a.config.Database.Enabled {
|
||||||
|
@@ -10,18 +10,16 @@ import (
|
|||||||
|
|
||||||
// AdminHandler 管理员处理器
|
// AdminHandler 管理员处理器
|
||||||
type AdminHandler struct {
|
type AdminHandler struct {
|
||||||
db *gorm.DB
|
*BaseHandler
|
||||||
config *configs.Config
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAdminHandler 创建管理员处理器
|
// NewAdminHandler 创建管理员处理器
|
||||||
func NewAdminHandler(db *gorm.DB, config *configs.Config) *AdminHandler {
|
func NewAdminHandler(db *gorm.DB, config *configs.Config) *AdminHandler {
|
||||||
handler := &AdminHandler{
|
handler := &AdminHandler{
|
||||||
db: db,
|
BaseHandler: NewBaseHandler(db, config),
|
||||||
config: config,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("AdminHandler initialized")
|
handler.LogInit("AdminHandler")
|
||||||
return handler
|
return handler
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -35,9 +33,9 @@ func (h *AdminHandler) GetSystemSettings(ctx *gin.Context) {
|
|||||||
|
|
||||||
// 返回系统设置
|
// 返回系统设置
|
||||||
settings := map[string]any{
|
settings := map[string]any{
|
||||||
"app_name": h.config.App.Name,
|
"app_name": h.Config.App.Name,
|
||||||
"app_version": "1.0.0",
|
"app_version": "1.0.0",
|
||||||
"app_mode": h.config.App.Mode,
|
"app_mode": h.Config.App.Mode,
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(ctx, settings)
|
response.Success(ctx, settings)
|
||||||
|
30
internal/handler/base_handler.go
Normal file
30
internal/handler/base_handler.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/limitcool/starter/configs"
|
||||||
|
"github.com/limitcool/starter/internal/pkg/logger"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BaseHandler 基础处理器,包含所有Handler的公共字段和方法
|
||||||
|
type BaseHandler struct {
|
||||||
|
DB *gorm.DB
|
||||||
|
Config *configs.Config
|
||||||
|
Helper *HandlerHelper
|
||||||
|
FileUtil *FileUtil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBaseHandler 创建基础处理器
|
||||||
|
func NewBaseHandler(db *gorm.DB, config *configs.Config) *BaseHandler {
|
||||||
|
return &BaseHandler{
|
||||||
|
DB: db,
|
||||||
|
Config: config,
|
||||||
|
Helper: NewHandlerHelper(),
|
||||||
|
FileUtil: NewFileUtil("/uploads"), // 默认基础URL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogInit 记录Handler初始化日志
|
||||||
|
func (h *BaseHandler) LogInit(handlerName string) {
|
||||||
|
logger.Info(handlerName + " initialized")
|
||||||
|
}
|
@@ -19,20 +19,18 @@ import (
|
|||||||
|
|
||||||
// FileHandler 文件处理器
|
// FileHandler 文件处理器
|
||||||
type FileHandler struct {
|
type FileHandler struct {
|
||||||
db *gorm.DB
|
*BaseHandler
|
||||||
config *configs.Config
|
|
||||||
storage *filestore.Storage
|
storage *filestore.Storage
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFileHandler 创建文件处理器
|
// NewFileHandler 创建文件处理器
|
||||||
func NewFileHandler(db *gorm.DB, config *configs.Config, storage *filestore.Storage) *FileHandler {
|
func NewFileHandler(db *gorm.DB, config *configs.Config, storage *filestore.Storage) *FileHandler {
|
||||||
handler := &FileHandler{
|
handler := &FileHandler{
|
||||||
db: db,
|
BaseHandler: NewBaseHandler(db, config),
|
||||||
config: config,
|
storage: storage,
|
||||||
storage: storage,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("FileHandler initialized")
|
handler.LogInit("FileHandler")
|
||||||
return handler
|
return handler
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -72,7 +70,7 @@ func (h *FileHandler) GetUploadURL(ctx *gin.Context) {
|
|||||||
|
|
||||||
// 验证文件类型
|
// 验证文件类型
|
||||||
ext := strings.ToLower(filepath.Ext(req.Filename))
|
ext := strings.ToLower(filepath.Ext(req.Filename))
|
||||||
if !isAllowedFileType(ext, req.FileType) {
|
if !h.FileUtil.IsAllowedFileType(ext, req.FileType) {
|
||||||
logger.WarnContext(reqCtx, "GetUploadURL 不支持的文件类型",
|
logger.WarnContext(reqCtx, "GetUploadURL 不支持的文件类型",
|
||||||
"user_id", id,
|
"user_id", id,
|
||||||
"file_type", req.FileType,
|
"file_type", req.FileType,
|
||||||
@@ -82,8 +80,8 @@ func (h *FileHandler) GetUploadURL(ctx *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 生成唯一的文件名和存储路径
|
// 生成唯一的文件名和存储路径
|
||||||
fileName := generateFileName(req.Filename)
|
fileName := h.FileUtil.GenerateFileName(req.Filename)
|
||||||
storagePath := getStoragePath(req.FileType, fileName, req.IsPublic)
|
storagePath := h.FileUtil.GetStoragePath(req.FileType, fileName, req.IsPublic)
|
||||||
|
|
||||||
// 生成上传预签名URL
|
// 生成上传预签名URL
|
||||||
uploadURL, err := h.storage.GetUploadPresignedURL(reqCtx, storagePath, req.ContentType, 15) // 15分钟有效期
|
uploadURL, err := h.storage.GetUploadPresignedURL(reqCtx, storagePath, req.ContentType, 15) // 15分钟有效期
|
||||||
@@ -97,7 +95,7 @@ func (h *FileHandler) GetUploadURL(ctx *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 创建文件记录(状态为pending,等待上传完成确认)
|
// 创建文件记录(状态为pending,等待上传完成确认)
|
||||||
fileRepo := model.NewFileRepo(h.db)
|
fileRepo := model.NewFileRepo(h.DB)
|
||||||
fileModel := &model.File{
|
fileModel := &model.File{
|
||||||
Name: fileName,
|
Name: fileName,
|
||||||
OriginalName: req.Filename,
|
OriginalName: req.Filename,
|
||||||
@@ -172,7 +170,7 @@ func (h *FileHandler) ConfirmUpload(ctx *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 创建文件仓库
|
// 创建文件仓库
|
||||||
fileRepo := model.NewFileRepo(h.db)
|
fileRepo := model.NewFileRepo(h.DB)
|
||||||
|
|
||||||
// 获取文件记录
|
// 获取文件记录
|
||||||
fileModel, err := fileRepo.GetByID(reqCtx, req.FileID)
|
fileModel, err := fileRepo.GetByID(reqCtx, req.FileID)
|
||||||
@@ -312,7 +310,7 @@ func (h *FileHandler) UploadFile(ctx *gin.Context) {
|
|||||||
size := fileHeader.Size
|
size := fileHeader.Size
|
||||||
|
|
||||||
// 验证文件类型
|
// 验证文件类型
|
||||||
if !isAllowedFileType(ext, fileType) {
|
if !h.FileUtil.IsAllowedFileType(ext, fileType) {
|
||||||
logger.WarnContext(reqCtx, "UploadFile 文件类型不允许",
|
logger.WarnContext(reqCtx, "UploadFile 文件类型不允许",
|
||||||
"user_id", id,
|
"user_id", id,
|
||||||
"file_type", fileType,
|
"file_type", fileType,
|
||||||
@@ -322,7 +320,7 @@ func (h *FileHandler) UploadFile(ctx *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 验证文件大小
|
// 验证文件大小
|
||||||
if !isAllowedFileSize(size, fileType) {
|
if !h.FileUtil.IsAllowedFileSize(size, fileType) {
|
||||||
logger.WarnContext(reqCtx, "UploadFile 文件大小超出限制",
|
logger.WarnContext(reqCtx, "UploadFile 文件大小超出限制",
|
||||||
"user_id", id,
|
"user_id", id,
|
||||||
"file_type", fileType,
|
"file_type", fileType,
|
||||||
@@ -343,8 +341,8 @@ func (h *FileHandler) UploadFile(ctx *gin.Context) {
|
|||||||
defer file.Close()
|
defer file.Close()
|
||||||
|
|
||||||
// 生成文件名和存储路径
|
// 生成文件名和存储路径
|
||||||
fileName := generateFileName(originalName)
|
fileName := h.FileUtil.GenerateFileName(originalName)
|
||||||
storagePath := getStoragePath(fileType, fileName, isPublic)
|
storagePath := h.FileUtil.GetStoragePath(fileType, fileName, isPublic)
|
||||||
|
|
||||||
// 上传文件到存储(权限由路径和Bucket Policy控制)
|
// 上传文件到存储(权限由路径和Bucket Policy控制)
|
||||||
err = h.storage.Put(reqCtx, storagePath, file)
|
err = h.storage.Put(reqCtx, storagePath, file)
|
||||||
@@ -369,7 +367,7 @@ func (h *FileHandler) UploadFile(ctx *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 创建文件仓库
|
// 创建文件仓库
|
||||||
fileRepo := model.NewFileRepo(h.db)
|
fileRepo := model.NewFileRepo(h.DB)
|
||||||
|
|
||||||
// 记录到数据库
|
// 记录到数据库
|
||||||
fileModel := &model.File{
|
fileModel := &model.File{
|
||||||
@@ -434,7 +432,7 @@ func (h *FileHandler) GetFileURL(ctx *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 创建文件仓库
|
// 创建文件仓库
|
||||||
fileRepo := model.NewFileRepo(h.db)
|
fileRepo := model.NewFileRepo(h.DB)
|
||||||
|
|
||||||
// 获取文件信息
|
// 获取文件信息
|
||||||
fileModel, err := fileRepo.GetByID(reqCtx, fileID)
|
fileModel, err := fileRepo.GetByID(reqCtx, fileID)
|
||||||
@@ -569,7 +567,7 @@ func (h *FileHandler) ServePublicFile(ctx *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 创建文件仓库
|
// 创建文件仓库
|
||||||
fileRepo := model.NewFileRepo(h.db)
|
fileRepo := model.NewFileRepo(h.DB)
|
||||||
|
|
||||||
// 获取文件信息
|
// 获取文件信息
|
||||||
fileModel, err := fileRepo.GetByID(reqCtx, fileID)
|
fileModel, err := fileRepo.GetByID(reqCtx, fileID)
|
||||||
@@ -632,7 +630,7 @@ func (h *FileHandler) ListFiles(ctx *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 创建文件仓库
|
// 创建文件仓库
|
||||||
fileRepo := model.NewFileRepo(h.db)
|
fileRepo := model.NewFileRepo(h.DB)
|
||||||
|
|
||||||
// 获取文件列表
|
// 获取文件列表
|
||||||
files, total, err := fileRepo.ListFiles(reqCtx, page, pageSize, fileType, usage, nil)
|
files, total, err := fileRepo.ListFiles(reqCtx, page, pageSize, fileType, usage, nil)
|
||||||
@@ -678,7 +676,7 @@ func (h *FileHandler) GetFileInfo(ctx *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 创建文件仓库
|
// 创建文件仓库
|
||||||
fileRepo := model.NewFileRepo(h.db)
|
fileRepo := model.NewFileRepo(h.DB)
|
||||||
|
|
||||||
// 获取文件信息
|
// 获取文件信息
|
||||||
fileModel, err := fileRepo.GetByID(reqCtx, fileID)
|
fileModel, err := fileRepo.GetByID(reqCtx, fileID)
|
||||||
@@ -696,92 +694,3 @@ func (h *FileHandler) GetFileInfo(ctx *gin.Context) {
|
|||||||
|
|
||||||
response.Success(ctx, fileModel)
|
response.Success(ctx, fileModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 生成文件名
|
|
||||||
func generateFileName(originalName string) string {
|
|
||||||
ext := filepath.Ext(originalName)
|
|
||||||
name := fmt.Sprintf("%d_%s%s", time.Now().UnixNano(), randString(8), ext)
|
|
||||||
return name
|
|
||||||
}
|
|
||||||
|
|
||||||
// 随机字符串
|
|
||||||
func randString(n int) string {
|
|
||||||
const letters = "abcdefghijklmnopqrstuvwxyz0123456789"
|
|
||||||
b := make([]byte, n)
|
|
||||||
for i := range b {
|
|
||||||
b[i] = letters[int(time.Now().UnixNano()%int64(len(letters)))]
|
|
||||||
time.Sleep(1 * time.Nanosecond) // 确保随机性
|
|
||||||
}
|
|
||||||
return string(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取存储路径
|
|
||||||
func getStoragePath(fileType, fileName string, isPublic bool) string {
|
|
||||||
// 根据是否公开选择根目录
|
|
||||||
var rootDir string
|
|
||||||
if isPublic {
|
|
||||||
rootDir = "public"
|
|
||||||
} else {
|
|
||||||
rootDir = "private"
|
|
||||||
}
|
|
||||||
|
|
||||||
// 根据文件类型选择子目录
|
|
||||||
var typeDir string
|
|
||||||
switch fileType {
|
|
||||||
case model.FileTypeImage:
|
|
||||||
typeDir = "images"
|
|
||||||
case model.FileTypeDocument:
|
|
||||||
typeDir = "documents"
|
|
||||||
case model.FileTypeVideo:
|
|
||||||
typeDir = "videos"
|
|
||||||
case model.FileTypeAudio:
|
|
||||||
typeDir = "audios"
|
|
||||||
default:
|
|
||||||
typeDir = "others"
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加日期子目录
|
|
||||||
dateDir := time.Now().Format("2006/01/02")
|
|
||||||
|
|
||||||
// 构建完整路径:public/documents/2025/06/17/filename.txt
|
|
||||||
path := filepath.Join(rootDir, typeDir, dateDir, fileName)
|
|
||||||
// 确保Windows上也使用正斜杠
|
|
||||||
return strings.ReplaceAll(path, "\\", "/")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查文件类型是否允许
|
|
||||||
func isAllowedFileType(ext string, fileType string) bool {
|
|
||||||
ext = strings.ToLower(ext)
|
|
||||||
switch fileType {
|
|
||||||
case model.FileTypeImage:
|
|
||||||
return ext == ".jpg" || ext == ".jpeg" || ext == ".png" || ext == ".gif" || ext == ".webp" || ext == ".svg"
|
|
||||||
case model.FileTypeDocument:
|
|
||||||
return ext == ".pdf" || ext == ".doc" || ext == ".docx" || ext == ".xls" || ext == ".xlsx" || ext == ".txt"
|
|
||||||
case model.FileTypeVideo:
|
|
||||||
return ext == ".mp4" || ext == ".avi" || ext == ".mov" || ext == ".wmv" || ext == ".flv"
|
|
||||||
case model.FileTypeAudio:
|
|
||||||
return ext == ".mp3" || ext == ".wav" || ext == ".ogg" || ext == ".flac"
|
|
||||||
default:
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查文件大小是否允许
|
|
||||||
func isAllowedFileSize(size int64, fileType string) bool {
|
|
||||||
const (
|
|
||||||
MB = 1024 * 1024
|
|
||||||
)
|
|
||||||
|
|
||||||
switch fileType {
|
|
||||||
case model.FileTypeImage:
|
|
||||||
return size <= 10*MB // 图片最大10MB
|
|
||||||
case model.FileTypeDocument:
|
|
||||||
return size <= 50*MB // 文档最大50MB
|
|
||||||
case model.FileTypeVideo:
|
|
||||||
return size <= 500*MB // 视频最大500MB
|
|
||||||
case model.FileTypeAudio:
|
|
||||||
return size <= 100*MB // 音频最大100MB
|
|
||||||
default:
|
|
||||||
return size <= 50*MB // 其他类型最大50MB
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
191
internal/handler/file_util.go
Normal file
191
internal/handler/file_util.go
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/limitcool/starter/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FileUtil 文件工具类
|
||||||
|
type FileUtil struct {
|
||||||
|
baseURL string // 基础URL,如 "/uploads" 或 "https://cdn.example.com"
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFileUtil 创建文件工具实例
|
||||||
|
func NewFileUtil(baseURL string) *FileUtil {
|
||||||
|
return &FileUtil{
|
||||||
|
baseURL: baseURL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildFileURL 构建文件访问URL
|
||||||
|
func (f *FileUtil) BuildFileURL(path string) string {
|
||||||
|
if path == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确保路径以正斜杠开头
|
||||||
|
if !strings.HasPrefix(path, "/") {
|
||||||
|
path = "/" + path
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果baseURL已经是完整URL,直接拼接
|
||||||
|
if strings.HasPrefix(f.baseURL, "http") {
|
||||||
|
return f.baseURL + path
|
||||||
|
}
|
||||||
|
|
||||||
|
// 否则作为相对路径处理
|
||||||
|
return f.baseURL + path
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateFileName 生成唯一的文件名
|
||||||
|
func (f *FileUtil) GenerateFileName(originalName string) string {
|
||||||
|
ext := filepath.Ext(originalName)
|
||||||
|
name := fmt.Sprintf("%d_%s%s", time.Now().UnixNano(), randString(8), ext)
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStoragePath 获取存储路径
|
||||||
|
func (f *FileUtil) GetStoragePath(fileType, fileName string, isPublic bool) string {
|
||||||
|
// 根据是否公开选择根目录
|
||||||
|
var rootDir string
|
||||||
|
if isPublic {
|
||||||
|
rootDir = "public"
|
||||||
|
} else {
|
||||||
|
rootDir = "private"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据文件类型选择子目录
|
||||||
|
var typeDir string
|
||||||
|
switch fileType {
|
||||||
|
case model.FileTypeImage:
|
||||||
|
typeDir = "images"
|
||||||
|
case model.FileTypeDocument:
|
||||||
|
typeDir = "documents"
|
||||||
|
case model.FileTypeVideo:
|
||||||
|
typeDir = "videos"
|
||||||
|
case model.FileTypeAudio:
|
||||||
|
typeDir = "audios"
|
||||||
|
default:
|
||||||
|
typeDir = "others"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加日期子目录
|
||||||
|
dateDir := time.Now().Format("2006/01/02")
|
||||||
|
|
||||||
|
// 构建完整路径:public/documents/2025/06/17/filename.txt
|
||||||
|
path := filepath.Join(rootDir, typeDir, dateDir, fileName)
|
||||||
|
// 确保Windows上也使用正斜杠
|
||||||
|
return strings.ReplaceAll(path, "\\", "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFileURL 为文件模型设置URL字段
|
||||||
|
func (f *FileUtil) SetFileURL(file *model.File) {
|
||||||
|
if file != nil && file.Path != "" {
|
||||||
|
file.URL = f.BuildFileURL(file.Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFileURLs 为文件列表设置URL字段
|
||||||
|
func (f *FileUtil) SetFileURLs(files []model.File) {
|
||||||
|
for i := range files {
|
||||||
|
f.SetFileURL(&files[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAllowedFileType 检查文件类型是否允许
|
||||||
|
func (f *FileUtil) IsAllowedFileType(ext, fileType string) bool {
|
||||||
|
ext = strings.ToLower(ext)
|
||||||
|
|
||||||
|
switch fileType {
|
||||||
|
case model.FileTypeImage:
|
||||||
|
return isImageExt(ext)
|
||||||
|
case model.FileTypeDocument:
|
||||||
|
return isDocumentExt(ext)
|
||||||
|
case model.FileTypeVideo:
|
||||||
|
return isVideoExt(ext)
|
||||||
|
case model.FileTypeAudio:
|
||||||
|
return isAudioExt(ext)
|
||||||
|
default:
|
||||||
|
return true // 其他类型默认允许
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAllowedFileSize 检查文件大小是否允许
|
||||||
|
func (f *FileUtil) IsAllowedFileSize(size int64, fileType string) bool {
|
||||||
|
const (
|
||||||
|
MB = 1024 * 1024
|
||||||
|
GB = 1024 * MB
|
||||||
|
)
|
||||||
|
|
||||||
|
switch fileType {
|
||||||
|
case model.FileTypeImage:
|
||||||
|
return size <= 10*MB // 图片最大10MB
|
||||||
|
case model.FileTypeDocument:
|
||||||
|
return size <= 50*MB // 文档最大50MB
|
||||||
|
case model.FileTypeVideo:
|
||||||
|
return size <= 500*MB // 视频最大500MB
|
||||||
|
case model.FileTypeAudio:
|
||||||
|
return size <= 100*MB // 音频最大100MB
|
||||||
|
default:
|
||||||
|
return size <= 100*MB // 其他类型最大100MB
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// randString 生成随机字符串
|
||||||
|
func randString(n int) string {
|
||||||
|
const letters = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||||
|
b := make([]byte, n)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = letters[int(time.Now().UnixNano()%int64(len(letters)))]
|
||||||
|
time.Sleep(1 * time.Nanosecond) // 确保随机性
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isImageExt 检查是否为图片扩展名
|
||||||
|
func isImageExt(ext string) bool {
|
||||||
|
imageExts := []string{".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp", ".svg"}
|
||||||
|
for _, allowedExt := range imageExts {
|
||||||
|
if ext == allowedExt {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// isDocumentExt 检查是否为文档扩展名
|
||||||
|
func isDocumentExt(ext string) bool {
|
||||||
|
docExts := []string{".pdf", ".doc", ".docx", ".xls", ".xlsx", ".ppt", ".pptx", ".txt", ".rtf"}
|
||||||
|
for _, allowedExt := range docExts {
|
||||||
|
if ext == allowedExt {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// isVideoExt 检查是否为视频扩展名
|
||||||
|
func isVideoExt(ext string) bool {
|
||||||
|
videoExts := []string{".mp4", ".avi", ".mov", ".wmv", ".flv", ".webm", ".mkv"}
|
||||||
|
for _, allowedExt := range videoExts {
|
||||||
|
if ext == allowedExt {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// isAudioExt 检查是否为音频扩展名
|
||||||
|
func isAudioExt(ext string) bool {
|
||||||
|
audioExts := []string{".mp3", ".wav", ".flac", ".aac", ".ogg", ".wma"}
|
||||||
|
for _, allowedExt := range audioExts {
|
||||||
|
if ext == allowedExt {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
157
internal/handler/helper.go
Normal file
157
internal/handler/helper.go
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/limitcool/starter/internal/api/response"
|
||||||
|
"github.com/limitcool/starter/internal/pkg/errorx"
|
||||||
|
"github.com/limitcool/starter/internal/pkg/logger"
|
||||||
|
"github.com/spf13/cast"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HandlerHelper 处理器辅助工具
|
||||||
|
type HandlerHelper struct{}
|
||||||
|
|
||||||
|
// NewHandlerHelper 创建处理器辅助工具
|
||||||
|
func NewHandlerHelper() *HandlerHelper {
|
||||||
|
return &HandlerHelper{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserID 从上下文中获取用户ID,如果不存在则返回错误响应
|
||||||
|
func (h *HandlerHelper) GetUserID(ctx *gin.Context) (int64, bool) {
|
||||||
|
reqCtx := ctx.Request.Context()
|
||||||
|
|
||||||
|
userID, exists := ctx.Get("user_id")
|
||||||
|
if !exists {
|
||||||
|
logger.WarnContext(reqCtx, "用户ID不存在")
|
||||||
|
response.Error(ctx, errorx.ErrUserNoLogin)
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return cast.ToInt64(userID), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// BindJSON 绑定JSON参数,如果失败则返回错误响应
|
||||||
|
func (h *HandlerHelper) BindJSON(ctx *gin.Context, req interface{}, operation string) bool {
|
||||||
|
reqCtx := ctx.Request.Context()
|
||||||
|
|
||||||
|
if err := ctx.ShouldBindJSON(req); err != nil {
|
||||||
|
logger.WarnContext(reqCtx, operation+" request validation failed",
|
||||||
|
"error", err,
|
||||||
|
"client_ip", ctx.ClientIP())
|
||||||
|
response.Error(ctx, errorx.ErrInvalidParams.WithError(err))
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleDBError 处理数据库错误,统一日志记录和错误响应
|
||||||
|
func (h *HandlerHelper) HandleDBError(ctx *gin.Context, err error, operation string, fields ...interface{}) {
|
||||||
|
reqCtx := ctx.Request.Context()
|
||||||
|
|
||||||
|
// 构建日志字段
|
||||||
|
logFields := []interface{}{
|
||||||
|
"error", err,
|
||||||
|
"operation", operation,
|
||||||
|
}
|
||||||
|
logFields = append(logFields, fields...)
|
||||||
|
|
||||||
|
logger.ErrorContext(reqCtx, operation+" database operation failed", logFields...)
|
||||||
|
response.Error(ctx, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleNotFoundError 处理资源不存在错误
|
||||||
|
func (h *HandlerHelper) HandleNotFoundError(ctx *gin.Context, err error, operation string, fields ...interface{}) {
|
||||||
|
reqCtx := ctx.Request.Context()
|
||||||
|
|
||||||
|
// 构建日志字段
|
||||||
|
logFields := []interface{}{
|
||||||
|
"operation", operation,
|
||||||
|
}
|
||||||
|
logFields = append(logFields, fields...)
|
||||||
|
|
||||||
|
logger.WarnContext(reqCtx, operation+" resource not found", logFields...)
|
||||||
|
response.Error(ctx, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogSuccess 记录成功操作日志
|
||||||
|
func (h *HandlerHelper) LogSuccess(ctx *gin.Context, operation string, fields ...interface{}) {
|
||||||
|
reqCtx := ctx.Request.Context()
|
||||||
|
|
||||||
|
// 构建日志字段
|
||||||
|
logFields := []interface{}{
|
||||||
|
"operation", operation,
|
||||||
|
}
|
||||||
|
logFields = append(logFields, fields...)
|
||||||
|
|
||||||
|
logger.InfoContext(reqCtx, operation+" successful", logFields...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogWarning 记录警告日志
|
||||||
|
func (h *HandlerHelper) LogWarning(ctx *gin.Context, message string, fields ...interface{}) {
|
||||||
|
reqCtx := ctx.Request.Context()
|
||||||
|
logger.WarnContext(reqCtx, message, fields...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogError 记录错误日志
|
||||||
|
func (h *HandlerHelper) LogError(ctx *gin.Context, message string, fields ...interface{}) {
|
||||||
|
reqCtx := ctx.Request.Context()
|
||||||
|
logger.ErrorContext(reqCtx, message, fields...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckPermission 检查用户权限(是否为管理员或资源所有者)
|
||||||
|
func (h *HandlerHelper) CheckPermission(ctx *gin.Context, userID int64, resourceOwnerID int64, operation string) bool {
|
||||||
|
reqCtx := ctx.Request.Context()
|
||||||
|
|
||||||
|
// 如果是资源所有者,直接允许
|
||||||
|
if userID == resourceOwnerID {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否是管理员
|
||||||
|
isAdmin, exists := ctx.Get("is_admin")
|
||||||
|
if exists && cast.ToBool(isAdmin) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 权限不足
|
||||||
|
logger.WarnContext(reqCtx, operation+" permission denied",
|
||||||
|
"user_id", userID,
|
||||||
|
"resource_owner_id", resourceOwnerID,
|
||||||
|
"is_admin", isAdmin)
|
||||||
|
response.Error(ctx, errorx.ErrForbidden.WithMsg("无权限操作此资源"))
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClientInfo 获取客户端信息
|
||||||
|
func (h *HandlerHelper) GetClientInfo(ctx *gin.Context) (string, string) {
|
||||||
|
return ctx.ClientIP(), ctx.Request.UserAgent()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateID 验证ID参数
|
||||||
|
func (h *HandlerHelper) ValidateID(ctx *gin.Context, idStr string, operation string) (uint, bool) {
|
||||||
|
reqCtx := ctx.Request.Context()
|
||||||
|
|
||||||
|
id := cast.ToUint(idStr)
|
||||||
|
if id == 0 {
|
||||||
|
logger.WarnContext(reqCtx, operation+" invalid ID", "id", idStr)
|
||||||
|
response.Error(ctx, errorx.ErrInvalidParams.WithMsg("ID无效"))
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return id, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateInt64ID 验证int64类型的ID参数
|
||||||
|
func (h *HandlerHelper) ValidateInt64ID(ctx *gin.Context, idStr string, operation string) (int64, bool) {
|
||||||
|
reqCtx := ctx.Request.Context()
|
||||||
|
|
||||||
|
id := cast.ToInt64(idStr)
|
||||||
|
if id == 0 {
|
||||||
|
logger.WarnContext(reqCtx, operation+" invalid ID", "id", idStr)
|
||||||
|
response.Error(ctx, errorx.ErrInvalidParams.WithMsg("ID无效"))
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return id, true
|
||||||
|
}
|
@@ -12,26 +12,23 @@ import (
|
|||||||
"github.com/limitcool/starter/internal/pkg/crypto"
|
"github.com/limitcool/starter/internal/pkg/crypto"
|
||||||
"github.com/limitcool/starter/internal/pkg/errorx"
|
"github.com/limitcool/starter/internal/pkg/errorx"
|
||||||
"github.com/limitcool/starter/internal/pkg/logger"
|
"github.com/limitcool/starter/internal/pkg/logger"
|
||||||
"github.com/spf13/cast"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// UserHandler 用户处理器
|
// UserHandler 用户处理器
|
||||||
type UserHandler struct {
|
type UserHandler struct {
|
||||||
db *gorm.DB
|
*BaseHandler
|
||||||
config *configs.Config
|
|
||||||
authService *AuthService
|
authService *AuthService
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUserHandler 创建用户处理器
|
// NewUserHandler 创建用户处理器
|
||||||
func NewUserHandler(db *gorm.DB, config *configs.Config) *UserHandler {
|
func NewUserHandler(db *gorm.DB, config *configs.Config) *UserHandler {
|
||||||
handler := &UserHandler{
|
handler := &UserHandler{
|
||||||
db: db,
|
BaseHandler: NewBaseHandler(db, config),
|
||||||
config: config,
|
|
||||||
authService: NewAuthService(config),
|
authService: NewAuthService(config),
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("UserHandler initialized")
|
handler.LogInit("UserHandler")
|
||||||
return handler
|
return handler
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,7 +61,7 @@ func (h *UserHandler) UserLogin(ctx *gin.Context) {
|
|||||||
"ip", clientIP)
|
"ip", clientIP)
|
||||||
|
|
||||||
// 创建用户仓库
|
// 创建用户仓库
|
||||||
userRepo := model.NewUserRepo(h.db)
|
userRepo := model.NewUserRepo(h.DB)
|
||||||
|
|
||||||
// 查询用户
|
// 查询用户
|
||||||
user, err := userRepo.GetByUsername(reqCtx, req.Username)
|
user, err := userRepo.GetByUsername(reqCtx, req.Username)
|
||||||
@@ -161,7 +158,7 @@ func (h *UserHandler) UserRegister(ctx *gin.Context) {
|
|||||||
clientIP := ctx.ClientIP()
|
clientIP := ctx.ClientIP()
|
||||||
|
|
||||||
// 创建用户仓库
|
// 创建用户仓库
|
||||||
userRepo := model.NewUserRepo(h.db)
|
userRepo := model.NewUserRepo(h.DB)
|
||||||
|
|
||||||
// 检查用户名是否已存在
|
// 检查用户名是否已存在
|
||||||
exists, err := userRepo.IsExist(reqCtx, req.Username)
|
exists, err := userRepo.IsExist(reqCtx, req.Username)
|
||||||
@@ -230,97 +227,64 @@ func (h *UserHandler) UserRegister(ctx *gin.Context) {
|
|||||||
|
|
||||||
// UserInfo 获取用户信息
|
// UserInfo 获取用户信息
|
||||||
func (h *UserHandler) UserInfo(ctx *gin.Context) {
|
func (h *UserHandler) UserInfo(ctx *gin.Context) {
|
||||||
// 获取请求上下文
|
// 获取用户ID
|
||||||
reqCtx := ctx.Request.Context()
|
id, ok := h.Helper.GetUserID(ctx)
|
||||||
|
if !ok {
|
||||||
// 从上下文中获取用户ID
|
|
||||||
userID, exists := ctx.Get("user_id")
|
|
||||||
if !exists {
|
|
||||||
logger.WarnContext(reqCtx, "UserInfo user ID not found")
|
|
||||||
response.Error(ctx, errorx.ErrUserNoLogin)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 转换用户ID
|
|
||||||
id := cast.ToInt64(userID)
|
|
||||||
|
|
||||||
// 创建用户仓库
|
// 创建用户仓库
|
||||||
userRepo := model.NewUserRepo(h.db)
|
userRepo := model.NewUserRepo(h.DB)
|
||||||
|
|
||||||
// 查询用户信息
|
// 查询用户信息
|
||||||
user, err := userRepo.GetByID(reqCtx, id)
|
user, err := userRepo.GetByID(ctx.Request.Context(), id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, errorx.ErrUserNotFound) {
|
if errors.Is(err, errorx.ErrUserNotFound) {
|
||||||
logger.WarnContext(reqCtx, "UserInfo user not found",
|
h.Helper.HandleNotFoundError(ctx, err, "UserInfo", "user_id", id)
|
||||||
"user_id", id)
|
|
||||||
response.Error(ctx, err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.ErrorContext(reqCtx, "UserInfo failed to query user",
|
h.Helper.HandleDBError(ctx, err, "UserInfo", "user_id", id)
|
||||||
"error", err,
|
|
||||||
"user_id", id)
|
|
||||||
response.Error(ctx, err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 隐藏敏感信息
|
// 隐藏敏感信息
|
||||||
user.Password = ""
|
user.Password = ""
|
||||||
|
|
||||||
logger.InfoContext(reqCtx, "UserInfo get user info successful",
|
h.Helper.LogSuccess(ctx, "UserInfo", "user_id", id)
|
||||||
"user_id", id)
|
|
||||||
|
|
||||||
response.Success(ctx, user)
|
response.Success(ctx, user)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserChangePassword 修改密码
|
// UserChangePassword 修改密码
|
||||||
func (h *UserHandler) UserChangePassword(ctx *gin.Context) {
|
func (h *UserHandler) UserChangePassword(ctx *gin.Context) {
|
||||||
// 获取请求上下文
|
// 获取用户ID
|
||||||
reqCtx := ctx.Request.Context()
|
id, ok := h.Helper.GetUserID(ctx)
|
||||||
|
if !ok {
|
||||||
// 从上下文中获取用户ID
|
|
||||||
userID, exists := ctx.Get("user_id")
|
|
||||||
if !exists {
|
|
||||||
logger.WarnContext(reqCtx, "UserChangePassword user ID not found")
|
|
||||||
response.Error(ctx, errorx.ErrUserNoLogin)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 转换用户ID
|
|
||||||
id := cast.ToInt64(userID)
|
|
||||||
|
|
||||||
// 绑定请求参数
|
// 绑定请求参数
|
||||||
var req v1.UserChangePasswordRequest
|
var req v1.UserChangePasswordRequest
|
||||||
if err := ctx.ShouldBindJSON(&req); err != nil {
|
if !h.Helper.BindJSON(ctx, &req, "UserChangePassword") {
|
||||||
logger.WarnContext(reqCtx, "UserChangePassword request validation failed",
|
|
||||||
"error", err,
|
|
||||||
"user_id", id)
|
|
||||||
response.Error(ctx, errorx.ErrInvalidParams.WithError(err))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建用户仓库
|
// 创建用户仓库
|
||||||
userRepo := model.NewUserRepo(h.db)
|
userRepo := model.NewUserRepo(h.DB)
|
||||||
|
|
||||||
// 查询用户
|
// 查询用户
|
||||||
user, err := userRepo.GetByID(reqCtx, id)
|
user, err := userRepo.GetByID(ctx.Request.Context(), id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, errorx.ErrUserNotFound) {
|
if errors.Is(err, errorx.ErrUserNotFound) {
|
||||||
logger.WarnContext(reqCtx, "UserChangePassword user not found",
|
h.Helper.HandleNotFoundError(ctx, err, "UserChangePassword", "user_id", id)
|
||||||
"user_id", id)
|
|
||||||
response.Error(ctx, err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.ErrorContext(reqCtx, "UserChangePassword failed to query user",
|
h.Helper.HandleDBError(ctx, err, "UserChangePassword", "user_id", id)
|
||||||
"error", err,
|
|
||||||
"user_id", id)
|
|
||||||
response.Error(ctx, err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证旧密码
|
// 验证旧密码
|
||||||
if !crypto.CheckPassword(user.Password, req.OldPassword) {
|
if !crypto.CheckPassword(user.Password, req.OldPassword) {
|
||||||
logger.WarnContext(reqCtx, "UserChangePassword old password incorrect",
|
h.Helper.LogWarning(ctx, "UserChangePassword old password incorrect", "user_id", id)
|
||||||
"user_id", id)
|
|
||||||
response.Error(ctx, errorx.Errorf(errorx.ErrUserPasswordError, "旧密码错误"))
|
response.Error(ctx, errorx.Errorf(errorx.ErrUserPasswordError, "旧密码错误"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -328,24 +292,17 @@ func (h *UserHandler) UserChangePassword(ctx *gin.Context) {
|
|||||||
// 哈希新密码
|
// 哈希新密码
|
||||||
hashedPassword, err := crypto.HashPassword(req.NewPassword)
|
hashedPassword, err := crypto.HashPassword(req.NewPassword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.ErrorContext(reqCtx, "UserChangePassword failed to hash password",
|
h.Helper.LogError(ctx, "UserChangePassword failed to hash password", "error", err, "user_id", id)
|
||||||
"error", err,
|
|
||||||
"user_id", id)
|
|
||||||
response.Error(ctx, errorx.WrapError(err, "密码加密失败"))
|
response.Error(ctx, errorx.WrapError(err, "密码加密失败"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新密码
|
// 更新密码
|
||||||
if err := userRepo.UpdatePassword(reqCtx, id, hashedPassword); err != nil {
|
if err := userRepo.UpdatePassword(ctx.Request.Context(), id, hashedPassword); err != nil {
|
||||||
logger.ErrorContext(reqCtx, "UserChangePassword failed to update password",
|
h.Helper.HandleDBError(ctx, err, "UserChangePassword", "user_id", id)
|
||||||
"error", err,
|
|
||||||
"user_id", id)
|
|
||||||
response.Error(ctx, err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.InfoContext(reqCtx, "UserChangePassword password change successful",
|
h.Helper.LogSuccess(ctx, "UserChangePassword", "user_id", id)
|
||||||
"user_id", id)
|
|
||||||
|
|
||||||
response.SuccessNoData(ctx, "密码修改成功")
|
response.SuccessNoData(ctx, "密码修改成功")
|
||||||
}
|
}
|
||||||
|
@@ -11,29 +11,9 @@ import (
|
|||||||
// AdminCheck 管理员检查中间件 - 基于JWT中的is_admin字段
|
// AdminCheck 管理员检查中间件 - 基于JWT中的is_admin字段
|
||||||
func AdminCheck() gin.HandlerFunc {
|
func AdminCheck() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
// 获取请求上下文
|
if !CheckAdminPermission(c) {
|
||||||
ctx := c.Request.Context()
|
|
||||||
|
|
||||||
// 从上下文中获取用户ID
|
|
||||||
_, exists := c.Get("user_id")
|
|
||||||
if !exists {
|
|
||||||
logger.WarnContext(ctx, "AdminCheck 未找到用户ID")
|
|
||||||
response.Error(c, errorx.ErrUserNoLogin)
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查用户是否为管理员
|
|
||||||
isAdmin, ok := c.Get("is_admin")
|
|
||||||
if !ok || !isAdmin.(bool) {
|
|
||||||
logger.WarnContext(ctx, "AdminCheck 用户不是管理员",
|
|
||||||
"is_admin", isAdmin)
|
|
||||||
response.Error(c, errorx.ErrUserNoLogin.WithMsg("用户无权限"))
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 继续处理请求
|
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -1,37 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
// RequestContext 中间件,同时处理请求ID和链路追踪ID
|
|
||||||
func RequestContext() gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
|
||||||
// 处理请求ID
|
|
||||||
requestID := c.GetHeader("X-Request-ID")
|
|
||||||
if requestID == "" {
|
|
||||||
requestID = fmt.Sprintf("req-%d", time.Now().UnixNano())
|
|
||||||
}
|
|
||||||
c.Set("request_id", requestID)
|
|
||||||
c.Header("X-Request-ID", requestID)
|
|
||||||
|
|
||||||
// 处理链路追踪ID
|
|
||||||
traceID := c.GetHeader("X-Trace-ID")
|
|
||||||
if traceID == "" {
|
|
||||||
traceID = fmt.Sprintf("trace-%d", time.Now().UnixNano())
|
|
||||||
}
|
|
||||||
c.Set("trace_id", traceID)
|
|
||||||
c.Header("X-Trace-ID", traceID)
|
|
||||||
|
|
||||||
// 将请求ID和链路追踪ID添加到context.Context中
|
|
||||||
ctx := context.WithValue(c.Request.Context(), "request_id", requestID)
|
|
||||||
ctx = context.WithValue(ctx, "trace_id", traceID)
|
|
||||||
c.Request = c.Request.WithContext(ctx)
|
|
||||||
|
|
||||||
c.Next()
|
|
||||||
}
|
|
||||||
}
|
|
@@ -2,36 +2,45 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/limitcool/starter/internal/pkg/logger"
|
"github.com/limitcool/starter/internal/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RequestLoggerMiddleware 是一个记录请求日志的中间件
|
// RequestLoggerMiddleware 是一个记录请求日志的中间件,同时处理请求ID和链路追踪ID
|
||||||
func RequestLoggerMiddleware() gin.HandlerFunc {
|
func RequestLoggerMiddleware() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
// 记录开始时间
|
// 记录开始时间
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
// 获取请求ID
|
// 处理请求ID
|
||||||
requestID := c.GetHeader("X-Request-ID")
|
requestID := c.GetHeader("X-Request-ID")
|
||||||
if requestID == "" {
|
if requestID == "" {
|
||||||
requestID = uuid.New().String()
|
requestID = fmt.Sprintf("req-%d", time.Now().UnixNano())
|
||||||
c.Request.Header.Set("X-Request-ID", requestID)
|
|
||||||
c.Set("request_id", requestID) // 同时存入上下文
|
|
||||||
|
|
||||||
// 将请求ID添加到context.Context中
|
|
||||||
ctx := context.WithValue(c.Request.Context(), "request_id", requestID)
|
|
||||||
c.Request = c.Request.WithContext(ctx)
|
|
||||||
}
|
}
|
||||||
|
c.Set("request_id", requestID)
|
||||||
|
c.Header("X-Request-ID", requestID)
|
||||||
|
|
||||||
|
// 处理链路追踪ID
|
||||||
|
traceID := c.GetHeader("X-Trace-ID")
|
||||||
|
if traceID == "" {
|
||||||
|
traceID = fmt.Sprintf("trace-%d", time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
c.Set("trace_id", traceID)
|
||||||
|
c.Header("X-Trace-ID", traceID)
|
||||||
|
|
||||||
|
// 将请求ID和链路追踪ID添加到context.Context中
|
||||||
|
ctx := context.WithValue(c.Request.Context(), "request_id", requestID)
|
||||||
|
ctx = context.WithValue(ctx, "trace_id", traceID)
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
|
||||||
// 处理请求
|
// 处理请求
|
||||||
c.Next()
|
c.Next()
|
||||||
|
|
||||||
// 获取请求上下文
|
// 获取请求上下文
|
||||||
ctx := c.Request.Context()
|
reqCtx := c.Request.Context()
|
||||||
|
|
||||||
// 计算延迟
|
// 计算延迟
|
||||||
latency := time.Since(start)
|
latency := time.Since(start)
|
||||||
@@ -59,11 +68,11 @@ func RequestLoggerMiddleware() gin.HandlerFunc {
|
|||||||
|
|
||||||
// 根据状态码选择日志级别
|
// 根据状态码选择日志级别
|
||||||
if status >= 500 {
|
if status >= 500 {
|
||||||
logger.ErrorContext(ctx, "Server error", fields...)
|
logger.ErrorContext(reqCtx, "Server error", fields...)
|
||||||
} else if status >= 400 {
|
} else if status >= 400 {
|
||||||
logger.WarnContext(ctx, "Client error", fields...)
|
logger.WarnContext(reqCtx, "Client error", fields...)
|
||||||
} else {
|
} else {
|
||||||
logger.InfoContext(ctx, "Request completed", fields...)
|
logger.InfoContext(reqCtx, "Request completed", fields...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -11,19 +11,9 @@ import (
|
|||||||
// UserCheck 用户检查中间件 - 确保用户已登录
|
// UserCheck 用户检查中间件 - 确保用户已登录
|
||||||
func UserCheck() gin.HandlerFunc {
|
func UserCheck() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
// 获取请求上下文
|
if !CheckUserLogin(c) {
|
||||||
ctx := c.Request.Context()
|
|
||||||
|
|
||||||
// 从上下文获取用户ID
|
|
||||||
_, exists := c.Get("user_id")
|
|
||||||
if !exists {
|
|
||||||
logger.WarnContext(ctx, "用户ID不存在")
|
|
||||||
response.Error(c, errorx.ErrUserNoLogin)
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 继续处理请求
|
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -62,28 +52,21 @@ func UserCheckWithDB(userRepo *model.UserRepo) gin.HandlerFunc {
|
|||||||
// 适用于只允许普通用户访问的接口
|
// 适用于只允许普通用户访问的接口
|
||||||
func RegularUserCheck() gin.HandlerFunc {
|
func RegularUserCheck() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
// 获取请求上下文
|
// 先检查是否已登录
|
||||||
ctx := c.Request.Context()
|
if !CheckUserLogin(c) {
|
||||||
|
|
||||||
// 从上下文获取用户ID
|
|
||||||
_, exists := c.Get("user_id")
|
|
||||||
if !exists {
|
|
||||||
logger.WarnContext(ctx, "用户ID不存在")
|
|
||||||
response.Error(c, errorx.ErrUserNoLogin)
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查用户是否为管理员
|
// 检查用户是否为管理员
|
||||||
isAdmin, ok := c.Get("is_admin")
|
isAdmin, ok := c.Get("is_admin")
|
||||||
if ok && isAdmin.(bool) {
|
if ok && isAdmin.(bool) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
logger.WarnContext(ctx, "管理员不能访问普通用户接口")
|
logger.WarnContext(ctx, "管理员不能访问普通用户接口")
|
||||||
response.Error(c, errorx.ErrAccessDenied)
|
response.Error(c, errorx.ErrAccessDenied)
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 继续处理请求
|
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -4,6 +4,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/limitcool/starter/internal/api/response"
|
||||||
|
"github.com/limitcool/starter/internal/pkg/errorx"
|
||||||
|
"github.com/limitcool/starter/internal/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetUserID 从上下文中获取用户ID
|
// GetUserID 从上下文中获取用户ID
|
||||||
@@ -64,3 +67,39 @@ func GetUserIDString(c *gin.Context) string {
|
|||||||
}
|
}
|
||||||
return fmt.Sprintf("%d", id)
|
return fmt.Sprintf("%d", id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CheckUserLogin 检查用户是否已登录,如果未登录则返回错误响应
|
||||||
|
func CheckUserLogin(c *gin.Context) bool {
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
|
_, exists := c.Get("user_id")
|
||||||
|
if !exists {
|
||||||
|
logger.WarnContext(ctx, "用户ID不存在")
|
||||||
|
response.Error(c, errorx.ErrUserNoLogin)
|
||||||
|
c.Abort()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckAdminPermission 检查用户是否为管理员,如果不是则返回错误响应
|
||||||
|
func CheckAdminPermission(c *gin.Context) bool {
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
|
// 先检查是否已登录
|
||||||
|
if !CheckUserLogin(c) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查用户是否为管理员
|
||||||
|
isAdmin, ok := c.Get("is_admin")
|
||||||
|
if !ok || !isAdmin.(bool) {
|
||||||
|
logger.WarnContext(ctx, "用户不是管理员", "is_admin", isAdmin)
|
||||||
|
response.Error(c, errorx.ErrUserNoLogin.WithMsg("用户无权限"))
|
||||||
|
c.Abort()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
@@ -54,8 +54,13 @@ func (File) TableName() string {
|
|||||||
|
|
||||||
// FileRepo 文件仓库
|
// FileRepo 文件仓库
|
||||||
type FileRepo struct {
|
type FileRepo struct {
|
||||||
DB *gorm.DB
|
*GenericRepo[File]
|
||||||
GenericRepo *GenericRepo[File]
|
fileUtil FileURLBuilder // 文件URL构建器接口
|
||||||
|
}
|
||||||
|
|
||||||
|
// FileURLBuilder 文件URL构建器接口
|
||||||
|
type FileURLBuilder interface {
|
||||||
|
BuildFileURL(path string) string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFileRepo 创建文件仓库
|
// NewFileRepo 创建文件仓库
|
||||||
@@ -64,52 +69,51 @@ func NewFileRepo(db *gorm.DB) *FileRepo {
|
|||||||
genericRepo.ErrorCode = errorx.ErrorFileNotFoundCode
|
genericRepo.ErrorCode = errorx.ErrorFileNotFoundCode
|
||||||
|
|
||||||
return &FileRepo{
|
return &FileRepo{
|
||||||
DB: db,
|
|
||||||
GenericRepo: genericRepo,
|
GenericRepo: genericRepo,
|
||||||
|
fileUtil: &defaultFileURLBuilder{}, // 使用默认实现
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create 创建文件记录
|
// NewFileRepoWithURLBuilder 创建带有自定义URL构建器的文件仓库
|
||||||
func (r *FileRepo) Create(ctx context.Context, file *File) error {
|
func NewFileRepoWithURLBuilder(db *gorm.DB, urlBuilder FileURLBuilder) *FileRepo {
|
||||||
return r.GenericRepo.Create(ctx, file)
|
genericRepo := NewGenericRepo[File](db)
|
||||||
|
genericRepo.ErrorCode = errorx.ErrorFileNotFoundCode
|
||||||
|
|
||||||
|
return &FileRepo{
|
||||||
|
GenericRepo: genericRepo,
|
||||||
|
fileUtil: urlBuilder,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultFileURLBuilder 默认的文件URL构建器
|
||||||
|
type defaultFileURLBuilder struct{}
|
||||||
|
|
||||||
|
func (d *defaultFileURLBuilder) BuildFileURL(path string) string {
|
||||||
|
if path == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return "/uploads/" + path
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByID 根据ID获取文件
|
// GetByID 根据ID获取文件
|
||||||
func (r *FileRepo) GetByID(ctx context.Context, id uint) (*File, error) {
|
func (r *FileRepo) GetByID(ctx context.Context, id uint) (*File, error) {
|
||||||
file, err := r.GenericRepo.Get(ctx, id, nil)
|
file, err := r.Get(ctx, id, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errorx.WrapError(err, "查询文件失败")
|
return nil, errorx.WrapError(err, "查询文件失败")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置URL字段
|
// 设置URL字段
|
||||||
if file.Path != "" {
|
if file.Path != "" {
|
||||||
file.URL = r.buildFileURL(file.Path)
|
file.URL = r.fileUtil.BuildFileURL(file.Path)
|
||||||
}
|
}
|
||||||
|
|
||||||
return file, nil
|
return file, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildFileURL 构建文件URL
|
|
||||||
func (r *FileRepo) buildFileURL(path string) string {
|
|
||||||
// 这里可以根据实际情况构建URL,例如添加域名前缀等
|
|
||||||
// 简单示例:
|
|
||||||
return "/uploads/" + path
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update 更新文件记录
|
|
||||||
func (r *FileRepo) Update(ctx context.Context, file *File) error {
|
|
||||||
return r.GenericRepo.Update(ctx, file)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete 删除文件记录
|
|
||||||
func (r *FileRepo) Delete(ctx context.Context, id uint) error {
|
|
||||||
return r.GenericRepo.Delete(ctx, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateFileUsage 更新文件用途
|
// UpdateFileUsage 更新文件用途
|
||||||
func (r *FileRepo) UpdateFileUsage(ctx context.Context, file *File, usage string) error {
|
func (r *FileRepo) UpdateFileUsage(ctx context.Context, file *File, usage string) error {
|
||||||
file.Usage = usage
|
file.Usage = usage
|
||||||
return r.GenericRepo.Update(ctx, file)
|
return r.Update(ctx, file)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListByUser 获取用户的文件列表
|
// ListByUser 获取用户的文件列表
|
||||||
@@ -119,7 +123,7 @@ func (r *FileRepo) ListByUser(ctx context.Context, userID int64, page, pageSize
|
|||||||
Condition: "uploaded_by = ? AND uploaded_by_type = ?",
|
Condition: "uploaded_by = ? AND uploaded_by_type = ?",
|
||||||
Args: []any{userID, 2},
|
Args: []any{userID, 2},
|
||||||
}
|
}
|
||||||
files, err := r.GenericRepo.List(ctx, page, pageSize, opts)
|
files, err := r.List(ctx, page, pageSize, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errorx.WrapError(err, "查询用户文件列表失败")
|
return nil, errorx.WrapError(err, "查询用户文件列表失败")
|
||||||
}
|
}
|
||||||
@@ -127,7 +131,7 @@ func (r *FileRepo) ListByUser(ctx context.Context, userID int64, page, pageSize
|
|||||||
// 为所有文件设置URL
|
// 为所有文件设置URL
|
||||||
for i := range files {
|
for i := range files {
|
||||||
if files[i].Path != "" {
|
if files[i].Path != "" {
|
||||||
files[i].URL = r.buildFileURL(files[i].Path)
|
files[i].URL = r.fileUtil.BuildFileURL(files[i].Path)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,7 +140,7 @@ func (r *FileRepo) ListByUser(ctx context.Context, userID int64, page, pageSize
|
|||||||
|
|
||||||
// CountByUser 获取用户的文件总数
|
// CountByUser 获取用户的文件总数
|
||||||
func (r *FileRepo) CountByUser(ctx context.Context, userID int64) (int64, error) {
|
func (r *FileRepo) CountByUser(ctx context.Context, userID int64) (int64, error) {
|
||||||
count, err := r.GenericRepo.Count(ctx, &QueryOptions{
|
count, err := r.Count(ctx, &QueryOptions{
|
||||||
Condition: "uploaded_by = ? AND uploaded_by_type = ?",
|
Condition: "uploaded_by = ? AND uploaded_by_type = ?",
|
||||||
Args: []any{userID, 2},
|
Args: []any{userID, 2},
|
||||||
})
|
})
|
||||||
@@ -180,7 +184,7 @@ func (r *FileRepo) ListFiles(ctx context.Context, page, pageSize int, fileType,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取文件列表
|
// 获取文件列表
|
||||||
files, err := r.GenericRepo.List(ctx, page, pageSize, opts)
|
files, err := r.List(ctx, page, pageSize, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, errorx.WrapError(err, "查询文件列表失败")
|
return nil, 0, errorx.WrapError(err, "查询文件列表失败")
|
||||||
}
|
}
|
||||||
@@ -188,12 +192,12 @@ func (r *FileRepo) ListFiles(ctx context.Context, page, pageSize int, fileType,
|
|||||||
// 为所有文件设置URL
|
// 为所有文件设置URL
|
||||||
for i := range files {
|
for i := range files {
|
||||||
if files[i].Path != "" {
|
if files[i].Path != "" {
|
||||||
files[i].URL = r.buildFileURL(files[i].Path)
|
files[i].URL = r.fileUtil.BuildFileURL(files[i].Path)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取总数
|
// 获取总数
|
||||||
total, err := r.GenericRepo.Count(ctx, opts)
|
total, err := r.Count(ctx, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, errorx.WrapError(err, "查询文件总数失败")
|
return nil, 0, errorx.WrapError(err, "查询文件总数失败")
|
||||||
}
|
}
|
||||||
|
@@ -47,8 +47,7 @@ func NewUser() *User {
|
|||||||
|
|
||||||
// UserRepo 用户仓库
|
// UserRepo 用户仓库
|
||||||
type UserRepo struct {
|
type UserRepo struct {
|
||||||
DB *gorm.DB
|
*GenericRepo[User]
|
||||||
GenericRepo *GenericRepo[User]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUserRepo 创建用户仓库
|
// NewUserRepo 创建用户仓库
|
||||||
@@ -57,14 +56,13 @@ func NewUserRepo(db *gorm.DB) *UserRepo {
|
|||||||
genericRepo.ErrorCode = errorx.ErrorUserNotFoundCode
|
genericRepo.ErrorCode = errorx.ErrorUserNotFoundCode
|
||||||
|
|
||||||
return &UserRepo{
|
return &UserRepo{
|
||||||
DB: db,
|
|
||||||
GenericRepo: genericRepo,
|
GenericRepo: genericRepo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByID 根据ID获取用户
|
// GetByID 根据ID获取用户
|
||||||
func (r *UserRepo) GetByID(ctx context.Context, id int64) (*User, error) {
|
func (r *UserRepo) GetByID(ctx context.Context, id int64) (*User, error) {
|
||||||
user, err := r.GenericRepo.Get(ctx, id, nil)
|
user, err := r.Get(ctx, id, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errorx.WrapError(err, "查询用户失败")
|
return nil, errorx.WrapError(err, "查询用户失败")
|
||||||
}
|
}
|
||||||
@@ -81,7 +79,7 @@ func (r *UserRepo) GetUserWithAvatar(ctx context.Context, id int64) (*User, erro
|
|||||||
|
|
||||||
// 如果用户有头像,再预加载头像
|
// 如果用户有头像,再预加载头像
|
||||||
if user.AvatarFileID > 0 {
|
if user.AvatarFileID > 0 {
|
||||||
user, err = r.GenericRepo.Get(ctx, id, &QueryOptions{
|
user, err = r.Get(ctx, id, &QueryOptions{
|
||||||
Preloads: []string{"AvatarFile"},
|
Preloads: []string{"AvatarFile"},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -99,7 +97,7 @@ func (r *UserRepo) GetUserWithAvatar(ctx context.Context, id int64) (*User, erro
|
|||||||
|
|
||||||
// GetByUsername 根据用户名获取用户
|
// GetByUsername 根据用户名获取用户
|
||||||
func (r *UserRepo) GetByUsername(ctx context.Context, username string) (*User, error) {
|
func (r *UserRepo) GetByUsername(ctx context.Context, username string) (*User, error) {
|
||||||
user, err := r.GenericRepo.Get(ctx, nil, &QueryOptions{
|
user, err := r.Get(ctx, nil, &QueryOptions{
|
||||||
Condition: "username = ?",
|
Condition: "username = ?",
|
||||||
Args: []any{username},
|
Args: []any{username},
|
||||||
})
|
})
|
||||||
@@ -112,24 +110,9 @@ func (r *UserRepo) GetByUsername(ctx context.Context, username string) (*User, e
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create 创建用户
|
|
||||||
func (r *UserRepo) Create(ctx context.Context, user *User) error {
|
|
||||||
return r.GenericRepo.Create(ctx, user)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update 更新用户
|
|
||||||
func (r *UserRepo) Update(ctx context.Context, user *User) error {
|
|
||||||
return r.GenericRepo.Update(ctx, user)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete 删除用户
|
|
||||||
func (r *UserRepo) Delete(ctx context.Context, id int64) error {
|
|
||||||
return r.GenericRepo.Delete(ctx, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsExist 检查用户是否存在
|
// IsExist 检查用户是否存在
|
||||||
func (r *UserRepo) IsExist(ctx context.Context, username string) (bool, error) {
|
func (r *UserRepo) IsExist(ctx context.Context, username string) (bool, error) {
|
||||||
count, err := r.GenericRepo.Count(ctx, &QueryOptions{
|
count, err := r.Count(ctx, &QueryOptions{
|
||||||
Condition: "username = ?",
|
Condition: "username = ?",
|
||||||
Args: []any{username},
|
Args: []any{username},
|
||||||
})
|
})
|
||||||
@@ -157,7 +140,7 @@ func (r *UserRepo) ListUsers(ctx context.Context, page, pageSize int, keyword st
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取用户列表
|
// 获取用户列表
|
||||||
users, err := r.GenericRepo.List(ctx, page, pageSize, opts)
|
users, err := r.List(ctx, page, pageSize, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, errorx.WrapError(err, "查询用户列表失败")
|
return nil, 0, errorx.WrapError(err, "查询用户列表失败")
|
||||||
}
|
}
|
||||||
@@ -170,7 +153,7 @@ func (r *UserRepo) ListUsers(ctx context.Context, page, pageSize int, keyword st
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取总数
|
// 获取总数
|
||||||
total, err := r.GenericRepo.Count(ctx, opts)
|
total, err := r.Count(ctx, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, errorx.WrapError(err, "查询用户总数失败")
|
return nil, 0, errorx.WrapError(err, "查询用户总数失败")
|
||||||
}
|
}
|
||||||
|
@@ -1,47 +0,0 @@
|
|||||||
package test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"path/filepath"
|
|
||||||
)
|
|
||||||
|
|
||||||
// 运行所有测试
|
|
||||||
func main() {
|
|
||||||
// 获取项目根目录
|
|
||||||
rootDir, err := os.Getwd()
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("获取当前目录失败: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 切换到项目根目录
|
|
||||||
rootDir = filepath.Dir(rootDir)
|
|
||||||
if err := os.Chdir(rootDir); err != nil {
|
|
||||||
fmt.Printf("切换到项目根目录失败: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 运行单元测试
|
|
||||||
fmt.Println("运行单元测试...")
|
|
||||||
cmd := exec.Command("go", "test", "./internal/pkg/casbin", "./internal/middleware", "-v")
|
|
||||||
cmd.Stdout = os.Stdout
|
|
||||||
cmd.Stderr = os.Stderr
|
|
||||||
if err := cmd.Run(); err != nil {
|
|
||||||
fmt.Printf("单元测试失败: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 运行集成测试
|
|
||||||
fmt.Println("\n运行集成测试...")
|
|
||||||
cmd = exec.Command("go", "test", "./test/integration", "-v")
|
|
||||||
cmd.Stdout = os.Stdout
|
|
||||||
cmd.Stderr = os.Stderr
|
|
||||||
if err := cmd.Run(); err != nil {
|
|
||||||
fmt.Printf("集成测试失败: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println("\n所有测试通过!")
|
|
||||||
}
|
|
Reference in New Issue
Block a user