mirror of
https://github.com/weloe/token-go.git
synced 2025-09-26 19:41:21 +08:00
670 lines
16 KiB
Go
670 lines
16 KiB
Go
package token_go
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"github.com/weloe/token-go/config"
|
|
"github.com/weloe/token-go/ctx"
|
|
httpCtx "github.com/weloe/token-go/ctx/go-http-context"
|
|
"github.com/weloe/token-go/log"
|
|
"github.com/weloe/token-go/model"
|
|
"github.com/weloe/token-go/persist"
|
|
"github.com/weloe/token-go/util"
|
|
log2 "log"
|
|
"net/http"
|
|
)
|
|
|
|
type Enforcer struct {
|
|
conf string
|
|
loginType string
|
|
config config.TokenConfig
|
|
generateFunc model.GenerateTokenFunc
|
|
adapter persist.Adapter
|
|
watcher persist.Watcher
|
|
logger log.Logger
|
|
|
|
dispatcher persist.Dispatcher
|
|
notifyDispatcher bool
|
|
|
|
updatableWatcher persist.UpdatableWatcher
|
|
notifyUpdatableWatcher bool
|
|
|
|
authManager interface{}
|
|
}
|
|
|
|
func (e *Enforcer) EnableUpdatableWatcher(b bool) {
|
|
if e.updatableWatcher == nil {
|
|
return
|
|
}
|
|
e.notifyUpdatableWatcher = b
|
|
}
|
|
|
|
func NewDefaultAdapter() persist.Adapter {
|
|
return persist.NewDefaultAdapter()
|
|
}
|
|
|
|
func NewHttpContext(req *http.Request, writer http.ResponseWriter) ctx.Context {
|
|
return httpCtx.NewHttpContext(req, writer)
|
|
}
|
|
|
|
func NewEnforcer(adapter persist.Adapter, args ...interface{}) (*Enforcer, error) {
|
|
var err error
|
|
var enforcer *Enforcer
|
|
if len(args) > 2 {
|
|
return nil, fmt.Errorf("NewEnforcer() failed: unexpected args length = %v, it should be less than or equal to 2", len(args))
|
|
}
|
|
if util.HasNil(args) {
|
|
return nil, errors.New("NewEnforcer() failed: parameters cannot be nil")
|
|
}
|
|
|
|
if len(args) == 0 {
|
|
enforcer, err = InitWithDefaultConfig(adapter)
|
|
} else if len(args) == 1 {
|
|
switch args[0].(type) {
|
|
case *config.TokenConfig:
|
|
enforcer, err = InitWithConfig(args[0].(*config.TokenConfig), adapter)
|
|
case string:
|
|
enforcer, err = InitWithFile(args[0].(string), adapter)
|
|
default:
|
|
return nil, errors.New("NewEnforcer() failed: the second parameter should be *TokenConfig or string")
|
|
}
|
|
}
|
|
|
|
return enforcer, err
|
|
}
|
|
|
|
func InitWithDefaultConfig(adapter persist.Adapter) (*Enforcer, error) {
|
|
if adapter == nil {
|
|
return nil, errors.New("InitWithDefaultConfig() failed: parameters cannot be nil")
|
|
}
|
|
return InitWithConfig(config.DefaultTokenConfig(), adapter)
|
|
}
|
|
|
|
func InitWithFile(conf string, adapter persist.Adapter) (*Enforcer, error) {
|
|
if conf == "" || adapter == nil {
|
|
return nil, errors.New("InitWithFile() failed: parameters cannot be nil")
|
|
}
|
|
newConfig, err := config.ReadConfig(conf)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
enforcer, err := InitWithConfig(newConfig.TokenConfig, adapter)
|
|
enforcer.conf = conf
|
|
return enforcer, err
|
|
}
|
|
|
|
func InitWithConfig(tokenConfig *config.TokenConfig, adapter persist.Adapter) (*Enforcer, error) {
|
|
fm := model.LoadFunctionMap()
|
|
if tokenConfig == nil || adapter == nil {
|
|
return nil, errors.New("InitWithConfig() failed: parameters cannot be nil")
|
|
}
|
|
tokenConfig.InitConfig()
|
|
e := &Enforcer{
|
|
loginType: "user",
|
|
config: *tokenConfig,
|
|
generateFunc: fm,
|
|
adapter: adapter,
|
|
logger: &log.DefaultLogger{},
|
|
}
|
|
|
|
e.startCleanTimer()
|
|
|
|
return e, nil
|
|
}
|
|
|
|
// if e.adapter.(type) == *persist.DefaultAdapter, can start cleanTimer
|
|
func (e *Enforcer) startCleanTimer() {
|
|
defaultAdapter, ok := e.adapter.(*persist.DefaultAdapter)
|
|
if ok {
|
|
if !defaultAdapter.GetCleanTimer() {
|
|
return
|
|
}
|
|
dataRefreshPeriod := e.config.DataRefreshPeriod
|
|
if period := dataRefreshPeriod; period > 0 {
|
|
err := defaultAdapter.StartCleanTimer(period)
|
|
if err != nil {
|
|
log2.Printf("enble adapter cleanTimer failed: %v", err)
|
|
return
|
|
}
|
|
e.logger.StartCleanTimer(period)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (e *Enforcer) SetType(t string) {
|
|
e.loginType = t
|
|
}
|
|
|
|
func (e *Enforcer) GetType() string {
|
|
return e.loginType
|
|
}
|
|
|
|
func (e *Enforcer) GetAdapter() persist.Adapter {
|
|
return e.adapter
|
|
}
|
|
|
|
func (e *Enforcer) SetAdapter(adapter persist.Adapter) {
|
|
e.adapter = adapter
|
|
}
|
|
|
|
func (e *Enforcer) GetWatcher() persist.Watcher {
|
|
return e.watcher
|
|
}
|
|
|
|
func (e *Enforcer) SetWatcher(watcher persist.Watcher) {
|
|
e.watcher = watcher
|
|
}
|
|
|
|
func (e *Enforcer) SetUpdatableWatcher(watcher persist.UpdatableWatcher) {
|
|
if watcher != nil {
|
|
e.updatableWatcher = watcher
|
|
e.notifyUpdatableWatcher = true
|
|
}
|
|
}
|
|
|
|
func (e *Enforcer) GetUpdatableWatcher() persist.UpdatableWatcher {
|
|
return e.updatableWatcher
|
|
}
|
|
|
|
func (e *Enforcer) GetLogger() log.Logger {
|
|
return e.logger
|
|
}
|
|
|
|
func (e *Enforcer) SetLogger(logger log.Logger) {
|
|
e.logger = logger
|
|
}
|
|
|
|
func (e *Enforcer) EnableLog() {
|
|
e.logger.Enable(true)
|
|
}
|
|
|
|
func (e *Enforcer) IsLogEnable() bool {
|
|
return e.logger.IsEnabled()
|
|
}
|
|
|
|
// Login login by id and default loginModel, return tokenValue and error
|
|
func (e *Enforcer) Login(id string, ctx ...ctx.Context) (string, error) {
|
|
return e.LoginByModel(id, model.CreateLoginModelByDevice(""), ctx...)
|
|
}
|
|
|
|
func (e *Enforcer) LoginById(id string, device ...string) (string, error) {
|
|
if len(device) > 0 && device[0] != "" {
|
|
return e.LoginByModel(id, model.CreateLoginModelByDevice(device[0]), nil)
|
|
}
|
|
|
|
return e.Login(id, nil)
|
|
}
|
|
|
|
// LoginByModel login by id and loginModel, return tokenValue and error
|
|
func (e *Enforcer) LoginByModel(id string, loginModel *model.Login, c ...ctx.Context) (string, error) {
|
|
if loginModel == nil {
|
|
return "", errors.New("arg loginModel can not be nil")
|
|
}
|
|
if len(c) == 0 {
|
|
c = []ctx.Context{nil}
|
|
}
|
|
var err error
|
|
var session *model.Session
|
|
var tokenValue string
|
|
tokenConfig := e.config
|
|
|
|
// allocate token
|
|
tokenValue, err = e.createLoginToken(id, loginModel)
|
|
device := loginModel.Device
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// add tokenSign
|
|
if session = e.GetSession(id); session == nil {
|
|
session = model.NewSession("0", "account-session", id)
|
|
}
|
|
session.AddDistinctValueTokenSign(&model.TokenSign{
|
|
Value: tokenValue,
|
|
Device: device,
|
|
})
|
|
|
|
if e.config.DoubleToken {
|
|
refreshToken, err := e.createRefreshToken(id, tokenValue, loginModel)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
err = e.responseRefreshToken(refreshToken, loginModel, c[0])
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
|
|
timeout := loginModel.Timeout
|
|
// reset session
|
|
err = e.SetSession(id, session, timeout)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// set token-id
|
|
err = e.SetIdByToken(id, tokenValue, timeout)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// response token
|
|
err = e.ResponseToken(tokenValue, loginModel, c[0])
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// called watcher
|
|
m := &model.Login{
|
|
Device: device,
|
|
IsLastingCookie: loginModel.IsLastingCookie,
|
|
Timeout: timeout,
|
|
JwtData: loginModel.JwtData,
|
|
Token: tokenValue,
|
|
}
|
|
|
|
// called logger
|
|
e.logger.Login(e.loginType, id, tokenValue, m)
|
|
|
|
if e.watcher != nil {
|
|
e.watcher.Login(e.loginType, id, tokenValue, m)
|
|
}
|
|
|
|
if device != "" && tokenConfig.DeviceMaxLoginCount != -1 {
|
|
if session = e.GetSession(id); session != nil {
|
|
// get by login device
|
|
tokenSignList := session.GetFilterTokenSignSlice(device)
|
|
if len(tokenSignList) > int(tokenConfig.DeviceMaxLoginCount) {
|
|
err = e.deleteRedundantTokenSign(session, tokenConfig.DeviceMaxLoginCount)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// check if the number of sessions for this account exceeds the maximum limit.
|
|
if tokenConfig.MaxLoginCount != -1 {
|
|
if session = e.GetSession(id); session != nil {
|
|
if session.TokenSignSize() <= int(tokenConfig.MaxLoginCount) {
|
|
return tokenValue, nil
|
|
}
|
|
err = e.deleteRedundantTokenSign(session, tokenConfig.MaxLoginCount)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
}
|
|
|
|
return tokenValue, nil
|
|
}
|
|
|
|
// Logout user logout
|
|
func (e *Enforcer) Logout(ctx ctx.Context) error {
|
|
tokenConfig := e.config
|
|
|
|
token := e.GetRequestToken(ctx)
|
|
if token == "" {
|
|
return errors.New("logout() failed: token doesn't exist")
|
|
}
|
|
if e.config.IsReadCookie {
|
|
ctx.Response().DeleteCookie(tokenConfig.TokenName,
|
|
tokenConfig.CookieConfig.Path,
|
|
tokenConfig.CookieConfig.Domain)
|
|
}
|
|
|
|
err := e.LogoutByToken(token)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// LogoutById force user to logout
|
|
func (e *Enforcer) LogoutById(id string, device ...string) error {
|
|
session := e.GetSession(id)
|
|
if session != nil {
|
|
for _, tokenSign := range session.TokenSignList {
|
|
if len(device) > 0 && device[0] != "" && tokenSign.Device == device[0] {
|
|
err := e.LogoutByToken(tokenSign.Value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
err := e.LogoutByToken(tokenSign.Value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// LogoutByToken clear token info
|
|
func (e *Enforcer) LogoutByToken(token string) error {
|
|
var err error
|
|
// delete token-id
|
|
id := e.getIdByToken(token)
|
|
if id == "" {
|
|
return errors.New("user not logged in")
|
|
}
|
|
// delete token-id
|
|
err = e.deleteIdByToken(token)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
session := e.GetSession(id)
|
|
if session != nil {
|
|
// delete tokenSign
|
|
session.RemoveTokenSign(token)
|
|
err = e.UpdateSession(id, session)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
// check TokenSignList length, if length == 0, delete this session
|
|
if session != nil && session.TokenSignSize() == 0 {
|
|
err = e.DeleteSession(id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
err = e.deleteRefreshToken(token)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
e.logger.Logout(e.loginType, id, token)
|
|
|
|
if e.watcher != nil {
|
|
e.watcher.Logout(e.loginType, id, token)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// IsLoginById check if user logged in by loginId.
|
|
// check all tokenValue and if one is validated return true
|
|
func (e *Enforcer) IsLoginById(id string, device ...string) (bool, error) {
|
|
var err error
|
|
session := e.GetSession(id)
|
|
if session != nil {
|
|
var l []*model.TokenSign
|
|
if len(device) > 0 && device[0] != "" {
|
|
l = session.GetFilterTokenSignSlice(device[0])
|
|
} else {
|
|
l = session.TokenSignList
|
|
}
|
|
for _, tokenSign := range l {
|
|
err = e.CheckLoginByToken(tokenSign.Value)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
return true, nil
|
|
}
|
|
}
|
|
|
|
return false, err
|
|
}
|
|
|
|
// GetId get the id from the Adapter, do not check the value
|
|
// if GetId()= -4, it means that user be replaced
|
|
// if GetId()= -5, it means that user be kicked
|
|
// if GetId()= -6, it means that user be banned
|
|
func (e *Enforcer) GetId(ctx ctx.Context) string {
|
|
token := e.GetRequestToken(ctx)
|
|
return e.GetIdByToken(token)
|
|
}
|
|
|
|
// GetIdByToken get the id from the Adapter
|
|
func (e *Enforcer) GetIdByToken(token string) string {
|
|
if token == "" {
|
|
return ""
|
|
}
|
|
loginId := e.getIdByToken(token)
|
|
return loginId
|
|
}
|
|
|
|
// IsLogin check if user logged in by token.
|
|
func (e *Enforcer) IsLogin(ctx ctx.Context) (bool, error) {
|
|
tokenValue := e.GetRequestToken(ctx)
|
|
return e.IsLoginByToken(tokenValue)
|
|
}
|
|
|
|
func (e *Enforcer) IsLoginByToken(tokenValue string) (bool, error) {
|
|
if tokenValue == "" {
|
|
return false, nil
|
|
}
|
|
|
|
err := e.CheckLoginByToken(tokenValue)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
func (e *Enforcer) CheckLogin(ctx ctx.Context) error {
|
|
_, err := e.GetLoginId(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (e *Enforcer) CheckLoginByToken(token string) error {
|
|
_, err := e.GetLoginIdByToken(token)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetLoginId get id and check it
|
|
func (e *Enforcer) GetLoginId(ctx ctx.Context) (string, error) {
|
|
tokenValue := e.GetRequestToken(ctx)
|
|
return e.GetLoginIdByToken(tokenValue)
|
|
}
|
|
|
|
func (e *Enforcer) GetLoginIdByToken(token string) (string, error) {
|
|
str := e.GetIdByToken(token)
|
|
if str == "" {
|
|
return "", errors.New("GetLoginId() failed: not logged in")
|
|
}
|
|
validate, err := e.checkId(str)
|
|
if !validate {
|
|
return "", err
|
|
}
|
|
// auto refresh timeout, When the user accesses
|
|
if e.config.AutoRenew {
|
|
_ = e.updateTokenTimeout(token, e.config.Timeout)
|
|
_ = e.UpdateSessionTimeout(str, e.config.Timeout)
|
|
}
|
|
return str, nil
|
|
}
|
|
|
|
func (e *Enforcer) GetLoginCount(id string, device ...string) int {
|
|
if session := e.GetSession(id); session != nil {
|
|
if len(device) > 0 && device[0] != "" {
|
|
return session.GetFilterTokenSign(device[0]).Len()
|
|
}
|
|
|
|
return session.TokenSignSize()
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// GetBannedTime get banned time
|
|
func (e *Enforcer) GetBannedTime(id string, service string) int64 {
|
|
timeout := e.getBannedTime(id, service)
|
|
return timeout
|
|
}
|
|
|
|
// GetRequestToken read token from requestHeader | cookie | requestBody
|
|
func (e *Enforcer) GetRequestToken(ctx ctx.Context) string {
|
|
var tokenValue string
|
|
if ctx == nil {
|
|
return ""
|
|
}
|
|
if e.config.IsReadHeader {
|
|
if tokenValue = ctx.Request().Header(e.config.TokenName); tokenValue != "" {
|
|
return tokenValue
|
|
}
|
|
}
|
|
if e.config.IsReadCookie {
|
|
if tokenValue = ctx.Request().Cookie(e.config.TokenName); tokenValue != "" {
|
|
return tokenValue
|
|
}
|
|
|
|
}
|
|
if e.config.IsReadBody {
|
|
if tokenValue = ctx.Request().PostForm(e.config.TokenName); tokenValue != "" {
|
|
return tokenValue
|
|
}
|
|
}
|
|
return tokenValue
|
|
}
|
|
|
|
// AddTokenGenerateFun add token generate strategy
|
|
func (e *Enforcer) AddTokenGenerateFun(tokenStyle string, f model.HandlerFunc) error {
|
|
e.generateFunc.AddFunc(tokenStyle, f)
|
|
return nil
|
|
}
|
|
|
|
func (e *Enforcer) GetSession(id string) *model.Session {
|
|
if v := e.adapter.Get(e.spliceSessionKey(id), util.GetType(&model.Session{})); v != nil {
|
|
return v.(*model.Session)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (e *Enforcer) SetSession(id string, session *model.Session, timeout int64) error {
|
|
err := e.notifySet(e.spliceSessionKey(id), session, timeout)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (e *Enforcer) DeleteSession(id string) error {
|
|
err := e.notifyDelete(e.spliceSessionKey(id))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (e *Enforcer) UpdateSession(id string, session *model.Session) error {
|
|
err := e.notifyUpdate(e.spliceSessionKey(id), session)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (e *Enforcer) UpdateSessionTimeout(id string, timeout int64) error {
|
|
err := e.notifyUpdateTimeout(id, timeout)
|
|
return err
|
|
}
|
|
|
|
func (e *Enforcer) GetTokenConfig() config.TokenConfig {
|
|
return e.config
|
|
}
|
|
|
|
func (e *Enforcer) GetLoginCounts() (int, error) {
|
|
adapter, ok := e.adapter.(persist.BatchAdapter)
|
|
if !ok {
|
|
return 0, fmt.Errorf("the adapter does not implement persist.BatchAdapter")
|
|
}
|
|
c, err := adapter.GetCountsFilteredKey(e.spliceSessionKey(""))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return c, nil
|
|
}
|
|
|
|
func (e *Enforcer) GetLoginTokenCounts() (int, error) {
|
|
adapter, ok := e.adapter.(persist.BatchAdapter)
|
|
if !ok {
|
|
return 0, fmt.Errorf("the adapter does not implement persist.BatchAdapter")
|
|
}
|
|
c, err := adapter.GetCountsFilteredKey(e.spliceTokenKey(""))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return c, nil
|
|
}
|
|
|
|
func (e *Enforcer) GetRefreshToken(tokenValue string) string {
|
|
return e.getRefreshTokenValue(tokenValue)
|
|
}
|
|
|
|
func (e *Enforcer) RefreshToken(refreshToken string, refreshModel ...*model.Refresh) (*model.RefreshRes, error) {
|
|
var m *model.Refresh
|
|
if len(refreshModel) != 0 {
|
|
m = refreshModel[0]
|
|
} else {
|
|
m = model.DefaultRefresh()
|
|
}
|
|
return e.RefreshTokenByModel(refreshToken, m, nil)
|
|
}
|
|
|
|
func (e *Enforcer) RefreshTokenByModel(refreshToken string, refreshModel *model.Refresh, ctx ...ctx.Context) (*model.RefreshRes, error) {
|
|
if refreshModel == nil {
|
|
return nil, errors.New("arg refreshModel can not be nil")
|
|
}
|
|
if !e.config.DoubleToken {
|
|
return nil, fmt.Errorf("double tokens are not enabled")
|
|
}
|
|
refreshTokenSign := e.getRefreshTokenSign(refreshToken)
|
|
if refreshTokenSign == nil {
|
|
return nil, fmt.Errorf("the refresh token does not exist: %v", refreshToken)
|
|
}
|
|
err := e.deleteRefreshToken(refreshTokenSign.Token)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
login := &model.Login{
|
|
Device: refreshTokenSign.Device,
|
|
IsLastingCookie: refreshModel.IsLastingCookie,
|
|
Timeout: refreshModel.Timeout,
|
|
JwtData: refreshModel.JwtData,
|
|
Token: refreshModel.Token,
|
|
RefreshToken: refreshModel.RefreshToken,
|
|
RefreshTokenTimeout: refreshModel.RefreshTokenTimeout,
|
|
}
|
|
|
|
token, err := e.LoginByModel(refreshTokenSign.Id, login, ctx...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &model.RefreshRes{
|
|
Token: token,
|
|
RefreshToken: refreshToken,
|
|
}, nil
|
|
}
|
|
|
|
func (e *Enforcer) GetLoginDevices(id string) []string {
|
|
session := e.GetSession(id)
|
|
if session == nil {
|
|
return nil
|
|
}
|
|
return session.GetAllDevice()
|
|
}
|
|
|
|
func (e *Enforcer) GetDeviceByToken(token string) string {
|
|
id := e.getIdByToken(token)
|
|
session := e.GetSession(id)
|
|
if session == nil {
|
|
return ""
|
|
}
|
|
tokenSign := session.GetTokenSign(token)
|
|
if tokenSign == nil {
|
|
return ""
|
|
}
|
|
return tokenSign.Device
|
|
}
|