diff --git a/internal/app/middle/token.go b/internal/app/middle/token.go index 906559a..5e8340b 100644 --- a/internal/app/middle/token.go +++ b/internal/app/middle/token.go @@ -1,7 +1,6 @@ package middle import ( - "errors" "slices" "strings" @@ -46,36 +45,19 @@ func CheckToken() gin.HandlerFunc { return strings.HasPrefix(c.Request.URL.Path, s) }) { var token string - if c.Request.Header.Get("token") != "" { - token = c.Request.Header.Get("token") + if c.Request.Header.Get("Authorization") != "" { + token = strings.TrimPrefix(c.Request.Header.Get("Authorization"), "bearer ") } else { token = c.Query("token") } - if _, err := utils.VerifyToken(token); err != nil { + if mc, err := utils.VerifyToken(token); err != nil { rErr(c, -2, "token校验失败", err) return - } - if username, err := getUser(c); err != nil { - rErr(c, -1, "无法获取user信息", err) } else { - c.Set(eum.CtxUserName, username) - c.Set(eum.CtxRole, repository.UserRepository.GetUserByName(username).Role) + c.Set(eum.CtxUserName, mc.Username) + c.Set(eum.CtxRole, repository.UserRepository.GetUserByName(mc.Username).Role) } } c.Next() } } - -func getUser(ctx *gin.Context) (string, error) { - var token string - if ctx.Request.Header.Get("token") != "" { - token = ctx.Request.Header.Get("token") - } else { - token = ctx.Query("token") - } - if mc, err := utils.VerifyToken(token); err == nil && mc != nil { - return mc.Username, nil - } else { - return "", errors.Join(errors.New("用户信息获取失败"), err) - } -}