Refactor user handler and middleware for improved error handling and logging

- Consolidated user ID retrieval and permission checks into helper functions.
- Updated UserHandler to utilize BaseHandler for common database and configuration access.
- Enhanced logging for user-related operations, including login, registration, and password changes.
- Removed redundant context handling in middleware and improved readability.
- Introduced FileUtil for file URL generation and management, encapsulating file-related logic.
- Refactored FileRepo and UserRepo to streamline database operations and error handling.
- Deleted unused request_id middleware and integrated its functionality into request_logger.
- Removed legacy test runner script to simplify testing process.
This commit is contained in:
limitcool
2025-06-17 23:09:02 +08:00
parent 36d816d908
commit b7628c770b
16 changed files with 588 additions and 407 deletions

View File

@@ -86,7 +86,11 @@ version:
.PHONY: build-dev
build-dev:
@echo "Building $(APP_NAME) for development..."
go build $(BUILDFLAGS) -ldflags="$(LDFLAGS)" -o $(APP_NAME)$(if $(filter windows,$(shell go env GOOS)),.exe,)
ifeq ($(OS),Windows_NT)
go build $(BUILDFLAGS) -ldflags="$(LDFLAGS)" -o $(APP_NAME).exe
else
go build $(BUILDFLAGS) -ldflags="$(LDFLAGS)" -o $(APP_NAME)
endif
# Docker 相关目标
.PHONY: docker-build

View File

@@ -42,44 +42,65 @@ type Handlers struct {
Admin *handler.AdminHandler
}
// InitStep 初始化步骤
type InitStep struct {
Name string
Required bool
Init func() error
}
// New 创建新的应用实例
func New(config *configs.Config) (*App, error) {
app := &App{
config: config,
}
app := &App{config: config}
// 按顺序初始化各个组件
if err := app.initDatabase(); err != nil {
return nil, fmt.Errorf("failed to initialize database: %w", err)
}
// 定义初始化步骤
steps := app.getInitSteps()
if err := app.initRedis(); err != nil {
return nil, fmt.Errorf("failed to initialize redis: %w", err)
}
// 按顺序执行初始化
for _, step := range steps {
logger.Info("Initializing component", "component", step.Name)
if err := app.initStorage(); err != nil {
return nil, fmt.Errorf("failed to initialize storage: %w", err)
}
if err := step.Init(); err != nil {
logger.Error("Failed to initialize component",
"component", step.Name,
"required", step.Required,
"error", err)
if err := app.initHandlers(); err != nil {
return nil, fmt.Errorf("failed to initialize handlers: %w", err)
}
if step.Required {
return nil, fmt.Errorf("failed to initialize required component %s: %w", step.Name, err)
}
if err := app.initRouter(); err != nil {
return nil, fmt.Errorf("failed to initialize router: %w", err)
}
logger.Warn("Optional component initialization failed, continuing",
"component", step.Name)
continue
}
if err := app.initServer(); err != nil {
return nil, fmt.Errorf("failed to initialize server: %w", err)
}
if err := app.initPprof(); err != nil {
return nil, fmt.Errorf("failed to initialize pprof: %w", err)
logger.Info("Component initialized successfully", "component", step.Name)
}
return app, nil
}
// getInitSteps 获取初始化步骤列表
func (app *App) getInitSteps() []InitStep {
steps := []InitStep{
// 数据库和Redis根据配置启用失败时不影响应用启动内部有禁用检查
{Name: "database", Required: false, Init: app.initDatabase},
{Name: "redis", Required: false, Init: app.initRedis},
// 存储服务是可选的,某些功能可能需要它
{Name: "storage", Required: false, Init: app.initStorage},
// 核心组件,必须成功初始化
{Name: "handlers", Required: true, Init: app.initHandlers},
{Name: "router", Required: true, Init: app.initRouter},
{Name: "server", Required: true, Init: app.initServer},
{Name: "pprof", Required: false, Init: app.initPprof},
}
return steps
}
// initDatabase 初始化数据库连接
func (a *App) initDatabase() error {
if !a.config.Database.Enabled {

View File

@@ -10,18 +10,16 @@ import (
// AdminHandler 管理员处理器
type AdminHandler struct {
db *gorm.DB
config *configs.Config
*BaseHandler
}
// NewAdminHandler 创建管理员处理器
func NewAdminHandler(db *gorm.DB, config *configs.Config) *AdminHandler {
handler := &AdminHandler{
db: db,
config: config,
BaseHandler: NewBaseHandler(db, config),
}
logger.Info("AdminHandler initialized")
handler.LogInit("AdminHandler")
return handler
}
@@ -35,9 +33,9 @@ func (h *AdminHandler) GetSystemSettings(ctx *gin.Context) {
// 返回系统设置
settings := map[string]any{
"app_name": h.config.App.Name,
"app_name": h.Config.App.Name,
"app_version": "1.0.0",
"app_mode": h.config.App.Mode,
"app_mode": h.Config.App.Mode,
}
response.Success(ctx, settings)

View File

@@ -0,0 +1,30 @@
package handler
import (
"github.com/limitcool/starter/configs"
"github.com/limitcool/starter/internal/pkg/logger"
"gorm.io/gorm"
)
// BaseHandler 基础处理器包含所有Handler的公共字段和方法
type BaseHandler struct {
DB *gorm.DB
Config *configs.Config
Helper *HandlerHelper
FileUtil *FileUtil
}
// NewBaseHandler 创建基础处理器
func NewBaseHandler(db *gorm.DB, config *configs.Config) *BaseHandler {
return &BaseHandler{
DB: db,
Config: config,
Helper: NewHandlerHelper(),
FileUtil: NewFileUtil("/uploads"), // 默认基础URL
}
}
// LogInit 记录Handler初始化日志
func (h *BaseHandler) LogInit(handlerName string) {
logger.Info(handlerName + " initialized")
}

View File

@@ -19,20 +19,18 @@ import (
// FileHandler 文件处理器
type FileHandler struct {
db *gorm.DB
config *configs.Config
*BaseHandler
storage *filestore.Storage
}
// NewFileHandler 创建文件处理器
func NewFileHandler(db *gorm.DB, config *configs.Config, storage *filestore.Storage) *FileHandler {
handler := &FileHandler{
db: db,
config: config,
storage: storage,
BaseHandler: NewBaseHandler(db, config),
storage: storage,
}
logger.Info("FileHandler initialized")
handler.LogInit("FileHandler")
return handler
}
@@ -72,7 +70,7 @@ func (h *FileHandler) GetUploadURL(ctx *gin.Context) {
// 验证文件类型
ext := strings.ToLower(filepath.Ext(req.Filename))
if !isAllowedFileType(ext, req.FileType) {
if !h.FileUtil.IsAllowedFileType(ext, req.FileType) {
logger.WarnContext(reqCtx, "GetUploadURL 不支持的文件类型",
"user_id", id,
"file_type", req.FileType,
@@ -82,8 +80,8 @@ func (h *FileHandler) GetUploadURL(ctx *gin.Context) {
}
// 生成唯一的文件名和存储路径
fileName := generateFileName(req.Filename)
storagePath := getStoragePath(req.FileType, fileName, req.IsPublic)
fileName := h.FileUtil.GenerateFileName(req.Filename)
storagePath := h.FileUtil.GetStoragePath(req.FileType, fileName, req.IsPublic)
// 生成上传预签名URL
uploadURL, err := h.storage.GetUploadPresignedURL(reqCtx, storagePath, req.ContentType, 15) // 15分钟有效期
@@ -97,7 +95,7 @@ func (h *FileHandler) GetUploadURL(ctx *gin.Context) {
}
// 创建文件记录状态为pending等待上传完成确认
fileRepo := model.NewFileRepo(h.db)
fileRepo := model.NewFileRepo(h.DB)
fileModel := &model.File{
Name: fileName,
OriginalName: req.Filename,
@@ -172,7 +170,7 @@ func (h *FileHandler) ConfirmUpload(ctx *gin.Context) {
}
// 创建文件仓库
fileRepo := model.NewFileRepo(h.db)
fileRepo := model.NewFileRepo(h.DB)
// 获取文件记录
fileModel, err := fileRepo.GetByID(reqCtx, req.FileID)
@@ -312,7 +310,7 @@ func (h *FileHandler) UploadFile(ctx *gin.Context) {
size := fileHeader.Size
// 验证文件类型
if !isAllowedFileType(ext, fileType) {
if !h.FileUtil.IsAllowedFileType(ext, fileType) {
logger.WarnContext(reqCtx, "UploadFile 文件类型不允许",
"user_id", id,
"file_type", fileType,
@@ -322,7 +320,7 @@ func (h *FileHandler) UploadFile(ctx *gin.Context) {
}
// 验证文件大小
if !isAllowedFileSize(size, fileType) {
if !h.FileUtil.IsAllowedFileSize(size, fileType) {
logger.WarnContext(reqCtx, "UploadFile 文件大小超出限制",
"user_id", id,
"file_type", fileType,
@@ -343,8 +341,8 @@ func (h *FileHandler) UploadFile(ctx *gin.Context) {
defer file.Close()
// 生成文件名和存储路径
fileName := generateFileName(originalName)
storagePath := getStoragePath(fileType, fileName, isPublic)
fileName := h.FileUtil.GenerateFileName(originalName)
storagePath := h.FileUtil.GetStoragePath(fileType, fileName, isPublic)
// 上传文件到存储权限由路径和Bucket Policy控制
err = h.storage.Put(reqCtx, storagePath, file)
@@ -369,7 +367,7 @@ func (h *FileHandler) UploadFile(ctx *gin.Context) {
}
// 创建文件仓库
fileRepo := model.NewFileRepo(h.db)
fileRepo := model.NewFileRepo(h.DB)
// 记录到数据库
fileModel := &model.File{
@@ -434,7 +432,7 @@ func (h *FileHandler) GetFileURL(ctx *gin.Context) {
}
// 创建文件仓库
fileRepo := model.NewFileRepo(h.db)
fileRepo := model.NewFileRepo(h.DB)
// 获取文件信息
fileModel, err := fileRepo.GetByID(reqCtx, fileID)
@@ -569,7 +567,7 @@ func (h *FileHandler) ServePublicFile(ctx *gin.Context) {
}
// 创建文件仓库
fileRepo := model.NewFileRepo(h.db)
fileRepo := model.NewFileRepo(h.DB)
// 获取文件信息
fileModel, err := fileRepo.GetByID(reqCtx, fileID)
@@ -632,7 +630,7 @@ func (h *FileHandler) ListFiles(ctx *gin.Context) {
}
// 创建文件仓库
fileRepo := model.NewFileRepo(h.db)
fileRepo := model.NewFileRepo(h.DB)
// 获取文件列表
files, total, err := fileRepo.ListFiles(reqCtx, page, pageSize, fileType, usage, nil)
@@ -678,7 +676,7 @@ func (h *FileHandler) GetFileInfo(ctx *gin.Context) {
}
// 创建文件仓库
fileRepo := model.NewFileRepo(h.db)
fileRepo := model.NewFileRepo(h.DB)
// 获取文件信息
fileModel, err := fileRepo.GetByID(reqCtx, fileID)
@@ -696,92 +694,3 @@ func (h *FileHandler) GetFileInfo(ctx *gin.Context) {
response.Success(ctx, fileModel)
}
// 生成文件名
func generateFileName(originalName string) string {
ext := filepath.Ext(originalName)
name := fmt.Sprintf("%d_%s%s", time.Now().UnixNano(), randString(8), ext)
return name
}
// 随机字符串
func randString(n int) string {
const letters = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, n)
for i := range b {
b[i] = letters[int(time.Now().UnixNano()%int64(len(letters)))]
time.Sleep(1 * time.Nanosecond) // 确保随机性
}
return string(b)
}
// 获取存储路径
func getStoragePath(fileType, fileName string, isPublic bool) string {
// 根据是否公开选择根目录
var rootDir string
if isPublic {
rootDir = "public"
} else {
rootDir = "private"
}
// 根据文件类型选择子目录
var typeDir string
switch fileType {
case model.FileTypeImage:
typeDir = "images"
case model.FileTypeDocument:
typeDir = "documents"
case model.FileTypeVideo:
typeDir = "videos"
case model.FileTypeAudio:
typeDir = "audios"
default:
typeDir = "others"
}
// 添加日期子目录
dateDir := time.Now().Format("2006/01/02")
// 构建完整路径public/documents/2025/06/17/filename.txt
path := filepath.Join(rootDir, typeDir, dateDir, fileName)
// 确保Windows上也使用正斜杠
return strings.ReplaceAll(path, "\\", "/")
}
// 检查文件类型是否允许
func isAllowedFileType(ext string, fileType string) bool {
ext = strings.ToLower(ext)
switch fileType {
case model.FileTypeImage:
return ext == ".jpg" || ext == ".jpeg" || ext == ".png" || ext == ".gif" || ext == ".webp" || ext == ".svg"
case model.FileTypeDocument:
return ext == ".pdf" || ext == ".doc" || ext == ".docx" || ext == ".xls" || ext == ".xlsx" || ext == ".txt"
case model.FileTypeVideo:
return ext == ".mp4" || ext == ".avi" || ext == ".mov" || ext == ".wmv" || ext == ".flv"
case model.FileTypeAudio:
return ext == ".mp3" || ext == ".wav" || ext == ".ogg" || ext == ".flac"
default:
return true
}
}
// 检查文件大小是否允许
func isAllowedFileSize(size int64, fileType string) bool {
const (
MB = 1024 * 1024
)
switch fileType {
case model.FileTypeImage:
return size <= 10*MB // 图片最大10MB
case model.FileTypeDocument:
return size <= 50*MB // 文档最大50MB
case model.FileTypeVideo:
return size <= 500*MB // 视频最大500MB
case model.FileTypeAudio:
return size <= 100*MB // 音频最大100MB
default:
return size <= 50*MB // 其他类型最大50MB
}
}

View File

@@ -0,0 +1,191 @@
package handler
import (
"fmt"
"path/filepath"
"strings"
"time"
"github.com/limitcool/starter/internal/model"
)
// FileUtil 文件工具类
type FileUtil struct {
baseURL string // 基础URL如 "/uploads" 或 "https://cdn.example.com"
}
// NewFileUtil 创建文件工具实例
func NewFileUtil(baseURL string) *FileUtil {
return &FileUtil{
baseURL: baseURL,
}
}
// BuildFileURL 构建文件访问URL
func (f *FileUtil) BuildFileURL(path string) string {
if path == "" {
return ""
}
// 确保路径以正斜杠开头
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
// 如果baseURL已经是完整URL直接拼接
if strings.HasPrefix(f.baseURL, "http") {
return f.baseURL + path
}
// 否则作为相对路径处理
return f.baseURL + path
}
// GenerateFileName 生成唯一的文件名
func (f *FileUtil) GenerateFileName(originalName string) string {
ext := filepath.Ext(originalName)
name := fmt.Sprintf("%d_%s%s", time.Now().UnixNano(), randString(8), ext)
return name
}
// GetStoragePath 获取存储路径
func (f *FileUtil) GetStoragePath(fileType, fileName string, isPublic bool) string {
// 根据是否公开选择根目录
var rootDir string
if isPublic {
rootDir = "public"
} else {
rootDir = "private"
}
// 根据文件类型选择子目录
var typeDir string
switch fileType {
case model.FileTypeImage:
typeDir = "images"
case model.FileTypeDocument:
typeDir = "documents"
case model.FileTypeVideo:
typeDir = "videos"
case model.FileTypeAudio:
typeDir = "audios"
default:
typeDir = "others"
}
// 添加日期子目录
dateDir := time.Now().Format("2006/01/02")
// 构建完整路径public/documents/2025/06/17/filename.txt
path := filepath.Join(rootDir, typeDir, dateDir, fileName)
// 确保Windows上也使用正斜杠
return strings.ReplaceAll(path, "\\", "/")
}
// SetFileURL 为文件模型设置URL字段
func (f *FileUtil) SetFileURL(file *model.File) {
if file != nil && file.Path != "" {
file.URL = f.BuildFileURL(file.Path)
}
}
// SetFileURLs 为文件列表设置URL字段
func (f *FileUtil) SetFileURLs(files []model.File) {
for i := range files {
f.SetFileURL(&files[i])
}
}
// IsAllowedFileType 检查文件类型是否允许
func (f *FileUtil) IsAllowedFileType(ext, fileType string) bool {
ext = strings.ToLower(ext)
switch fileType {
case model.FileTypeImage:
return isImageExt(ext)
case model.FileTypeDocument:
return isDocumentExt(ext)
case model.FileTypeVideo:
return isVideoExt(ext)
case model.FileTypeAudio:
return isAudioExt(ext)
default:
return true // 其他类型默认允许
}
}
// IsAllowedFileSize 检查文件大小是否允许
func (f *FileUtil) IsAllowedFileSize(size int64, fileType string) bool {
const (
MB = 1024 * 1024
GB = 1024 * MB
)
switch fileType {
case model.FileTypeImage:
return size <= 10*MB // 图片最大10MB
case model.FileTypeDocument:
return size <= 50*MB // 文档最大50MB
case model.FileTypeVideo:
return size <= 500*MB // 视频最大500MB
case model.FileTypeAudio:
return size <= 100*MB // 音频最大100MB
default:
return size <= 100*MB // 其他类型最大100MB
}
}
// randString 生成随机字符串
func randString(n int) string {
const letters = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, n)
for i := range b {
b[i] = letters[int(time.Now().UnixNano()%int64(len(letters)))]
time.Sleep(1 * time.Nanosecond) // 确保随机性
}
return string(b)
}
// isImageExt 检查是否为图片扩展名
func isImageExt(ext string) bool {
imageExts := []string{".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp", ".svg"}
for _, allowedExt := range imageExts {
if ext == allowedExt {
return true
}
}
return false
}
// isDocumentExt 检查是否为文档扩展名
func isDocumentExt(ext string) bool {
docExts := []string{".pdf", ".doc", ".docx", ".xls", ".xlsx", ".ppt", ".pptx", ".txt", ".rtf"}
for _, allowedExt := range docExts {
if ext == allowedExt {
return true
}
}
return false
}
// isVideoExt 检查是否为视频扩展名
func isVideoExt(ext string) bool {
videoExts := []string{".mp4", ".avi", ".mov", ".wmv", ".flv", ".webm", ".mkv"}
for _, allowedExt := range videoExts {
if ext == allowedExt {
return true
}
}
return false
}
// isAudioExt 检查是否为音频扩展名
func isAudioExt(ext string) bool {
audioExts := []string{".mp3", ".wav", ".flac", ".aac", ".ogg", ".wma"}
for _, allowedExt := range audioExts {
if ext == allowedExt {
return true
}
}
return false
}

157
internal/handler/helper.go Normal file
View File

@@ -0,0 +1,157 @@
package handler
import (
"github.com/gin-gonic/gin"
"github.com/limitcool/starter/internal/api/response"
"github.com/limitcool/starter/internal/pkg/errorx"
"github.com/limitcool/starter/internal/pkg/logger"
"github.com/spf13/cast"
)
// HandlerHelper 处理器辅助工具
type HandlerHelper struct{}
// NewHandlerHelper 创建处理器辅助工具
func NewHandlerHelper() *HandlerHelper {
return &HandlerHelper{}
}
// GetUserID 从上下文中获取用户ID如果不存在则返回错误响应
func (h *HandlerHelper) GetUserID(ctx *gin.Context) (int64, bool) {
reqCtx := ctx.Request.Context()
userID, exists := ctx.Get("user_id")
if !exists {
logger.WarnContext(reqCtx, "用户ID不存在")
response.Error(ctx, errorx.ErrUserNoLogin)
return 0, false
}
return cast.ToInt64(userID), true
}
// BindJSON 绑定JSON参数如果失败则返回错误响应
func (h *HandlerHelper) BindJSON(ctx *gin.Context, req interface{}, operation string) bool {
reqCtx := ctx.Request.Context()
if err := ctx.ShouldBindJSON(req); err != nil {
logger.WarnContext(reqCtx, operation+" request validation failed",
"error", err,
"client_ip", ctx.ClientIP())
response.Error(ctx, errorx.ErrInvalidParams.WithError(err))
return false
}
return true
}
// HandleDBError 处理数据库错误,统一日志记录和错误响应
func (h *HandlerHelper) HandleDBError(ctx *gin.Context, err error, operation string, fields ...interface{}) {
reqCtx := ctx.Request.Context()
// 构建日志字段
logFields := []interface{}{
"error", err,
"operation", operation,
}
logFields = append(logFields, fields...)
logger.ErrorContext(reqCtx, operation+" database operation failed", logFields...)
response.Error(ctx, err)
}
// HandleNotFoundError 处理资源不存在错误
func (h *HandlerHelper) HandleNotFoundError(ctx *gin.Context, err error, operation string, fields ...interface{}) {
reqCtx := ctx.Request.Context()
// 构建日志字段
logFields := []interface{}{
"operation", operation,
}
logFields = append(logFields, fields...)
logger.WarnContext(reqCtx, operation+" resource not found", logFields...)
response.Error(ctx, err)
}
// LogSuccess 记录成功操作日志
func (h *HandlerHelper) LogSuccess(ctx *gin.Context, operation string, fields ...interface{}) {
reqCtx := ctx.Request.Context()
// 构建日志字段
logFields := []interface{}{
"operation", operation,
}
logFields = append(logFields, fields...)
logger.InfoContext(reqCtx, operation+" successful", logFields...)
}
// LogWarning 记录警告日志
func (h *HandlerHelper) LogWarning(ctx *gin.Context, message string, fields ...interface{}) {
reqCtx := ctx.Request.Context()
logger.WarnContext(reqCtx, message, fields...)
}
// LogError 记录错误日志
func (h *HandlerHelper) LogError(ctx *gin.Context, message string, fields ...interface{}) {
reqCtx := ctx.Request.Context()
logger.ErrorContext(reqCtx, message, fields...)
}
// CheckPermission 检查用户权限(是否为管理员或资源所有者)
func (h *HandlerHelper) CheckPermission(ctx *gin.Context, userID int64, resourceOwnerID int64, operation string) bool {
reqCtx := ctx.Request.Context()
// 如果是资源所有者,直接允许
if userID == resourceOwnerID {
return true
}
// 检查是否是管理员
isAdmin, exists := ctx.Get("is_admin")
if exists && cast.ToBool(isAdmin) {
return true
}
// 权限不足
logger.WarnContext(reqCtx, operation+" permission denied",
"user_id", userID,
"resource_owner_id", resourceOwnerID,
"is_admin", isAdmin)
response.Error(ctx, errorx.ErrForbidden.WithMsg("无权限操作此资源"))
return false
}
// GetClientInfo 获取客户端信息
func (h *HandlerHelper) GetClientInfo(ctx *gin.Context) (string, string) {
return ctx.ClientIP(), ctx.Request.UserAgent()
}
// ValidateID 验证ID参数
func (h *HandlerHelper) ValidateID(ctx *gin.Context, idStr string, operation string) (uint, bool) {
reqCtx := ctx.Request.Context()
id := cast.ToUint(idStr)
if id == 0 {
logger.WarnContext(reqCtx, operation+" invalid ID", "id", idStr)
response.Error(ctx, errorx.ErrInvalidParams.WithMsg("ID无效"))
return 0, false
}
return id, true
}
// ValidateInt64ID 验证int64类型的ID参数
func (h *HandlerHelper) ValidateInt64ID(ctx *gin.Context, idStr string, operation string) (int64, bool) {
reqCtx := ctx.Request.Context()
id := cast.ToInt64(idStr)
if id == 0 {
logger.WarnContext(reqCtx, operation+" invalid ID", "id", idStr)
response.Error(ctx, errorx.ErrInvalidParams.WithMsg("ID无效"))
return 0, false
}
return id, true
}

View File

@@ -12,26 +12,23 @@ import (
"github.com/limitcool/starter/internal/pkg/crypto"
"github.com/limitcool/starter/internal/pkg/errorx"
"github.com/limitcool/starter/internal/pkg/logger"
"github.com/spf13/cast"
"gorm.io/gorm"
)
// UserHandler 用户处理器
type UserHandler struct {
db *gorm.DB
config *configs.Config
*BaseHandler
authService *AuthService
}
// NewUserHandler 创建用户处理器
func NewUserHandler(db *gorm.DB, config *configs.Config) *UserHandler {
handler := &UserHandler{
db: db,
config: config,
BaseHandler: NewBaseHandler(db, config),
authService: NewAuthService(config),
}
logger.Info("UserHandler initialized")
handler.LogInit("UserHandler")
return handler
}
@@ -64,7 +61,7 @@ func (h *UserHandler) UserLogin(ctx *gin.Context) {
"ip", clientIP)
// 创建用户仓库
userRepo := model.NewUserRepo(h.db)
userRepo := model.NewUserRepo(h.DB)
// 查询用户
user, err := userRepo.GetByUsername(reqCtx, req.Username)
@@ -161,7 +158,7 @@ func (h *UserHandler) UserRegister(ctx *gin.Context) {
clientIP := ctx.ClientIP()
// 创建用户仓库
userRepo := model.NewUserRepo(h.db)
userRepo := model.NewUserRepo(h.DB)
// 检查用户名是否已存在
exists, err := userRepo.IsExist(reqCtx, req.Username)
@@ -230,97 +227,64 @@ func (h *UserHandler) UserRegister(ctx *gin.Context) {
// UserInfo 获取用户信息
func (h *UserHandler) UserInfo(ctx *gin.Context) {
// 获取请求上下文
reqCtx := ctx.Request.Context()
// 从上下文中获取用户ID
userID, exists := ctx.Get("user_id")
if !exists {
logger.WarnContext(reqCtx, "UserInfo user ID not found")
response.Error(ctx, errorx.ErrUserNoLogin)
// 获取用户ID
id, ok := h.Helper.GetUserID(ctx)
if !ok {
return
}
// 转换用户ID
id := cast.ToInt64(userID)
// 创建用户仓库
userRepo := model.NewUserRepo(h.db)
userRepo := model.NewUserRepo(h.DB)
// 查询用户信息
user, err := userRepo.GetByID(reqCtx, id)
user, err := userRepo.GetByID(ctx.Request.Context(), id)
if err != nil {
if errors.Is(err, errorx.ErrUserNotFound) {
logger.WarnContext(reqCtx, "UserInfo user not found",
"user_id", id)
response.Error(ctx, err)
h.Helper.HandleNotFoundError(ctx, err, "UserInfo", "user_id", id)
return
}
logger.ErrorContext(reqCtx, "UserInfo failed to query user",
"error", err,
"user_id", id)
response.Error(ctx, err)
h.Helper.HandleDBError(ctx, err, "UserInfo", "user_id", id)
return
}
// 隐藏敏感信息
user.Password = ""
logger.InfoContext(reqCtx, "UserInfo get user info successful",
"user_id", id)
h.Helper.LogSuccess(ctx, "UserInfo", "user_id", id)
response.Success(ctx, user)
}
// UserChangePassword 修改密码
func (h *UserHandler) UserChangePassword(ctx *gin.Context) {
// 获取请求上下文
reqCtx := ctx.Request.Context()
// 从上下文中获取用户ID
userID, exists := ctx.Get("user_id")
if !exists {
logger.WarnContext(reqCtx, "UserChangePassword user ID not found")
response.Error(ctx, errorx.ErrUserNoLogin)
// 获取用户ID
id, ok := h.Helper.GetUserID(ctx)
if !ok {
return
}
// 转换用户ID
id := cast.ToInt64(userID)
// 绑定请求参数
var req v1.UserChangePasswordRequest
if err := ctx.ShouldBindJSON(&req); err != nil {
logger.WarnContext(reqCtx, "UserChangePassword request validation failed",
"error", err,
"user_id", id)
response.Error(ctx, errorx.ErrInvalidParams.WithError(err))
if !h.Helper.BindJSON(ctx, &req, "UserChangePassword") {
return
}
// 创建用户仓库
userRepo := model.NewUserRepo(h.db)
userRepo := model.NewUserRepo(h.DB)
// 查询用户
user, err := userRepo.GetByID(reqCtx, id)
user, err := userRepo.GetByID(ctx.Request.Context(), id)
if err != nil {
if errors.Is(err, errorx.ErrUserNotFound) {
logger.WarnContext(reqCtx, "UserChangePassword user not found",
"user_id", id)
response.Error(ctx, err)
h.Helper.HandleNotFoundError(ctx, err, "UserChangePassword", "user_id", id)
return
}
logger.ErrorContext(reqCtx, "UserChangePassword failed to query user",
"error", err,
"user_id", id)
response.Error(ctx, err)
h.Helper.HandleDBError(ctx, err, "UserChangePassword", "user_id", id)
return
}
// 验证旧密码
if !crypto.CheckPassword(user.Password, req.OldPassword) {
logger.WarnContext(reqCtx, "UserChangePassword old password incorrect",
"user_id", id)
h.Helper.LogWarning(ctx, "UserChangePassword old password incorrect", "user_id", id)
response.Error(ctx, errorx.Errorf(errorx.ErrUserPasswordError, "旧密码错误"))
return
}
@@ -328,24 +292,17 @@ func (h *UserHandler) UserChangePassword(ctx *gin.Context) {
// 哈希新密码
hashedPassword, err := crypto.HashPassword(req.NewPassword)
if err != nil {
logger.ErrorContext(reqCtx, "UserChangePassword failed to hash password",
"error", err,
"user_id", id)
h.Helper.LogError(ctx, "UserChangePassword failed to hash password", "error", err, "user_id", id)
response.Error(ctx, errorx.WrapError(err, "密码加密失败"))
return
}
// 更新密码
if err := userRepo.UpdatePassword(reqCtx, id, hashedPassword); err != nil {
logger.ErrorContext(reqCtx, "UserChangePassword failed to update password",
"error", err,
"user_id", id)
response.Error(ctx, err)
if err := userRepo.UpdatePassword(ctx.Request.Context(), id, hashedPassword); err != nil {
h.Helper.HandleDBError(ctx, err, "UserChangePassword", "user_id", id)
return
}
logger.InfoContext(reqCtx, "UserChangePassword password change successful",
"user_id", id)
h.Helper.LogSuccess(ctx, "UserChangePassword", "user_id", id)
response.SuccessNoData(ctx, "密码修改成功")
}

View File

@@ -11,29 +11,9 @@ import (
// AdminCheck 管理员检查中间件 - 基于JWT中的is_admin字段
func AdminCheck() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取请求上下文
ctx := c.Request.Context()
// 从上下文中获取用户ID
_, exists := c.Get("user_id")
if !exists {
logger.WarnContext(ctx, "AdminCheck 未找到用户ID")
response.Error(c, errorx.ErrUserNoLogin)
c.Abort()
if !CheckAdminPermission(c) {
return
}
// 检查用户是否为管理员
isAdmin, ok := c.Get("is_admin")
if !ok || !isAdmin.(bool) {
logger.WarnContext(ctx, "AdminCheck 用户不是管理员",
"is_admin", isAdmin)
response.Error(c, errorx.ErrUserNoLogin.WithMsg("用户无权限"))
c.Abort()
return
}
// 继续处理请求
c.Next()
}
}

View File

@@ -1,37 +0,0 @@
package middleware
import (
"context"
"fmt"
"time"
"github.com/gin-gonic/gin"
)
// RequestContext 中间件同时处理请求ID和链路追踪ID
func RequestContext() gin.HandlerFunc {
return func(c *gin.Context) {
// 处理请求ID
requestID := c.GetHeader("X-Request-ID")
if requestID == "" {
requestID = fmt.Sprintf("req-%d", time.Now().UnixNano())
}
c.Set("request_id", requestID)
c.Header("X-Request-ID", requestID)
// 处理链路追踪ID
traceID := c.GetHeader("X-Trace-ID")
if traceID == "" {
traceID = fmt.Sprintf("trace-%d", time.Now().UnixNano())
}
c.Set("trace_id", traceID)
c.Header("X-Trace-ID", traceID)
// 将请求ID和链路追踪ID添加到context.Context中
ctx := context.WithValue(c.Request.Context(), "request_id", requestID)
ctx = context.WithValue(ctx, "trace_id", traceID)
c.Request = c.Request.WithContext(ctx)
c.Next()
}
}

View File

@@ -2,36 +2,45 @@ package middleware
import (
"context"
"fmt"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/limitcool/starter/internal/pkg/logger"
)
// RequestLoggerMiddleware 是一个记录请求日志的中间件
// RequestLoggerMiddleware 是一个记录请求日志的中间件同时处理请求ID和链路追踪ID
func RequestLoggerMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// 记录开始时间
start := time.Now()
// 获取请求ID
// 处理请求ID
requestID := c.GetHeader("X-Request-ID")
if requestID == "" {
requestID = uuid.New().String()
c.Request.Header.Set("X-Request-ID", requestID)
c.Set("request_id", requestID) // 同时存入上下文
// 将请求ID添加到context.Context中
ctx := context.WithValue(c.Request.Context(), "request_id", requestID)
c.Request = c.Request.WithContext(ctx)
requestID = fmt.Sprintf("req-%d", time.Now().UnixNano())
}
c.Set("request_id", requestID)
c.Header("X-Request-ID", requestID)
// 处理链路追踪ID
traceID := c.GetHeader("X-Trace-ID")
if traceID == "" {
traceID = fmt.Sprintf("trace-%d", time.Now().UnixNano())
}
c.Set("trace_id", traceID)
c.Header("X-Trace-ID", traceID)
// 将请求ID和链路追踪ID添加到context.Context中
ctx := context.WithValue(c.Request.Context(), "request_id", requestID)
ctx = context.WithValue(ctx, "trace_id", traceID)
c.Request = c.Request.WithContext(ctx)
// 处理请求
c.Next()
// 获取请求上下文
ctx := c.Request.Context()
reqCtx := c.Request.Context()
// 计算延迟
latency := time.Since(start)
@@ -59,11 +68,11 @@ func RequestLoggerMiddleware() gin.HandlerFunc {
// 根据状态码选择日志级别
if status >= 500 {
logger.ErrorContext(ctx, "Server error", fields...)
logger.ErrorContext(reqCtx, "Server error", fields...)
} else if status >= 400 {
logger.WarnContext(ctx, "Client error", fields...)
logger.WarnContext(reqCtx, "Client error", fields...)
} else {
logger.InfoContext(ctx, "Request completed", fields...)
logger.InfoContext(reqCtx, "Request completed", fields...)
}
}
}

View File

@@ -11,19 +11,9 @@ import (
// UserCheck 用户检查中间件 - 确保用户已登录
func UserCheck() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取请求上下文
ctx := c.Request.Context()
// 从上下文获取用户ID
_, exists := c.Get("user_id")
if !exists {
logger.WarnContext(ctx, "用户ID不存在")
response.Error(c, errorx.ErrUserNoLogin)
c.Abort()
if !CheckUserLogin(c) {
return
}
// 继续处理请求
c.Next()
}
}
@@ -62,28 +52,21 @@ func UserCheckWithDB(userRepo *model.UserRepo) gin.HandlerFunc {
// 适用于只允许普通用户访问的接口
func RegularUserCheck() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取请求上下文
ctx := c.Request.Context()
// 从上下文获取用户ID
_, exists := c.Get("user_id")
if !exists {
logger.WarnContext(ctx, "用户ID不存在")
response.Error(c, errorx.ErrUserNoLogin)
c.Abort()
// 先检查是否已登录
if !CheckUserLogin(c) {
return
}
// 检查用户是否为管理员
isAdmin, ok := c.Get("is_admin")
if ok && isAdmin.(bool) {
ctx := c.Request.Context()
logger.WarnContext(ctx, "管理员不能访问普通用户接口")
response.Error(c, errorx.ErrAccessDenied)
c.Abort()
return
}
// 继续处理请求
c.Next()
}
}

View File

@@ -4,6 +4,9 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/limitcool/starter/internal/api/response"
"github.com/limitcool/starter/internal/pkg/errorx"
"github.com/limitcool/starter/internal/pkg/logger"
)
// GetUserID 从上下文中获取用户ID
@@ -64,3 +67,39 @@ func GetUserIDString(c *gin.Context) string {
}
return fmt.Sprintf("%d", id)
}
// CheckUserLogin 检查用户是否已登录,如果未登录则返回错误响应
func CheckUserLogin(c *gin.Context) bool {
ctx := c.Request.Context()
_, exists := c.Get("user_id")
if !exists {
logger.WarnContext(ctx, "用户ID不存在")
response.Error(c, errorx.ErrUserNoLogin)
c.Abort()
return false
}
return true
}
// CheckAdminPermission 检查用户是否为管理员,如果不是则返回错误响应
func CheckAdminPermission(c *gin.Context) bool {
ctx := c.Request.Context()
// 先检查是否已登录
if !CheckUserLogin(c) {
return false
}
// 检查用户是否为管理员
isAdmin, ok := c.Get("is_admin")
if !ok || !isAdmin.(bool) {
logger.WarnContext(ctx, "用户不是管理员", "is_admin", isAdmin)
response.Error(c, errorx.ErrUserNoLogin.WithMsg("用户无权限"))
c.Abort()
return false
}
return true
}

View File

@@ -54,8 +54,13 @@ func (File) TableName() string {
// FileRepo 文件仓库
type FileRepo struct {
DB *gorm.DB
GenericRepo *GenericRepo[File]
*GenericRepo[File]
fileUtil FileURLBuilder // 文件URL构建器接口
}
// FileURLBuilder 文件URL构建器接口
type FileURLBuilder interface {
BuildFileURL(path string) string
}
// NewFileRepo 创建文件仓库
@@ -64,52 +69,51 @@ func NewFileRepo(db *gorm.DB) *FileRepo {
genericRepo.ErrorCode = errorx.ErrorFileNotFoundCode
return &FileRepo{
DB: db,
GenericRepo: genericRepo,
fileUtil: &defaultFileURLBuilder{}, // 使用默认实现
}
}
// Create 创建文件记录
func (r *FileRepo) Create(ctx context.Context, file *File) error {
return r.GenericRepo.Create(ctx, file)
// NewFileRepoWithURLBuilder 创建带有自定义URL构建器的文件仓库
func NewFileRepoWithURLBuilder(db *gorm.DB, urlBuilder FileURLBuilder) *FileRepo {
genericRepo := NewGenericRepo[File](db)
genericRepo.ErrorCode = errorx.ErrorFileNotFoundCode
return &FileRepo{
GenericRepo: genericRepo,
fileUtil: urlBuilder,
}
}
// defaultFileURLBuilder 默认的文件URL构建器
type defaultFileURLBuilder struct{}
func (d *defaultFileURLBuilder) BuildFileURL(path string) string {
if path == "" {
return ""
}
return "/uploads/" + path
}
// GetByID 根据ID获取文件
func (r *FileRepo) GetByID(ctx context.Context, id uint) (*File, error) {
file, err := r.GenericRepo.Get(ctx, id, nil)
file, err := r.Get(ctx, id, nil)
if err != nil {
return nil, errorx.WrapError(err, "查询文件失败")
}
// 设置URL字段
if file.Path != "" {
file.URL = r.buildFileURL(file.Path)
file.URL = r.fileUtil.BuildFileURL(file.Path)
}
return file, nil
}
// buildFileURL 构建文件URL
func (r *FileRepo) buildFileURL(path string) string {
// 这里可以根据实际情况构建URL例如添加域名前缀等
// 简单示例:
return "/uploads/" + path
}
// Update 更新文件记录
func (r *FileRepo) Update(ctx context.Context, file *File) error {
return r.GenericRepo.Update(ctx, file)
}
// Delete 删除文件记录
func (r *FileRepo) Delete(ctx context.Context, id uint) error {
return r.GenericRepo.Delete(ctx, id)
}
// UpdateFileUsage 更新文件用途
func (r *FileRepo) UpdateFileUsage(ctx context.Context, file *File, usage string) error {
file.Usage = usage
return r.GenericRepo.Update(ctx, file)
return r.Update(ctx, file)
}
// ListByUser 获取用户的文件列表
@@ -119,7 +123,7 @@ func (r *FileRepo) ListByUser(ctx context.Context, userID int64, page, pageSize
Condition: "uploaded_by = ? AND uploaded_by_type = ?",
Args: []any{userID, 2},
}
files, err := r.GenericRepo.List(ctx, page, pageSize, opts)
files, err := r.List(ctx, page, pageSize, opts)
if err != nil {
return nil, errorx.WrapError(err, "查询用户文件列表失败")
}
@@ -127,7 +131,7 @@ func (r *FileRepo) ListByUser(ctx context.Context, userID int64, page, pageSize
// 为所有文件设置URL
for i := range files {
if files[i].Path != "" {
files[i].URL = r.buildFileURL(files[i].Path)
files[i].URL = r.fileUtil.BuildFileURL(files[i].Path)
}
}
@@ -136,7 +140,7 @@ func (r *FileRepo) ListByUser(ctx context.Context, userID int64, page, pageSize
// CountByUser 获取用户的文件总数
func (r *FileRepo) CountByUser(ctx context.Context, userID int64) (int64, error) {
count, err := r.GenericRepo.Count(ctx, &QueryOptions{
count, err := r.Count(ctx, &QueryOptions{
Condition: "uploaded_by = ? AND uploaded_by_type = ?",
Args: []any{userID, 2},
})
@@ -180,7 +184,7 @@ func (r *FileRepo) ListFiles(ctx context.Context, page, pageSize int, fileType,
}
// 获取文件列表
files, err := r.GenericRepo.List(ctx, page, pageSize, opts)
files, err := r.List(ctx, page, pageSize, opts)
if err != nil {
return nil, 0, errorx.WrapError(err, "查询文件列表失败")
}
@@ -188,12 +192,12 @@ func (r *FileRepo) ListFiles(ctx context.Context, page, pageSize int, fileType,
// 为所有文件设置URL
for i := range files {
if files[i].Path != "" {
files[i].URL = r.buildFileURL(files[i].Path)
files[i].URL = r.fileUtil.BuildFileURL(files[i].Path)
}
}
// 获取总数
total, err := r.GenericRepo.Count(ctx, opts)
total, err := r.Count(ctx, opts)
if err != nil {
return nil, 0, errorx.WrapError(err, "查询文件总数失败")
}

View File

@@ -47,8 +47,7 @@ func NewUser() *User {
// UserRepo 用户仓库
type UserRepo struct {
DB *gorm.DB
GenericRepo *GenericRepo[User]
*GenericRepo[User]
}
// NewUserRepo 创建用户仓库
@@ -57,14 +56,13 @@ func NewUserRepo(db *gorm.DB) *UserRepo {
genericRepo.ErrorCode = errorx.ErrorUserNotFoundCode
return &UserRepo{
DB: db,
GenericRepo: genericRepo,
}
}
// GetByID 根据ID获取用户
func (r *UserRepo) GetByID(ctx context.Context, id int64) (*User, error) {
user, err := r.GenericRepo.Get(ctx, id, nil)
user, err := r.Get(ctx, id, nil)
if err != nil {
return nil, errorx.WrapError(err, "查询用户失败")
}
@@ -81,7 +79,7 @@ func (r *UserRepo) GetUserWithAvatar(ctx context.Context, id int64) (*User, erro
// 如果用户有头像,再预加载头像
if user.AvatarFileID > 0 {
user, err = r.GenericRepo.Get(ctx, id, &QueryOptions{
user, err = r.Get(ctx, id, &QueryOptions{
Preloads: []string{"AvatarFile"},
})
if err != nil {
@@ -99,7 +97,7 @@ func (r *UserRepo) GetUserWithAvatar(ctx context.Context, id int64) (*User, erro
// GetByUsername 根据用户名获取用户
func (r *UserRepo) GetByUsername(ctx context.Context, username string) (*User, error) {
user, err := r.GenericRepo.Get(ctx, nil, &QueryOptions{
user, err := r.Get(ctx, nil, &QueryOptions{
Condition: "username = ?",
Args: []any{username},
})
@@ -112,24 +110,9 @@ func (r *UserRepo) GetByUsername(ctx context.Context, username string) (*User, e
return user, nil
}
// Create 创建用户
func (r *UserRepo) Create(ctx context.Context, user *User) error {
return r.GenericRepo.Create(ctx, user)
}
// Update 更新用户
func (r *UserRepo) Update(ctx context.Context, user *User) error {
return r.GenericRepo.Update(ctx, user)
}
// Delete 删除用户
func (r *UserRepo) Delete(ctx context.Context, id int64) error {
return r.GenericRepo.Delete(ctx, id)
}
// IsExist 检查用户是否存在
func (r *UserRepo) IsExist(ctx context.Context, username string) (bool, error) {
count, err := r.GenericRepo.Count(ctx, &QueryOptions{
count, err := r.Count(ctx, &QueryOptions{
Condition: "username = ?",
Args: []any{username},
})
@@ -157,7 +140,7 @@ func (r *UserRepo) ListUsers(ctx context.Context, page, pageSize int, keyword st
}
// 获取用户列表
users, err := r.GenericRepo.List(ctx, page, pageSize, opts)
users, err := r.List(ctx, page, pageSize, opts)
if err != nil {
return nil, 0, errorx.WrapError(err, "查询用户列表失败")
}
@@ -170,7 +153,7 @@ func (r *UserRepo) ListUsers(ctx context.Context, page, pageSize int, keyword st
}
// 获取总数
total, err := r.GenericRepo.Count(ctx, opts)
total, err := r.Count(ctx, opts)
if err != nil {
return nil, 0, errorx.WrapError(err, "查询用户总数失败")
}

View File

@@ -1,47 +0,0 @@
package test
import (
"fmt"
"os"
"os/exec"
"path/filepath"
)
// 运行所有测试
func main() {
// 获取项目根目录
rootDir, err := os.Getwd()
if err != nil {
fmt.Printf("获取当前目录失败: %v\n", err)
os.Exit(1)
}
// 切换到项目根目录
rootDir = filepath.Dir(rootDir)
if err := os.Chdir(rootDir); err != nil {
fmt.Printf("切换到项目根目录失败: %v\n", err)
os.Exit(1)
}
// 运行单元测试
fmt.Println("运行单元测试...")
cmd := exec.Command("go", "test", "./internal/pkg/casbin", "./internal/middleware", "-v")
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
fmt.Printf("单元测试失败: %v\n", err)
os.Exit(1)
}
// 运行集成测试
fmt.Println("\n运行集成测试...")
cmd = exec.Command("go", "test", "./test/integration", "-v")
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
fmt.Printf("集成测试失败: %v\n", err)
os.Exit(1)
}
fmt.Println("\n所有测试通过!")
}