mirror of
https://github.com/limitcool/starter.git
synced 2025-09-26 20:31:25 +08:00
Refactor user handler and middleware for improved error handling and logging
- Consolidated user ID retrieval and permission checks into helper functions. - Updated UserHandler to utilize BaseHandler for common database and configuration access. - Enhanced logging for user-related operations, including login, registration, and password changes. - Removed redundant context handling in middleware and improved readability. - Introduced FileUtil for file URL generation and management, encapsulating file-related logic. - Refactored FileRepo and UserRepo to streamline database operations and error handling. - Deleted unused request_id middleware and integrated its functionality into request_logger. - Removed legacy test runner script to simplify testing process.
This commit is contained in:
6
Makefile
6
Makefile
@@ -86,7 +86,11 @@ version:
|
||||
.PHONY: build-dev
|
||||
build-dev:
|
||||
@echo "Building $(APP_NAME) for development..."
|
||||
go build $(BUILDFLAGS) -ldflags="$(LDFLAGS)" -o $(APP_NAME)$(if $(filter windows,$(shell go env GOOS)),.exe,)
|
||||
ifeq ($(OS),Windows_NT)
|
||||
go build $(BUILDFLAGS) -ldflags="$(LDFLAGS)" -o $(APP_NAME).exe
|
||||
else
|
||||
go build $(BUILDFLAGS) -ldflags="$(LDFLAGS)" -o $(APP_NAME)
|
||||
endif
|
||||
|
||||
# Docker 相关目标
|
||||
.PHONY: docker-build
|
||||
|
@@ -42,44 +42,65 @@ type Handlers struct {
|
||||
Admin *handler.AdminHandler
|
||||
}
|
||||
|
||||
// InitStep 初始化步骤
|
||||
type InitStep struct {
|
||||
Name string
|
||||
Required bool
|
||||
Init func() error
|
||||
}
|
||||
|
||||
// New 创建新的应用实例
|
||||
func New(config *configs.Config) (*App, error) {
|
||||
app := &App{
|
||||
config: config,
|
||||
}
|
||||
app := &App{config: config}
|
||||
|
||||
// 按顺序初始化各个组件
|
||||
if err := app.initDatabase(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize database: %w", err)
|
||||
}
|
||||
// 定义初始化步骤
|
||||
steps := app.getInitSteps()
|
||||
|
||||
if err := app.initRedis(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize redis: %w", err)
|
||||
}
|
||||
// 按顺序执行初始化
|
||||
for _, step := range steps {
|
||||
logger.Info("Initializing component", "component", step.Name)
|
||||
|
||||
if err := app.initStorage(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize storage: %w", err)
|
||||
}
|
||||
if err := step.Init(); err != nil {
|
||||
logger.Error("Failed to initialize component",
|
||||
"component", step.Name,
|
||||
"required", step.Required,
|
||||
"error", err)
|
||||
|
||||
if err := app.initHandlers(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize handlers: %w", err)
|
||||
}
|
||||
if step.Required {
|
||||
return nil, fmt.Errorf("failed to initialize required component %s: %w", step.Name, err)
|
||||
}
|
||||
|
||||
if err := app.initRouter(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize router: %w", err)
|
||||
}
|
||||
logger.Warn("Optional component initialization failed, continuing",
|
||||
"component", step.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := app.initServer(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize server: %w", err)
|
||||
}
|
||||
|
||||
if err := app.initPprof(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize pprof: %w", err)
|
||||
logger.Info("Component initialized successfully", "component", step.Name)
|
||||
}
|
||||
|
||||
return app, nil
|
||||
}
|
||||
|
||||
// getInitSteps 获取初始化步骤列表
|
||||
func (app *App) getInitSteps() []InitStep {
|
||||
steps := []InitStep{
|
||||
// 数据库和Redis根据配置启用,失败时不影响应用启动(内部有禁用检查)
|
||||
{Name: "database", Required: false, Init: app.initDatabase},
|
||||
{Name: "redis", Required: false, Init: app.initRedis},
|
||||
|
||||
// 存储服务是可选的,某些功能可能需要它
|
||||
{Name: "storage", Required: false, Init: app.initStorage},
|
||||
|
||||
// 核心组件,必须成功初始化
|
||||
{Name: "handlers", Required: true, Init: app.initHandlers},
|
||||
{Name: "router", Required: true, Init: app.initRouter},
|
||||
{Name: "server", Required: true, Init: app.initServer},
|
||||
{Name: "pprof", Required: false, Init: app.initPprof},
|
||||
}
|
||||
|
||||
return steps
|
||||
}
|
||||
|
||||
// initDatabase 初始化数据库连接
|
||||
func (a *App) initDatabase() error {
|
||||
if !a.config.Database.Enabled {
|
||||
|
@@ -10,18 +10,16 @@ import (
|
||||
|
||||
// AdminHandler 管理员处理器
|
||||
type AdminHandler struct {
|
||||
db *gorm.DB
|
||||
config *configs.Config
|
||||
*BaseHandler
|
||||
}
|
||||
|
||||
// NewAdminHandler 创建管理员处理器
|
||||
func NewAdminHandler(db *gorm.DB, config *configs.Config) *AdminHandler {
|
||||
handler := &AdminHandler{
|
||||
db: db,
|
||||
config: config,
|
||||
BaseHandler: NewBaseHandler(db, config),
|
||||
}
|
||||
|
||||
logger.Info("AdminHandler initialized")
|
||||
handler.LogInit("AdminHandler")
|
||||
return handler
|
||||
}
|
||||
|
||||
@@ -35,9 +33,9 @@ func (h *AdminHandler) GetSystemSettings(ctx *gin.Context) {
|
||||
|
||||
// 返回系统设置
|
||||
settings := map[string]any{
|
||||
"app_name": h.config.App.Name,
|
||||
"app_name": h.Config.App.Name,
|
||||
"app_version": "1.0.0",
|
||||
"app_mode": h.config.App.Mode,
|
||||
"app_mode": h.Config.App.Mode,
|
||||
}
|
||||
|
||||
response.Success(ctx, settings)
|
||||
|
30
internal/handler/base_handler.go
Normal file
30
internal/handler/base_handler.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/limitcool/starter/configs"
|
||||
"github.com/limitcool/starter/internal/pkg/logger"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// BaseHandler 基础处理器,包含所有Handler的公共字段和方法
|
||||
type BaseHandler struct {
|
||||
DB *gorm.DB
|
||||
Config *configs.Config
|
||||
Helper *HandlerHelper
|
||||
FileUtil *FileUtil
|
||||
}
|
||||
|
||||
// NewBaseHandler 创建基础处理器
|
||||
func NewBaseHandler(db *gorm.DB, config *configs.Config) *BaseHandler {
|
||||
return &BaseHandler{
|
||||
DB: db,
|
||||
Config: config,
|
||||
Helper: NewHandlerHelper(),
|
||||
FileUtil: NewFileUtil("/uploads"), // 默认基础URL
|
||||
}
|
||||
}
|
||||
|
||||
// LogInit 记录Handler初始化日志
|
||||
func (h *BaseHandler) LogInit(handlerName string) {
|
||||
logger.Info(handlerName + " initialized")
|
||||
}
|
@@ -19,20 +19,18 @@ import (
|
||||
|
||||
// FileHandler 文件处理器
|
||||
type FileHandler struct {
|
||||
db *gorm.DB
|
||||
config *configs.Config
|
||||
*BaseHandler
|
||||
storage *filestore.Storage
|
||||
}
|
||||
|
||||
// NewFileHandler 创建文件处理器
|
||||
func NewFileHandler(db *gorm.DB, config *configs.Config, storage *filestore.Storage) *FileHandler {
|
||||
handler := &FileHandler{
|
||||
db: db,
|
||||
config: config,
|
||||
storage: storage,
|
||||
BaseHandler: NewBaseHandler(db, config),
|
||||
storage: storage,
|
||||
}
|
||||
|
||||
logger.Info("FileHandler initialized")
|
||||
handler.LogInit("FileHandler")
|
||||
return handler
|
||||
}
|
||||
|
||||
@@ -72,7 +70,7 @@ func (h *FileHandler) GetUploadURL(ctx *gin.Context) {
|
||||
|
||||
// 验证文件类型
|
||||
ext := strings.ToLower(filepath.Ext(req.Filename))
|
||||
if !isAllowedFileType(ext, req.FileType) {
|
||||
if !h.FileUtil.IsAllowedFileType(ext, req.FileType) {
|
||||
logger.WarnContext(reqCtx, "GetUploadURL 不支持的文件类型",
|
||||
"user_id", id,
|
||||
"file_type", req.FileType,
|
||||
@@ -82,8 +80,8 @@ func (h *FileHandler) GetUploadURL(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// 生成唯一的文件名和存储路径
|
||||
fileName := generateFileName(req.Filename)
|
||||
storagePath := getStoragePath(req.FileType, fileName, req.IsPublic)
|
||||
fileName := h.FileUtil.GenerateFileName(req.Filename)
|
||||
storagePath := h.FileUtil.GetStoragePath(req.FileType, fileName, req.IsPublic)
|
||||
|
||||
// 生成上传预签名URL
|
||||
uploadURL, err := h.storage.GetUploadPresignedURL(reqCtx, storagePath, req.ContentType, 15) // 15分钟有效期
|
||||
@@ -97,7 +95,7 @@ func (h *FileHandler) GetUploadURL(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// 创建文件记录(状态为pending,等待上传完成确认)
|
||||
fileRepo := model.NewFileRepo(h.db)
|
||||
fileRepo := model.NewFileRepo(h.DB)
|
||||
fileModel := &model.File{
|
||||
Name: fileName,
|
||||
OriginalName: req.Filename,
|
||||
@@ -172,7 +170,7 @@ func (h *FileHandler) ConfirmUpload(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// 创建文件仓库
|
||||
fileRepo := model.NewFileRepo(h.db)
|
||||
fileRepo := model.NewFileRepo(h.DB)
|
||||
|
||||
// 获取文件记录
|
||||
fileModel, err := fileRepo.GetByID(reqCtx, req.FileID)
|
||||
@@ -312,7 +310,7 @@ func (h *FileHandler) UploadFile(ctx *gin.Context) {
|
||||
size := fileHeader.Size
|
||||
|
||||
// 验证文件类型
|
||||
if !isAllowedFileType(ext, fileType) {
|
||||
if !h.FileUtil.IsAllowedFileType(ext, fileType) {
|
||||
logger.WarnContext(reqCtx, "UploadFile 文件类型不允许",
|
||||
"user_id", id,
|
||||
"file_type", fileType,
|
||||
@@ -322,7 +320,7 @@ func (h *FileHandler) UploadFile(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// 验证文件大小
|
||||
if !isAllowedFileSize(size, fileType) {
|
||||
if !h.FileUtil.IsAllowedFileSize(size, fileType) {
|
||||
logger.WarnContext(reqCtx, "UploadFile 文件大小超出限制",
|
||||
"user_id", id,
|
||||
"file_type", fileType,
|
||||
@@ -343,8 +341,8 @@ func (h *FileHandler) UploadFile(ctx *gin.Context) {
|
||||
defer file.Close()
|
||||
|
||||
// 生成文件名和存储路径
|
||||
fileName := generateFileName(originalName)
|
||||
storagePath := getStoragePath(fileType, fileName, isPublic)
|
||||
fileName := h.FileUtil.GenerateFileName(originalName)
|
||||
storagePath := h.FileUtil.GetStoragePath(fileType, fileName, isPublic)
|
||||
|
||||
// 上传文件到存储(权限由路径和Bucket Policy控制)
|
||||
err = h.storage.Put(reqCtx, storagePath, file)
|
||||
@@ -369,7 +367,7 @@ func (h *FileHandler) UploadFile(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// 创建文件仓库
|
||||
fileRepo := model.NewFileRepo(h.db)
|
||||
fileRepo := model.NewFileRepo(h.DB)
|
||||
|
||||
// 记录到数据库
|
||||
fileModel := &model.File{
|
||||
@@ -434,7 +432,7 @@ func (h *FileHandler) GetFileURL(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// 创建文件仓库
|
||||
fileRepo := model.NewFileRepo(h.db)
|
||||
fileRepo := model.NewFileRepo(h.DB)
|
||||
|
||||
// 获取文件信息
|
||||
fileModel, err := fileRepo.GetByID(reqCtx, fileID)
|
||||
@@ -569,7 +567,7 @@ func (h *FileHandler) ServePublicFile(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// 创建文件仓库
|
||||
fileRepo := model.NewFileRepo(h.db)
|
||||
fileRepo := model.NewFileRepo(h.DB)
|
||||
|
||||
// 获取文件信息
|
||||
fileModel, err := fileRepo.GetByID(reqCtx, fileID)
|
||||
@@ -632,7 +630,7 @@ func (h *FileHandler) ListFiles(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// 创建文件仓库
|
||||
fileRepo := model.NewFileRepo(h.db)
|
||||
fileRepo := model.NewFileRepo(h.DB)
|
||||
|
||||
// 获取文件列表
|
||||
files, total, err := fileRepo.ListFiles(reqCtx, page, pageSize, fileType, usage, nil)
|
||||
@@ -678,7 +676,7 @@ func (h *FileHandler) GetFileInfo(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// 创建文件仓库
|
||||
fileRepo := model.NewFileRepo(h.db)
|
||||
fileRepo := model.NewFileRepo(h.DB)
|
||||
|
||||
// 获取文件信息
|
||||
fileModel, err := fileRepo.GetByID(reqCtx, fileID)
|
||||
@@ -696,92 +694,3 @@ func (h *FileHandler) GetFileInfo(ctx *gin.Context) {
|
||||
|
||||
response.Success(ctx, fileModel)
|
||||
}
|
||||
|
||||
// 生成文件名
|
||||
func generateFileName(originalName string) string {
|
||||
ext := filepath.Ext(originalName)
|
||||
name := fmt.Sprintf("%d_%s%s", time.Now().UnixNano(), randString(8), ext)
|
||||
return name
|
||||
}
|
||||
|
||||
// 随机字符串
|
||||
func randString(n int) string {
|
||||
const letters = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
b[i] = letters[int(time.Now().UnixNano()%int64(len(letters)))]
|
||||
time.Sleep(1 * time.Nanosecond) // 确保随机性
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// 获取存储路径
|
||||
func getStoragePath(fileType, fileName string, isPublic bool) string {
|
||||
// 根据是否公开选择根目录
|
||||
var rootDir string
|
||||
if isPublic {
|
||||
rootDir = "public"
|
||||
} else {
|
||||
rootDir = "private"
|
||||
}
|
||||
|
||||
// 根据文件类型选择子目录
|
||||
var typeDir string
|
||||
switch fileType {
|
||||
case model.FileTypeImage:
|
||||
typeDir = "images"
|
||||
case model.FileTypeDocument:
|
||||
typeDir = "documents"
|
||||
case model.FileTypeVideo:
|
||||
typeDir = "videos"
|
||||
case model.FileTypeAudio:
|
||||
typeDir = "audios"
|
||||
default:
|
||||
typeDir = "others"
|
||||
}
|
||||
|
||||
// 添加日期子目录
|
||||
dateDir := time.Now().Format("2006/01/02")
|
||||
|
||||
// 构建完整路径:public/documents/2025/06/17/filename.txt
|
||||
path := filepath.Join(rootDir, typeDir, dateDir, fileName)
|
||||
// 确保Windows上也使用正斜杠
|
||||
return strings.ReplaceAll(path, "\\", "/")
|
||||
}
|
||||
|
||||
// 检查文件类型是否允许
|
||||
func isAllowedFileType(ext string, fileType string) bool {
|
||||
ext = strings.ToLower(ext)
|
||||
switch fileType {
|
||||
case model.FileTypeImage:
|
||||
return ext == ".jpg" || ext == ".jpeg" || ext == ".png" || ext == ".gif" || ext == ".webp" || ext == ".svg"
|
||||
case model.FileTypeDocument:
|
||||
return ext == ".pdf" || ext == ".doc" || ext == ".docx" || ext == ".xls" || ext == ".xlsx" || ext == ".txt"
|
||||
case model.FileTypeVideo:
|
||||
return ext == ".mp4" || ext == ".avi" || ext == ".mov" || ext == ".wmv" || ext == ".flv"
|
||||
case model.FileTypeAudio:
|
||||
return ext == ".mp3" || ext == ".wav" || ext == ".ogg" || ext == ".flac"
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// 检查文件大小是否允许
|
||||
func isAllowedFileSize(size int64, fileType string) bool {
|
||||
const (
|
||||
MB = 1024 * 1024
|
||||
)
|
||||
|
||||
switch fileType {
|
||||
case model.FileTypeImage:
|
||||
return size <= 10*MB // 图片最大10MB
|
||||
case model.FileTypeDocument:
|
||||
return size <= 50*MB // 文档最大50MB
|
||||
case model.FileTypeVideo:
|
||||
return size <= 500*MB // 视频最大500MB
|
||||
case model.FileTypeAudio:
|
||||
return size <= 100*MB // 音频最大100MB
|
||||
default:
|
||||
return size <= 50*MB // 其他类型最大50MB
|
||||
}
|
||||
}
|
||||
|
191
internal/handler/file_util.go
Normal file
191
internal/handler/file_util.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/limitcool/starter/internal/model"
|
||||
)
|
||||
|
||||
// FileUtil 文件工具类
|
||||
type FileUtil struct {
|
||||
baseURL string // 基础URL,如 "/uploads" 或 "https://cdn.example.com"
|
||||
}
|
||||
|
||||
// NewFileUtil 创建文件工具实例
|
||||
func NewFileUtil(baseURL string) *FileUtil {
|
||||
return &FileUtil{
|
||||
baseURL: baseURL,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildFileURL 构建文件访问URL
|
||||
func (f *FileUtil) BuildFileURL(path string) string {
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 确保路径以正斜杠开头
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
|
||||
// 如果baseURL已经是完整URL,直接拼接
|
||||
if strings.HasPrefix(f.baseURL, "http") {
|
||||
return f.baseURL + path
|
||||
}
|
||||
|
||||
// 否则作为相对路径处理
|
||||
return f.baseURL + path
|
||||
}
|
||||
|
||||
// GenerateFileName 生成唯一的文件名
|
||||
func (f *FileUtil) GenerateFileName(originalName string) string {
|
||||
ext := filepath.Ext(originalName)
|
||||
name := fmt.Sprintf("%d_%s%s", time.Now().UnixNano(), randString(8), ext)
|
||||
return name
|
||||
}
|
||||
|
||||
// GetStoragePath 获取存储路径
|
||||
func (f *FileUtil) GetStoragePath(fileType, fileName string, isPublic bool) string {
|
||||
// 根据是否公开选择根目录
|
||||
var rootDir string
|
||||
if isPublic {
|
||||
rootDir = "public"
|
||||
} else {
|
||||
rootDir = "private"
|
||||
}
|
||||
|
||||
// 根据文件类型选择子目录
|
||||
var typeDir string
|
||||
switch fileType {
|
||||
case model.FileTypeImage:
|
||||
typeDir = "images"
|
||||
case model.FileTypeDocument:
|
||||
typeDir = "documents"
|
||||
case model.FileTypeVideo:
|
||||
typeDir = "videos"
|
||||
case model.FileTypeAudio:
|
||||
typeDir = "audios"
|
||||
default:
|
||||
typeDir = "others"
|
||||
}
|
||||
|
||||
// 添加日期子目录
|
||||
dateDir := time.Now().Format("2006/01/02")
|
||||
|
||||
// 构建完整路径:public/documents/2025/06/17/filename.txt
|
||||
path := filepath.Join(rootDir, typeDir, dateDir, fileName)
|
||||
// 确保Windows上也使用正斜杠
|
||||
return strings.ReplaceAll(path, "\\", "/")
|
||||
}
|
||||
|
||||
// SetFileURL 为文件模型设置URL字段
|
||||
func (f *FileUtil) SetFileURL(file *model.File) {
|
||||
if file != nil && file.Path != "" {
|
||||
file.URL = f.BuildFileURL(file.Path)
|
||||
}
|
||||
}
|
||||
|
||||
// SetFileURLs 为文件列表设置URL字段
|
||||
func (f *FileUtil) SetFileURLs(files []model.File) {
|
||||
for i := range files {
|
||||
f.SetFileURL(&files[i])
|
||||
}
|
||||
}
|
||||
|
||||
// IsAllowedFileType 检查文件类型是否允许
|
||||
func (f *FileUtil) IsAllowedFileType(ext, fileType string) bool {
|
||||
ext = strings.ToLower(ext)
|
||||
|
||||
switch fileType {
|
||||
case model.FileTypeImage:
|
||||
return isImageExt(ext)
|
||||
case model.FileTypeDocument:
|
||||
return isDocumentExt(ext)
|
||||
case model.FileTypeVideo:
|
||||
return isVideoExt(ext)
|
||||
case model.FileTypeAudio:
|
||||
return isAudioExt(ext)
|
||||
default:
|
||||
return true // 其他类型默认允许
|
||||
}
|
||||
}
|
||||
|
||||
// IsAllowedFileSize 检查文件大小是否允许
|
||||
func (f *FileUtil) IsAllowedFileSize(size int64, fileType string) bool {
|
||||
const (
|
||||
MB = 1024 * 1024
|
||||
GB = 1024 * MB
|
||||
)
|
||||
|
||||
switch fileType {
|
||||
case model.FileTypeImage:
|
||||
return size <= 10*MB // 图片最大10MB
|
||||
case model.FileTypeDocument:
|
||||
return size <= 50*MB // 文档最大50MB
|
||||
case model.FileTypeVideo:
|
||||
return size <= 500*MB // 视频最大500MB
|
||||
case model.FileTypeAudio:
|
||||
return size <= 100*MB // 音频最大100MB
|
||||
default:
|
||||
return size <= 100*MB // 其他类型最大100MB
|
||||
}
|
||||
}
|
||||
|
||||
// randString 生成随机字符串
|
||||
func randString(n int) string {
|
||||
const letters = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
b[i] = letters[int(time.Now().UnixNano()%int64(len(letters)))]
|
||||
time.Sleep(1 * time.Nanosecond) // 确保随机性
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// isImageExt 检查是否为图片扩展名
|
||||
func isImageExt(ext string) bool {
|
||||
imageExts := []string{".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp", ".svg"}
|
||||
for _, allowedExt := range imageExts {
|
||||
if ext == allowedExt {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isDocumentExt 检查是否为文档扩展名
|
||||
func isDocumentExt(ext string) bool {
|
||||
docExts := []string{".pdf", ".doc", ".docx", ".xls", ".xlsx", ".ppt", ".pptx", ".txt", ".rtf"}
|
||||
for _, allowedExt := range docExts {
|
||||
if ext == allowedExt {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isVideoExt 检查是否为视频扩展名
|
||||
func isVideoExt(ext string) bool {
|
||||
videoExts := []string{".mp4", ".avi", ".mov", ".wmv", ".flv", ".webm", ".mkv"}
|
||||
for _, allowedExt := range videoExts {
|
||||
if ext == allowedExt {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isAudioExt 检查是否为音频扩展名
|
||||
func isAudioExt(ext string) bool {
|
||||
audioExts := []string{".mp3", ".wav", ".flac", ".aac", ".ogg", ".wma"}
|
||||
for _, allowedExt := range audioExts {
|
||||
if ext == allowedExt {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
157
internal/handler/helper.go
Normal file
157
internal/handler/helper.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/limitcool/starter/internal/api/response"
|
||||
"github.com/limitcool/starter/internal/pkg/errorx"
|
||||
"github.com/limitcool/starter/internal/pkg/logger"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
// HandlerHelper 处理器辅助工具
|
||||
type HandlerHelper struct{}
|
||||
|
||||
// NewHandlerHelper 创建处理器辅助工具
|
||||
func NewHandlerHelper() *HandlerHelper {
|
||||
return &HandlerHelper{}
|
||||
}
|
||||
|
||||
// GetUserID 从上下文中获取用户ID,如果不存在则返回错误响应
|
||||
func (h *HandlerHelper) GetUserID(ctx *gin.Context) (int64, bool) {
|
||||
reqCtx := ctx.Request.Context()
|
||||
|
||||
userID, exists := ctx.Get("user_id")
|
||||
if !exists {
|
||||
logger.WarnContext(reqCtx, "用户ID不存在")
|
||||
response.Error(ctx, errorx.ErrUserNoLogin)
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return cast.ToInt64(userID), true
|
||||
}
|
||||
|
||||
// BindJSON 绑定JSON参数,如果失败则返回错误响应
|
||||
func (h *HandlerHelper) BindJSON(ctx *gin.Context, req interface{}, operation string) bool {
|
||||
reqCtx := ctx.Request.Context()
|
||||
|
||||
if err := ctx.ShouldBindJSON(req); err != nil {
|
||||
logger.WarnContext(reqCtx, operation+" request validation failed",
|
||||
"error", err,
|
||||
"client_ip", ctx.ClientIP())
|
||||
response.Error(ctx, errorx.ErrInvalidParams.WithError(err))
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// HandleDBError 处理数据库错误,统一日志记录和错误响应
|
||||
func (h *HandlerHelper) HandleDBError(ctx *gin.Context, err error, operation string, fields ...interface{}) {
|
||||
reqCtx := ctx.Request.Context()
|
||||
|
||||
// 构建日志字段
|
||||
logFields := []interface{}{
|
||||
"error", err,
|
||||
"operation", operation,
|
||||
}
|
||||
logFields = append(logFields, fields...)
|
||||
|
||||
logger.ErrorContext(reqCtx, operation+" database operation failed", logFields...)
|
||||
response.Error(ctx, err)
|
||||
}
|
||||
|
||||
// HandleNotFoundError 处理资源不存在错误
|
||||
func (h *HandlerHelper) HandleNotFoundError(ctx *gin.Context, err error, operation string, fields ...interface{}) {
|
||||
reqCtx := ctx.Request.Context()
|
||||
|
||||
// 构建日志字段
|
||||
logFields := []interface{}{
|
||||
"operation", operation,
|
||||
}
|
||||
logFields = append(logFields, fields...)
|
||||
|
||||
logger.WarnContext(reqCtx, operation+" resource not found", logFields...)
|
||||
response.Error(ctx, err)
|
||||
}
|
||||
|
||||
// LogSuccess 记录成功操作日志
|
||||
func (h *HandlerHelper) LogSuccess(ctx *gin.Context, operation string, fields ...interface{}) {
|
||||
reqCtx := ctx.Request.Context()
|
||||
|
||||
// 构建日志字段
|
||||
logFields := []interface{}{
|
||||
"operation", operation,
|
||||
}
|
||||
logFields = append(logFields, fields...)
|
||||
|
||||
logger.InfoContext(reqCtx, operation+" successful", logFields...)
|
||||
}
|
||||
|
||||
// LogWarning 记录警告日志
|
||||
func (h *HandlerHelper) LogWarning(ctx *gin.Context, message string, fields ...interface{}) {
|
||||
reqCtx := ctx.Request.Context()
|
||||
logger.WarnContext(reqCtx, message, fields...)
|
||||
}
|
||||
|
||||
// LogError 记录错误日志
|
||||
func (h *HandlerHelper) LogError(ctx *gin.Context, message string, fields ...interface{}) {
|
||||
reqCtx := ctx.Request.Context()
|
||||
logger.ErrorContext(reqCtx, message, fields...)
|
||||
}
|
||||
|
||||
// CheckPermission 检查用户权限(是否为管理员或资源所有者)
|
||||
func (h *HandlerHelper) CheckPermission(ctx *gin.Context, userID int64, resourceOwnerID int64, operation string) bool {
|
||||
reqCtx := ctx.Request.Context()
|
||||
|
||||
// 如果是资源所有者,直接允许
|
||||
if userID == resourceOwnerID {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查是否是管理员
|
||||
isAdmin, exists := ctx.Get("is_admin")
|
||||
if exists && cast.ToBool(isAdmin) {
|
||||
return true
|
||||
}
|
||||
|
||||
// 权限不足
|
||||
logger.WarnContext(reqCtx, operation+" permission denied",
|
||||
"user_id", userID,
|
||||
"resource_owner_id", resourceOwnerID,
|
||||
"is_admin", isAdmin)
|
||||
response.Error(ctx, errorx.ErrForbidden.WithMsg("无权限操作此资源"))
|
||||
return false
|
||||
}
|
||||
|
||||
// GetClientInfo 获取客户端信息
|
||||
func (h *HandlerHelper) GetClientInfo(ctx *gin.Context) (string, string) {
|
||||
return ctx.ClientIP(), ctx.Request.UserAgent()
|
||||
}
|
||||
|
||||
// ValidateID 验证ID参数
|
||||
func (h *HandlerHelper) ValidateID(ctx *gin.Context, idStr string, operation string) (uint, bool) {
|
||||
reqCtx := ctx.Request.Context()
|
||||
|
||||
id := cast.ToUint(idStr)
|
||||
if id == 0 {
|
||||
logger.WarnContext(reqCtx, operation+" invalid ID", "id", idStr)
|
||||
response.Error(ctx, errorx.ErrInvalidParams.WithMsg("ID无效"))
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return id, true
|
||||
}
|
||||
|
||||
// ValidateInt64ID 验证int64类型的ID参数
|
||||
func (h *HandlerHelper) ValidateInt64ID(ctx *gin.Context, idStr string, operation string) (int64, bool) {
|
||||
reqCtx := ctx.Request.Context()
|
||||
|
||||
id := cast.ToInt64(idStr)
|
||||
if id == 0 {
|
||||
logger.WarnContext(reqCtx, operation+" invalid ID", "id", idStr)
|
||||
response.Error(ctx, errorx.ErrInvalidParams.WithMsg("ID无效"))
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return id, true
|
||||
}
|
@@ -12,26 +12,23 @@ import (
|
||||
"github.com/limitcool/starter/internal/pkg/crypto"
|
||||
"github.com/limitcool/starter/internal/pkg/errorx"
|
||||
"github.com/limitcool/starter/internal/pkg/logger"
|
||||
"github.com/spf13/cast"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// UserHandler 用户处理器
|
||||
type UserHandler struct {
|
||||
db *gorm.DB
|
||||
config *configs.Config
|
||||
*BaseHandler
|
||||
authService *AuthService
|
||||
}
|
||||
|
||||
// NewUserHandler 创建用户处理器
|
||||
func NewUserHandler(db *gorm.DB, config *configs.Config) *UserHandler {
|
||||
handler := &UserHandler{
|
||||
db: db,
|
||||
config: config,
|
||||
BaseHandler: NewBaseHandler(db, config),
|
||||
authService: NewAuthService(config),
|
||||
}
|
||||
|
||||
logger.Info("UserHandler initialized")
|
||||
handler.LogInit("UserHandler")
|
||||
return handler
|
||||
}
|
||||
|
||||
@@ -64,7 +61,7 @@ func (h *UserHandler) UserLogin(ctx *gin.Context) {
|
||||
"ip", clientIP)
|
||||
|
||||
// 创建用户仓库
|
||||
userRepo := model.NewUserRepo(h.db)
|
||||
userRepo := model.NewUserRepo(h.DB)
|
||||
|
||||
// 查询用户
|
||||
user, err := userRepo.GetByUsername(reqCtx, req.Username)
|
||||
@@ -161,7 +158,7 @@ func (h *UserHandler) UserRegister(ctx *gin.Context) {
|
||||
clientIP := ctx.ClientIP()
|
||||
|
||||
// 创建用户仓库
|
||||
userRepo := model.NewUserRepo(h.db)
|
||||
userRepo := model.NewUserRepo(h.DB)
|
||||
|
||||
// 检查用户名是否已存在
|
||||
exists, err := userRepo.IsExist(reqCtx, req.Username)
|
||||
@@ -230,97 +227,64 @@ func (h *UserHandler) UserRegister(ctx *gin.Context) {
|
||||
|
||||
// UserInfo 获取用户信息
|
||||
func (h *UserHandler) UserInfo(ctx *gin.Context) {
|
||||
// 获取请求上下文
|
||||
reqCtx := ctx.Request.Context()
|
||||
|
||||
// 从上下文中获取用户ID
|
||||
userID, exists := ctx.Get("user_id")
|
||||
if !exists {
|
||||
logger.WarnContext(reqCtx, "UserInfo user ID not found")
|
||||
response.Error(ctx, errorx.ErrUserNoLogin)
|
||||
// 获取用户ID
|
||||
id, ok := h.Helper.GetUserID(ctx)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// 转换用户ID
|
||||
id := cast.ToInt64(userID)
|
||||
|
||||
// 创建用户仓库
|
||||
userRepo := model.NewUserRepo(h.db)
|
||||
userRepo := model.NewUserRepo(h.DB)
|
||||
|
||||
// 查询用户信息
|
||||
user, err := userRepo.GetByID(reqCtx, id)
|
||||
user, err := userRepo.GetByID(ctx.Request.Context(), id)
|
||||
if err != nil {
|
||||
if errors.Is(err, errorx.ErrUserNotFound) {
|
||||
logger.WarnContext(reqCtx, "UserInfo user not found",
|
||||
"user_id", id)
|
||||
response.Error(ctx, err)
|
||||
h.Helper.HandleNotFoundError(ctx, err, "UserInfo", "user_id", id)
|
||||
return
|
||||
}
|
||||
logger.ErrorContext(reqCtx, "UserInfo failed to query user",
|
||||
"error", err,
|
||||
"user_id", id)
|
||||
response.Error(ctx, err)
|
||||
h.Helper.HandleDBError(ctx, err, "UserInfo", "user_id", id)
|
||||
return
|
||||
}
|
||||
|
||||
// 隐藏敏感信息
|
||||
user.Password = ""
|
||||
|
||||
logger.InfoContext(reqCtx, "UserInfo get user info successful",
|
||||
"user_id", id)
|
||||
|
||||
h.Helper.LogSuccess(ctx, "UserInfo", "user_id", id)
|
||||
response.Success(ctx, user)
|
||||
}
|
||||
|
||||
// UserChangePassword 修改密码
|
||||
func (h *UserHandler) UserChangePassword(ctx *gin.Context) {
|
||||
// 获取请求上下文
|
||||
reqCtx := ctx.Request.Context()
|
||||
|
||||
// 从上下文中获取用户ID
|
||||
userID, exists := ctx.Get("user_id")
|
||||
if !exists {
|
||||
logger.WarnContext(reqCtx, "UserChangePassword user ID not found")
|
||||
response.Error(ctx, errorx.ErrUserNoLogin)
|
||||
// 获取用户ID
|
||||
id, ok := h.Helper.GetUserID(ctx)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// 转换用户ID
|
||||
id := cast.ToInt64(userID)
|
||||
|
||||
// 绑定请求参数
|
||||
var req v1.UserChangePasswordRequest
|
||||
if err := ctx.ShouldBindJSON(&req); err != nil {
|
||||
logger.WarnContext(reqCtx, "UserChangePassword request validation failed",
|
||||
"error", err,
|
||||
"user_id", id)
|
||||
response.Error(ctx, errorx.ErrInvalidParams.WithError(err))
|
||||
if !h.Helper.BindJSON(ctx, &req, "UserChangePassword") {
|
||||
return
|
||||
}
|
||||
|
||||
// 创建用户仓库
|
||||
userRepo := model.NewUserRepo(h.db)
|
||||
userRepo := model.NewUserRepo(h.DB)
|
||||
|
||||
// 查询用户
|
||||
user, err := userRepo.GetByID(reqCtx, id)
|
||||
user, err := userRepo.GetByID(ctx.Request.Context(), id)
|
||||
if err != nil {
|
||||
if errors.Is(err, errorx.ErrUserNotFound) {
|
||||
logger.WarnContext(reqCtx, "UserChangePassword user not found",
|
||||
"user_id", id)
|
||||
response.Error(ctx, err)
|
||||
h.Helper.HandleNotFoundError(ctx, err, "UserChangePassword", "user_id", id)
|
||||
return
|
||||
}
|
||||
logger.ErrorContext(reqCtx, "UserChangePassword failed to query user",
|
||||
"error", err,
|
||||
"user_id", id)
|
||||
response.Error(ctx, err)
|
||||
h.Helper.HandleDBError(ctx, err, "UserChangePassword", "user_id", id)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证旧密码
|
||||
if !crypto.CheckPassword(user.Password, req.OldPassword) {
|
||||
logger.WarnContext(reqCtx, "UserChangePassword old password incorrect",
|
||||
"user_id", id)
|
||||
h.Helper.LogWarning(ctx, "UserChangePassword old password incorrect", "user_id", id)
|
||||
response.Error(ctx, errorx.Errorf(errorx.ErrUserPasswordError, "旧密码错误"))
|
||||
return
|
||||
}
|
||||
@@ -328,24 +292,17 @@ func (h *UserHandler) UserChangePassword(ctx *gin.Context) {
|
||||
// 哈希新密码
|
||||
hashedPassword, err := crypto.HashPassword(req.NewPassword)
|
||||
if err != nil {
|
||||
logger.ErrorContext(reqCtx, "UserChangePassword failed to hash password",
|
||||
"error", err,
|
||||
"user_id", id)
|
||||
h.Helper.LogError(ctx, "UserChangePassword failed to hash password", "error", err, "user_id", id)
|
||||
response.Error(ctx, errorx.WrapError(err, "密码加密失败"))
|
||||
return
|
||||
}
|
||||
|
||||
// 更新密码
|
||||
if err := userRepo.UpdatePassword(reqCtx, id, hashedPassword); err != nil {
|
||||
logger.ErrorContext(reqCtx, "UserChangePassword failed to update password",
|
||||
"error", err,
|
||||
"user_id", id)
|
||||
response.Error(ctx, err)
|
||||
if err := userRepo.UpdatePassword(ctx.Request.Context(), id, hashedPassword); err != nil {
|
||||
h.Helper.HandleDBError(ctx, err, "UserChangePassword", "user_id", id)
|
||||
return
|
||||
}
|
||||
|
||||
logger.InfoContext(reqCtx, "UserChangePassword password change successful",
|
||||
"user_id", id)
|
||||
|
||||
h.Helper.LogSuccess(ctx, "UserChangePassword", "user_id", id)
|
||||
response.SuccessNoData(ctx, "密码修改成功")
|
||||
}
|
||||
|
@@ -11,29 +11,9 @@ import (
|
||||
// AdminCheck 管理员检查中间件 - 基于JWT中的is_admin字段
|
||||
func AdminCheck() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取请求上下文
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// 从上下文中获取用户ID
|
||||
_, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
logger.WarnContext(ctx, "AdminCheck 未找到用户ID")
|
||||
response.Error(c, errorx.ErrUserNoLogin)
|
||||
c.Abort()
|
||||
if !CheckAdminPermission(c) {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户是否为管理员
|
||||
isAdmin, ok := c.Get("is_admin")
|
||||
if !ok || !isAdmin.(bool) {
|
||||
logger.WarnContext(ctx, "AdminCheck 用户不是管理员",
|
||||
"is_admin", isAdmin)
|
||||
response.Error(c, errorx.ErrUserNoLogin.WithMsg("用户无权限"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 继续处理请求
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
@@ -1,37 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RequestContext 中间件,同时处理请求ID和链路追踪ID
|
||||
func RequestContext() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 处理请求ID
|
||||
requestID := c.GetHeader("X-Request-ID")
|
||||
if requestID == "" {
|
||||
requestID = fmt.Sprintf("req-%d", time.Now().UnixNano())
|
||||
}
|
||||
c.Set("request_id", requestID)
|
||||
c.Header("X-Request-ID", requestID)
|
||||
|
||||
// 处理链路追踪ID
|
||||
traceID := c.GetHeader("X-Trace-ID")
|
||||
if traceID == "" {
|
||||
traceID = fmt.Sprintf("trace-%d", time.Now().UnixNano())
|
||||
}
|
||||
c.Set("trace_id", traceID)
|
||||
c.Header("X-Trace-ID", traceID)
|
||||
|
||||
// 将请求ID和链路追踪ID添加到context.Context中
|
||||
ctx := context.WithValue(c.Request.Context(), "request_id", requestID)
|
||||
ctx = context.WithValue(ctx, "trace_id", traceID)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
@@ -2,36 +2,45 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/limitcool/starter/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// RequestLoggerMiddleware 是一个记录请求日志的中间件
|
||||
// RequestLoggerMiddleware 是一个记录请求日志的中间件,同时处理请求ID和链路追踪ID
|
||||
func RequestLoggerMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 记录开始时间
|
||||
start := time.Now()
|
||||
|
||||
// 获取请求ID
|
||||
// 处理请求ID
|
||||
requestID := c.GetHeader("X-Request-ID")
|
||||
if requestID == "" {
|
||||
requestID = uuid.New().String()
|
||||
c.Request.Header.Set("X-Request-ID", requestID)
|
||||
c.Set("request_id", requestID) // 同时存入上下文
|
||||
|
||||
// 将请求ID添加到context.Context中
|
||||
ctx := context.WithValue(c.Request.Context(), "request_id", requestID)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
requestID = fmt.Sprintf("req-%d", time.Now().UnixNano())
|
||||
}
|
||||
c.Set("request_id", requestID)
|
||||
c.Header("X-Request-ID", requestID)
|
||||
|
||||
// 处理链路追踪ID
|
||||
traceID := c.GetHeader("X-Trace-ID")
|
||||
if traceID == "" {
|
||||
traceID = fmt.Sprintf("trace-%d", time.Now().UnixNano())
|
||||
}
|
||||
c.Set("trace_id", traceID)
|
||||
c.Header("X-Trace-ID", traceID)
|
||||
|
||||
// 将请求ID和链路追踪ID添加到context.Context中
|
||||
ctx := context.WithValue(c.Request.Context(), "request_id", requestID)
|
||||
ctx = context.WithValue(ctx, "trace_id", traceID)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// 获取请求上下文
|
||||
ctx := c.Request.Context()
|
||||
reqCtx := c.Request.Context()
|
||||
|
||||
// 计算延迟
|
||||
latency := time.Since(start)
|
||||
@@ -59,11 +68,11 @@ func RequestLoggerMiddleware() gin.HandlerFunc {
|
||||
|
||||
// 根据状态码选择日志级别
|
||||
if status >= 500 {
|
||||
logger.ErrorContext(ctx, "Server error", fields...)
|
||||
logger.ErrorContext(reqCtx, "Server error", fields...)
|
||||
} else if status >= 400 {
|
||||
logger.WarnContext(ctx, "Client error", fields...)
|
||||
logger.WarnContext(reqCtx, "Client error", fields...)
|
||||
} else {
|
||||
logger.InfoContext(ctx, "Request completed", fields...)
|
||||
logger.InfoContext(reqCtx, "Request completed", fields...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -11,19 +11,9 @@ import (
|
||||
// UserCheck 用户检查中间件 - 确保用户已登录
|
||||
func UserCheck() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取请求上下文
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// 从上下文获取用户ID
|
||||
_, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
logger.WarnContext(ctx, "用户ID不存在")
|
||||
response.Error(c, errorx.ErrUserNoLogin)
|
||||
c.Abort()
|
||||
if !CheckUserLogin(c) {
|
||||
return
|
||||
}
|
||||
|
||||
// 继续处理请求
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -62,28 +52,21 @@ func UserCheckWithDB(userRepo *model.UserRepo) gin.HandlerFunc {
|
||||
// 适用于只允许普通用户访问的接口
|
||||
func RegularUserCheck() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取请求上下文
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// 从上下文获取用户ID
|
||||
_, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
logger.WarnContext(ctx, "用户ID不存在")
|
||||
response.Error(c, errorx.ErrUserNoLogin)
|
||||
c.Abort()
|
||||
// 先检查是否已登录
|
||||
if !CheckUserLogin(c) {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户是否为管理员
|
||||
isAdmin, ok := c.Get("is_admin")
|
||||
if ok && isAdmin.(bool) {
|
||||
ctx := c.Request.Context()
|
||||
logger.WarnContext(ctx, "管理员不能访问普通用户接口")
|
||||
response.Error(c, errorx.ErrAccessDenied)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 继续处理请求
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
@@ -4,6 +4,9 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/limitcool/starter/internal/api/response"
|
||||
"github.com/limitcool/starter/internal/pkg/errorx"
|
||||
"github.com/limitcool/starter/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// GetUserID 从上下文中获取用户ID
|
||||
@@ -64,3 +67,39 @@ func GetUserIDString(c *gin.Context) string {
|
||||
}
|
||||
return fmt.Sprintf("%d", id)
|
||||
}
|
||||
|
||||
// CheckUserLogin 检查用户是否已登录,如果未登录则返回错误响应
|
||||
func CheckUserLogin(c *gin.Context) bool {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
_, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
logger.WarnContext(ctx, "用户ID不存在")
|
||||
response.Error(c, errorx.ErrUserNoLogin)
|
||||
c.Abort()
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// CheckAdminPermission 检查用户是否为管理员,如果不是则返回错误响应
|
||||
func CheckAdminPermission(c *gin.Context) bool {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// 先检查是否已登录
|
||||
if !CheckUserLogin(c) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查用户是否为管理员
|
||||
isAdmin, ok := c.Get("is_admin")
|
||||
if !ok || !isAdmin.(bool) {
|
||||
logger.WarnContext(ctx, "用户不是管理员", "is_admin", isAdmin)
|
||||
response.Error(c, errorx.ErrUserNoLogin.WithMsg("用户无权限"))
|
||||
c.Abort()
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
@@ -54,8 +54,13 @@ func (File) TableName() string {
|
||||
|
||||
// FileRepo 文件仓库
|
||||
type FileRepo struct {
|
||||
DB *gorm.DB
|
||||
GenericRepo *GenericRepo[File]
|
||||
*GenericRepo[File]
|
||||
fileUtil FileURLBuilder // 文件URL构建器接口
|
||||
}
|
||||
|
||||
// FileURLBuilder 文件URL构建器接口
|
||||
type FileURLBuilder interface {
|
||||
BuildFileURL(path string) string
|
||||
}
|
||||
|
||||
// NewFileRepo 创建文件仓库
|
||||
@@ -64,52 +69,51 @@ func NewFileRepo(db *gorm.DB) *FileRepo {
|
||||
genericRepo.ErrorCode = errorx.ErrorFileNotFoundCode
|
||||
|
||||
return &FileRepo{
|
||||
DB: db,
|
||||
GenericRepo: genericRepo,
|
||||
fileUtil: &defaultFileURLBuilder{}, // 使用默认实现
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建文件记录
|
||||
func (r *FileRepo) Create(ctx context.Context, file *File) error {
|
||||
return r.GenericRepo.Create(ctx, file)
|
||||
// NewFileRepoWithURLBuilder 创建带有自定义URL构建器的文件仓库
|
||||
func NewFileRepoWithURLBuilder(db *gorm.DB, urlBuilder FileURLBuilder) *FileRepo {
|
||||
genericRepo := NewGenericRepo[File](db)
|
||||
genericRepo.ErrorCode = errorx.ErrorFileNotFoundCode
|
||||
|
||||
return &FileRepo{
|
||||
GenericRepo: genericRepo,
|
||||
fileUtil: urlBuilder,
|
||||
}
|
||||
}
|
||||
|
||||
// defaultFileURLBuilder 默认的文件URL构建器
|
||||
type defaultFileURLBuilder struct{}
|
||||
|
||||
func (d *defaultFileURLBuilder) BuildFileURL(path string) string {
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
return "/uploads/" + path
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取文件
|
||||
func (r *FileRepo) GetByID(ctx context.Context, id uint) (*File, error) {
|
||||
file, err := r.GenericRepo.Get(ctx, id, nil)
|
||||
file, err := r.Get(ctx, id, nil)
|
||||
if err != nil {
|
||||
return nil, errorx.WrapError(err, "查询文件失败")
|
||||
}
|
||||
|
||||
// 设置URL字段
|
||||
if file.Path != "" {
|
||||
file.URL = r.buildFileURL(file.Path)
|
||||
file.URL = r.fileUtil.BuildFileURL(file.Path)
|
||||
}
|
||||
|
||||
return file, nil
|
||||
}
|
||||
|
||||
// buildFileURL 构建文件URL
|
||||
func (r *FileRepo) buildFileURL(path string) string {
|
||||
// 这里可以根据实际情况构建URL,例如添加域名前缀等
|
||||
// 简单示例:
|
||||
return "/uploads/" + path
|
||||
}
|
||||
|
||||
// Update 更新文件记录
|
||||
func (r *FileRepo) Update(ctx context.Context, file *File) error {
|
||||
return r.GenericRepo.Update(ctx, file)
|
||||
}
|
||||
|
||||
// Delete 删除文件记录
|
||||
func (r *FileRepo) Delete(ctx context.Context, id uint) error {
|
||||
return r.GenericRepo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
// UpdateFileUsage 更新文件用途
|
||||
func (r *FileRepo) UpdateFileUsage(ctx context.Context, file *File, usage string) error {
|
||||
file.Usage = usage
|
||||
return r.GenericRepo.Update(ctx, file)
|
||||
return r.Update(ctx, file)
|
||||
}
|
||||
|
||||
// ListByUser 获取用户的文件列表
|
||||
@@ -119,7 +123,7 @@ func (r *FileRepo) ListByUser(ctx context.Context, userID int64, page, pageSize
|
||||
Condition: "uploaded_by = ? AND uploaded_by_type = ?",
|
||||
Args: []any{userID, 2},
|
||||
}
|
||||
files, err := r.GenericRepo.List(ctx, page, pageSize, opts)
|
||||
files, err := r.List(ctx, page, pageSize, opts)
|
||||
if err != nil {
|
||||
return nil, errorx.WrapError(err, "查询用户文件列表失败")
|
||||
}
|
||||
@@ -127,7 +131,7 @@ func (r *FileRepo) ListByUser(ctx context.Context, userID int64, page, pageSize
|
||||
// 为所有文件设置URL
|
||||
for i := range files {
|
||||
if files[i].Path != "" {
|
||||
files[i].URL = r.buildFileURL(files[i].Path)
|
||||
files[i].URL = r.fileUtil.BuildFileURL(files[i].Path)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,7 +140,7 @@ func (r *FileRepo) ListByUser(ctx context.Context, userID int64, page, pageSize
|
||||
|
||||
// CountByUser 获取用户的文件总数
|
||||
func (r *FileRepo) CountByUser(ctx context.Context, userID int64) (int64, error) {
|
||||
count, err := r.GenericRepo.Count(ctx, &QueryOptions{
|
||||
count, err := r.Count(ctx, &QueryOptions{
|
||||
Condition: "uploaded_by = ? AND uploaded_by_type = ?",
|
||||
Args: []any{userID, 2},
|
||||
})
|
||||
@@ -180,7 +184,7 @@ func (r *FileRepo) ListFiles(ctx context.Context, page, pageSize int, fileType,
|
||||
}
|
||||
|
||||
// 获取文件列表
|
||||
files, err := r.GenericRepo.List(ctx, page, pageSize, opts)
|
||||
files, err := r.List(ctx, page, pageSize, opts)
|
||||
if err != nil {
|
||||
return nil, 0, errorx.WrapError(err, "查询文件列表失败")
|
||||
}
|
||||
@@ -188,12 +192,12 @@ func (r *FileRepo) ListFiles(ctx context.Context, page, pageSize int, fileType,
|
||||
// 为所有文件设置URL
|
||||
for i := range files {
|
||||
if files[i].Path != "" {
|
||||
files[i].URL = r.buildFileURL(files[i].Path)
|
||||
files[i].URL = r.fileUtil.BuildFileURL(files[i].Path)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
total, err := r.GenericRepo.Count(ctx, opts)
|
||||
total, err := r.Count(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, 0, errorx.WrapError(err, "查询文件总数失败")
|
||||
}
|
||||
|
@@ -47,8 +47,7 @@ func NewUser() *User {
|
||||
|
||||
// UserRepo 用户仓库
|
||||
type UserRepo struct {
|
||||
DB *gorm.DB
|
||||
GenericRepo *GenericRepo[User]
|
||||
*GenericRepo[User]
|
||||
}
|
||||
|
||||
// NewUserRepo 创建用户仓库
|
||||
@@ -57,14 +56,13 @@ func NewUserRepo(db *gorm.DB) *UserRepo {
|
||||
genericRepo.ErrorCode = errorx.ErrorUserNotFoundCode
|
||||
|
||||
return &UserRepo{
|
||||
DB: db,
|
||||
GenericRepo: genericRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取用户
|
||||
func (r *UserRepo) GetByID(ctx context.Context, id int64) (*User, error) {
|
||||
user, err := r.GenericRepo.Get(ctx, id, nil)
|
||||
user, err := r.Get(ctx, id, nil)
|
||||
if err != nil {
|
||||
return nil, errorx.WrapError(err, "查询用户失败")
|
||||
}
|
||||
@@ -81,7 +79,7 @@ func (r *UserRepo) GetUserWithAvatar(ctx context.Context, id int64) (*User, erro
|
||||
|
||||
// 如果用户有头像,再预加载头像
|
||||
if user.AvatarFileID > 0 {
|
||||
user, err = r.GenericRepo.Get(ctx, id, &QueryOptions{
|
||||
user, err = r.Get(ctx, id, &QueryOptions{
|
||||
Preloads: []string{"AvatarFile"},
|
||||
})
|
||||
if err != nil {
|
||||
@@ -99,7 +97,7 @@ func (r *UserRepo) GetUserWithAvatar(ctx context.Context, id int64) (*User, erro
|
||||
|
||||
// GetByUsername 根据用户名获取用户
|
||||
func (r *UserRepo) GetByUsername(ctx context.Context, username string) (*User, error) {
|
||||
user, err := r.GenericRepo.Get(ctx, nil, &QueryOptions{
|
||||
user, err := r.Get(ctx, nil, &QueryOptions{
|
||||
Condition: "username = ?",
|
||||
Args: []any{username},
|
||||
})
|
||||
@@ -112,24 +110,9 @@ func (r *UserRepo) GetByUsername(ctx context.Context, username string) (*User, e
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Create 创建用户
|
||||
func (r *UserRepo) Create(ctx context.Context, user *User) error {
|
||||
return r.GenericRepo.Create(ctx, user)
|
||||
}
|
||||
|
||||
// Update 更新用户
|
||||
func (r *UserRepo) Update(ctx context.Context, user *User) error {
|
||||
return r.GenericRepo.Update(ctx, user)
|
||||
}
|
||||
|
||||
// Delete 删除用户
|
||||
func (r *UserRepo) Delete(ctx context.Context, id int64) error {
|
||||
return r.GenericRepo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
// IsExist 检查用户是否存在
|
||||
func (r *UserRepo) IsExist(ctx context.Context, username string) (bool, error) {
|
||||
count, err := r.GenericRepo.Count(ctx, &QueryOptions{
|
||||
count, err := r.Count(ctx, &QueryOptions{
|
||||
Condition: "username = ?",
|
||||
Args: []any{username},
|
||||
})
|
||||
@@ -157,7 +140,7 @@ func (r *UserRepo) ListUsers(ctx context.Context, page, pageSize int, keyword st
|
||||
}
|
||||
|
||||
// 获取用户列表
|
||||
users, err := r.GenericRepo.List(ctx, page, pageSize, opts)
|
||||
users, err := r.List(ctx, page, pageSize, opts)
|
||||
if err != nil {
|
||||
return nil, 0, errorx.WrapError(err, "查询用户列表失败")
|
||||
}
|
||||
@@ -170,7 +153,7 @@ func (r *UserRepo) ListUsers(ctx context.Context, page, pageSize int, keyword st
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
total, err := r.GenericRepo.Count(ctx, opts)
|
||||
total, err := r.Count(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, 0, errorx.WrapError(err, "查询用户总数失败")
|
||||
}
|
||||
|
@@ -1,47 +0,0 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// 运行所有测试
|
||||
func main() {
|
||||
// 获取项目根目录
|
||||
rootDir, err := os.Getwd()
|
||||
if err != nil {
|
||||
fmt.Printf("获取当前目录失败: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// 切换到项目根目录
|
||||
rootDir = filepath.Dir(rootDir)
|
||||
if err := os.Chdir(rootDir); err != nil {
|
||||
fmt.Printf("切换到项目根目录失败: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// 运行单元测试
|
||||
fmt.Println("运行单元测试...")
|
||||
cmd := exec.Command("go", "test", "./internal/pkg/casbin", "./internal/middleware", "-v")
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
fmt.Printf("单元测试失败: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// 运行集成测试
|
||||
fmt.Println("\n运行集成测试...")
|
||||
cmd = exec.Command("go", "test", "./test/integration", "-v")
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
fmt.Printf("集成测试失败: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Println("\n所有测试通过!")
|
||||
}
|
Reference in New Issue
Block a user