Files
token-go/enforcer.go
2023-12-04 00:38:27 +08:00

670 lines
16 KiB
Go

package token_go
import (
"errors"
"fmt"
"github.com/weloe/token-go/config"
"github.com/weloe/token-go/ctx"
httpCtx "github.com/weloe/token-go/ctx/go-http-context"
"github.com/weloe/token-go/log"
"github.com/weloe/token-go/model"
"github.com/weloe/token-go/persist"
"github.com/weloe/token-go/util"
log2 "log"
"net/http"
)
type Enforcer struct {
conf string
loginType string
config config.TokenConfig
generateFunc model.GenerateTokenFunc
adapter persist.Adapter
watcher persist.Watcher
logger log.Logger
dispatcher persist.Dispatcher
notifyDispatcher bool
updatableWatcher persist.UpdatableWatcher
notifyUpdatableWatcher bool
authManager interface{}
}
func (e *Enforcer) EnableUpdatableWatcher(b bool) {
if e.updatableWatcher == nil {
return
}
e.notifyUpdatableWatcher = b
}
func NewDefaultAdapter() persist.Adapter {
return persist.NewDefaultAdapter()
}
func NewHttpContext(req *http.Request, writer http.ResponseWriter) ctx.Context {
return httpCtx.NewHttpContext(req, writer)
}
func NewEnforcer(adapter persist.Adapter, args ...interface{}) (*Enforcer, error) {
var err error
var enforcer *Enforcer
if len(args) > 2 {
return nil, fmt.Errorf("NewEnforcer() failed: unexpected args length = %v, it should be less than or equal to 2", len(args))
}
if util.HasNil(args) {
return nil, errors.New("NewEnforcer() failed: parameters cannot be nil")
}
if len(args) == 0 {
enforcer, err = InitWithDefaultConfig(adapter)
} else if len(args) == 1 {
switch args[0].(type) {
case *config.TokenConfig:
enforcer, err = InitWithConfig(args[0].(*config.TokenConfig), adapter)
case string:
enforcer, err = InitWithFile(args[0].(string), adapter)
default:
return nil, errors.New("NewEnforcer() failed: the second parameter should be *TokenConfig or string")
}
}
return enforcer, err
}
func InitWithDefaultConfig(adapter persist.Adapter) (*Enforcer, error) {
if adapter == nil {
return nil, errors.New("InitWithDefaultConfig() failed: parameters cannot be nil")
}
return InitWithConfig(config.DefaultTokenConfig(), adapter)
}
func InitWithFile(conf string, adapter persist.Adapter) (*Enforcer, error) {
if conf == "" || adapter == nil {
return nil, errors.New("InitWithFile() failed: parameters cannot be nil")
}
newConfig, err := config.ReadConfig(conf)
if err != nil {
return nil, err
}
enforcer, err := InitWithConfig(newConfig.TokenConfig, adapter)
enforcer.conf = conf
return enforcer, err
}
func InitWithConfig(tokenConfig *config.TokenConfig, adapter persist.Adapter) (*Enforcer, error) {
fm := model.LoadFunctionMap()
if tokenConfig == nil || adapter == nil {
return nil, errors.New("InitWithConfig() failed: parameters cannot be nil")
}
tokenConfig.InitConfig()
e := &Enforcer{
loginType: "user",
config: *tokenConfig,
generateFunc: fm,
adapter: adapter,
logger: &log.DefaultLogger{},
}
e.startCleanTimer()
return e, nil
}
// if e.adapter.(type) == *persist.DefaultAdapter, can start cleanTimer
func (e *Enforcer) startCleanTimer() {
defaultAdapter, ok := e.adapter.(*persist.DefaultAdapter)
if ok {
if !defaultAdapter.GetCleanTimer() {
return
}
dataRefreshPeriod := e.config.DataRefreshPeriod
if period := dataRefreshPeriod; period > 0 {
err := defaultAdapter.StartCleanTimer(period)
if err != nil {
log2.Printf("enble adapter cleanTimer failed: %v", err)
return
}
e.logger.StartCleanTimer(period)
}
}
}
func (e *Enforcer) SetType(t string) {
e.loginType = t
}
func (e *Enforcer) GetType() string {
return e.loginType
}
func (e *Enforcer) GetAdapter() persist.Adapter {
return e.adapter
}
func (e *Enforcer) SetAdapter(adapter persist.Adapter) {
e.adapter = adapter
}
func (e *Enforcer) GetWatcher() persist.Watcher {
return e.watcher
}
func (e *Enforcer) SetWatcher(watcher persist.Watcher) {
e.watcher = watcher
}
func (e *Enforcer) SetUpdatableWatcher(watcher persist.UpdatableWatcher) {
if watcher != nil {
e.updatableWatcher = watcher
e.notifyUpdatableWatcher = true
}
}
func (e *Enforcer) GetUpdatableWatcher() persist.UpdatableWatcher {
return e.updatableWatcher
}
func (e *Enforcer) GetLogger() log.Logger {
return e.logger
}
func (e *Enforcer) SetLogger(logger log.Logger) {
e.logger = logger
}
func (e *Enforcer) EnableLog() {
e.logger.Enable(true)
}
func (e *Enforcer) IsLogEnable() bool {
return e.logger.IsEnabled()
}
// Login login by id and default loginModel, return tokenValue and error
func (e *Enforcer) Login(id string, ctx ...ctx.Context) (string, error) {
return e.LoginByModel(id, model.CreateLoginModelByDevice(""), ctx...)
}
func (e *Enforcer) LoginById(id string, device ...string) (string, error) {
if len(device) > 0 && device[0] != "" {
return e.LoginByModel(id, model.CreateLoginModelByDevice(device[0]), nil)
}
return e.Login(id, nil)
}
// LoginByModel login by id and loginModel, return tokenValue and error
func (e *Enforcer) LoginByModel(id string, loginModel *model.Login, c ...ctx.Context) (string, error) {
if loginModel == nil {
return "", errors.New("arg loginModel can not be nil")
}
if len(c) == 0 {
c = []ctx.Context{nil}
}
var err error
var session *model.Session
var tokenValue string
tokenConfig := e.config
// allocate token
tokenValue, err = e.createLoginToken(id, loginModel)
device := loginModel.Device
if err != nil {
return "", err
}
// add tokenSign
if session = e.GetSession(id); session == nil {
session = model.NewSession("0", "account-session", id)
}
session.AddDistinctValueTokenSign(&model.TokenSign{
Value: tokenValue,
Device: device,
})
if e.config.DoubleToken {
refreshToken, err := e.createRefreshToken(id, tokenValue, loginModel)
if err != nil {
return "", err
}
err = e.responseRefreshToken(refreshToken, loginModel, c[0])
if err != nil {
return "", err
}
}
timeout := loginModel.Timeout
// reset session
err = e.SetSession(id, session, timeout)
if err != nil {
return "", err
}
// set token-id
err = e.SetIdByToken(id, tokenValue, timeout)
if err != nil {
return "", err
}
// response token
err = e.ResponseToken(tokenValue, loginModel, c[0])
if err != nil {
return "", err
}
// called watcher
m := &model.Login{
Device: device,
IsLastingCookie: loginModel.IsLastingCookie,
Timeout: timeout,
JwtData: loginModel.JwtData,
Token: tokenValue,
}
// called logger
e.logger.Login(e.loginType, id, tokenValue, m)
if e.watcher != nil {
e.watcher.Login(e.loginType, id, tokenValue, m)
}
if device != "" && tokenConfig.DeviceMaxLoginCount != -1 {
if session = e.GetSession(id); session != nil {
// get by login device
tokenSignList := session.GetFilterTokenSignSlice(device)
if len(tokenSignList) > int(tokenConfig.DeviceMaxLoginCount) {
err = e.deleteRedundantTokenSign(session, tokenConfig.DeviceMaxLoginCount)
if err != nil {
return "", err
}
}
}
}
// check if the number of sessions for this account exceeds the maximum limit.
if tokenConfig.MaxLoginCount != -1 {
if session = e.GetSession(id); session != nil {
if session.TokenSignSize() <= int(tokenConfig.MaxLoginCount) {
return tokenValue, nil
}
err = e.deleteRedundantTokenSign(session, tokenConfig.MaxLoginCount)
if err != nil {
return "", err
}
}
}
return tokenValue, nil
}
// Logout user logout
func (e *Enforcer) Logout(ctx ctx.Context) error {
tokenConfig := e.config
token := e.GetRequestToken(ctx)
if token == "" {
return errors.New("logout() failed: token doesn't exist")
}
if e.config.IsReadCookie {
ctx.Response().DeleteCookie(tokenConfig.TokenName,
tokenConfig.CookieConfig.Path,
tokenConfig.CookieConfig.Domain)
}
err := e.LogoutByToken(token)
if err != nil {
return err
}
return nil
}
// LogoutById force user to logout
func (e *Enforcer) LogoutById(id string, device ...string) error {
session := e.GetSession(id)
if session != nil {
for _, tokenSign := range session.TokenSignList {
if len(device) > 0 && device[0] != "" && tokenSign.Device == device[0] {
err := e.LogoutByToken(tokenSign.Value)
if err != nil {
return err
}
} else {
err := e.LogoutByToken(tokenSign.Value)
if err != nil {
return err
}
}
}
}
return nil
}
// LogoutByToken clear token info
func (e *Enforcer) LogoutByToken(token string) error {
var err error
// delete token-id
id := e.getIdByToken(token)
if id == "" {
return errors.New("user not logged in")
}
// delete token-id
err = e.deleteIdByToken(token)
if err != nil {
return err
}
session := e.GetSession(id)
if session != nil {
// delete tokenSign
session.RemoveTokenSign(token)
err = e.UpdateSession(id, session)
if err != nil {
return err
}
}
// check TokenSignList length, if length == 0, delete this session
if session != nil && session.TokenSignSize() == 0 {
err = e.DeleteSession(id)
if err != nil {
return err
}
}
err = e.deleteRefreshToken(token)
if err != nil {
return err
}
e.logger.Logout(e.loginType, id, token)
if e.watcher != nil {
e.watcher.Logout(e.loginType, id, token)
}
return nil
}
// IsLoginById check if user logged in by loginId.
// check all tokenValue and if one is validated return true
func (e *Enforcer) IsLoginById(id string, device ...string) (bool, error) {
var err error
session := e.GetSession(id)
if session != nil {
var l []*model.TokenSign
if len(device) > 0 && device[0] != "" {
l = session.GetFilterTokenSignSlice(device[0])
} else {
l = session.TokenSignList
}
for _, tokenSign := range l {
err = e.CheckLoginByToken(tokenSign.Value)
if err != nil {
continue
}
return true, nil
}
}
return false, err
}
// GetId get the id from the Adapter, do not check the value
// if GetId()= -4, it means that user be replaced
// if GetId()= -5, it means that user be kicked
// if GetId()= -6, it means that user be banned
func (e *Enforcer) GetId(ctx ctx.Context) string {
token := e.GetRequestToken(ctx)
return e.GetIdByToken(token)
}
// GetIdByToken get the id from the Adapter
func (e *Enforcer) GetIdByToken(token string) string {
if token == "" {
return ""
}
loginId := e.getIdByToken(token)
return loginId
}
// IsLogin check if user logged in by token.
func (e *Enforcer) IsLogin(ctx ctx.Context) (bool, error) {
tokenValue := e.GetRequestToken(ctx)
return e.IsLoginByToken(tokenValue)
}
func (e *Enforcer) IsLoginByToken(tokenValue string) (bool, error) {
if tokenValue == "" {
return false, nil
}
err := e.CheckLoginByToken(tokenValue)
if err != nil {
return false, err
}
return true, nil
}
func (e *Enforcer) CheckLogin(ctx ctx.Context) error {
_, err := e.GetLoginId(ctx)
if err != nil {
return err
}
return nil
}
func (e *Enforcer) CheckLoginByToken(token string) error {
_, err := e.GetLoginIdByToken(token)
if err != nil {
return err
}
return nil
}
// GetLoginId get id and check it
func (e *Enforcer) GetLoginId(ctx ctx.Context) (string, error) {
tokenValue := e.GetRequestToken(ctx)
return e.GetLoginIdByToken(tokenValue)
}
func (e *Enforcer) GetLoginIdByToken(token string) (string, error) {
str := e.GetIdByToken(token)
if str == "" {
return "", errors.New("GetLoginId() failed: not logged in")
}
validate, err := e.checkId(str)
if !validate {
return "", err
}
// auto refresh timeout, When the user accesses
if e.config.AutoRenew {
_ = e.updateTokenTimeout(token, e.config.Timeout)
_ = e.UpdateSessionTimeout(str, e.config.Timeout)
}
return str, nil
}
func (e *Enforcer) GetLoginCount(id string, device ...string) int {
if session := e.GetSession(id); session != nil {
if len(device) > 0 && device[0] != "" {
return session.GetFilterTokenSign(device[0]).Len()
}
return session.TokenSignSize()
}
return 0
}
// GetBannedTime get banned time
func (e *Enforcer) GetBannedTime(id string, service string) int64 {
timeout := e.getBannedTime(id, service)
return timeout
}
// GetRequestToken read token from requestHeader | cookie | requestBody
func (e *Enforcer) GetRequestToken(ctx ctx.Context) string {
var tokenValue string
if ctx == nil {
return ""
}
if e.config.IsReadHeader {
if tokenValue = ctx.Request().Header(e.config.TokenName); tokenValue != "" {
return tokenValue
}
}
if e.config.IsReadCookie {
if tokenValue = ctx.Request().Cookie(e.config.TokenName); tokenValue != "" {
return tokenValue
}
}
if e.config.IsReadBody {
if tokenValue = ctx.Request().PostForm(e.config.TokenName); tokenValue != "" {
return tokenValue
}
}
return tokenValue
}
// AddTokenGenerateFun add token generate strategy
func (e *Enforcer) AddTokenGenerateFun(tokenStyle string, f model.HandlerFunc) error {
e.generateFunc.AddFunc(tokenStyle, f)
return nil
}
func (e *Enforcer) GetSession(id string) *model.Session {
if v := e.adapter.Get(e.spliceSessionKey(id), util.GetType(&model.Session{})); v != nil {
return v.(*model.Session)
}
return nil
}
func (e *Enforcer) SetSession(id string, session *model.Session, timeout int64) error {
err := e.notifySet(e.spliceSessionKey(id), session, timeout)
if err != nil {
return err
}
return nil
}
func (e *Enforcer) DeleteSession(id string) error {
err := e.notifyDelete(e.spliceSessionKey(id))
if err != nil {
return err
}
return nil
}
func (e *Enforcer) UpdateSession(id string, session *model.Session) error {
err := e.notifyUpdate(e.spliceSessionKey(id), session)
if err != nil {
return err
}
return nil
}
func (e *Enforcer) UpdateSessionTimeout(id string, timeout int64) error {
err := e.notifyUpdateTimeout(id, timeout)
return err
}
func (e *Enforcer) GetTokenConfig() config.TokenConfig {
return e.config
}
func (e *Enforcer) GetLoginCounts() (int, error) {
adapter, ok := e.adapter.(persist.BatchAdapter)
if !ok {
return 0, fmt.Errorf("the adapter does not implement persist.BatchAdapter")
}
c, err := adapter.GetCountsFilteredKey(e.spliceSessionKey(""))
if err != nil {
return 0, err
}
return c, nil
}
func (e *Enforcer) GetLoginTokenCounts() (int, error) {
adapter, ok := e.adapter.(persist.BatchAdapter)
if !ok {
return 0, fmt.Errorf("the adapter does not implement persist.BatchAdapter")
}
c, err := adapter.GetCountsFilteredKey(e.spliceTokenKey(""))
if err != nil {
return 0, err
}
return c, nil
}
func (e *Enforcer) GetRefreshToken(tokenValue string) string {
return e.getRefreshTokenValue(tokenValue)
}
func (e *Enforcer) RefreshToken(refreshToken string, refreshModel ...*model.Refresh) (*model.RefreshRes, error) {
var m *model.Refresh
if len(refreshModel) != 0 {
m = refreshModel[0]
} else {
m = model.DefaultRefresh()
}
return e.RefreshTokenByModel(refreshToken, m, nil)
}
func (e *Enforcer) RefreshTokenByModel(refreshToken string, refreshModel *model.Refresh, ctx ...ctx.Context) (*model.RefreshRes, error) {
if refreshModel == nil {
return nil, errors.New("arg refreshModel can not be nil")
}
if !e.config.DoubleToken {
return nil, fmt.Errorf("double tokens are not enabled")
}
refreshTokenSign := e.getRefreshTokenSign(refreshToken)
if refreshTokenSign == nil {
return nil, fmt.Errorf("the refresh token does not exist: %v", refreshToken)
}
err := e.deleteRefreshToken(refreshTokenSign.Token)
if err != nil {
return nil, err
}
login := &model.Login{
Device: refreshTokenSign.Device,
IsLastingCookie: refreshModel.IsLastingCookie,
Timeout: refreshModel.Timeout,
JwtData: refreshModel.JwtData,
Token: refreshModel.Token,
RefreshToken: refreshModel.RefreshToken,
RefreshTokenTimeout: refreshModel.RefreshTokenTimeout,
}
token, err := e.LoginByModel(refreshTokenSign.Id, login, ctx...)
if err != nil {
return nil, err
}
return &model.RefreshRes{
Token: token,
RefreshToken: refreshToken,
}, nil
}
func (e *Enforcer) GetLoginDevices(id string) []string {
session := e.GetSession(id)
if session == nil {
return nil
}
return session.GetAllDevice()
}
func (e *Enforcer) GetDeviceByToken(token string) string {
id := e.getIdByToken(token)
session := e.GetSession(id)
if session == nil {
return ""
}
tokenSign := session.GetTokenSign(token)
if tokenSign == nil {
return ""
}
return tokenSign.Device
}