mirror of
https://github.com/weloe/token-go.git
synced 2025-10-04 23:22:41 +08:00
refactor: improve NewEnforcer()
This commit is contained in:
92
enforcer.go
92
enforcer.go
@@ -2,6 +2,7 @@ package token_go
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/weloe/token-go/config"
|
||||
"github.com/weloe/token-go/constant"
|
||||
"github.com/weloe/token-go/ctx"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"github.com/weloe/token-go/log"
|
||||
"github.com/weloe/token-go/model"
|
||||
"github.com/weloe/token-go/persist"
|
||||
"github.com/weloe/token-go/util"
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
@@ -23,20 +25,6 @@ type Enforcer struct {
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
func NewEnforcer(tokenConfig *config.TokenConfig, adapter persist.Adapter) (*Enforcer, error) {
|
||||
fm := model.LoadFunctionMap()
|
||||
if tokenConfig == nil || adapter == nil {
|
||||
return nil, errors.New("NewEnforcer() params should be not nil")
|
||||
}
|
||||
return &Enforcer{
|
||||
loginType: "user",
|
||||
config: *tokenConfig,
|
||||
generateFunc: fm,
|
||||
adapter: adapter,
|
||||
logger: &log.DefaultLogger{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewDefaultAdapter() persist.Adapter {
|
||||
return persist.NewDefaultAdapter()
|
||||
}
|
||||
@@ -45,34 +33,76 @@ func NewHttpContext(req *http.Request, writer http.ResponseWriter) ctx.Context {
|
||||
return httpCtx.NewHttpContext(req, writer)
|
||||
}
|
||||
|
||||
func NewDefaultEnforcer(adapter persist.Adapter) (*Enforcer, error) {
|
||||
fm := model.LoadFunctionMap()
|
||||
if adapter == nil {
|
||||
return nil, errors.New("NewDefaultEnforcer() params should be not nil")
|
||||
func NewEnforcer(args ...interface{}) (*Enforcer, error) {
|
||||
var err error
|
||||
var enforcer *Enforcer
|
||||
if len(args) > 2 {
|
||||
return nil, fmt.Errorf("NewEnforcer() failed: unexpected args length = %v, it should be less than or equal to 2", len(args))
|
||||
}
|
||||
return &Enforcer{
|
||||
loginType: "user",
|
||||
config: *config.DefaultTokenConfig(),
|
||||
generateFunc: fm,
|
||||
adapter: adapter,
|
||||
logger: &log.DefaultLogger{},
|
||||
}, nil
|
||||
if util.HasNil(args) {
|
||||
return nil, errors.New("NewEnforcer() failed: parameters cannot be nil")
|
||||
}
|
||||
|
||||
func NewEnforcerByFile(conf string, adapter persist.Adapter) (*Enforcer, error) {
|
||||
if len(args) == 1 {
|
||||
switch args[0].(type) {
|
||||
case persist.Adapter:
|
||||
enforcer, err = InitWithDefaultConfig(args[0].(persist.Adapter))
|
||||
case *config.TokenConfig:
|
||||
adapter, ok := args[1].(persist.Adapter)
|
||||
if ok {
|
||||
enforcer, err = InitWithConfig(args[0].(*config.TokenConfig), adapter)
|
||||
} else {
|
||||
return nil, fmt.Errorf("NewEnforcer() failed: unexpected args[1] type, it should be persist.Adapter")
|
||||
}
|
||||
}
|
||||
} else if len(args) == 2 {
|
||||
adapter, ok := args[1].(persist.Adapter)
|
||||
if !ok {
|
||||
return nil, errors.New("NewEnforcer() failed: unexpected args[1] type, it should be persist.Adapter")
|
||||
}
|
||||
switch args[0].(type) {
|
||||
case *config.TokenConfig:
|
||||
if ok {
|
||||
enforcer, err = InitWithConfig(args[0].(*config.TokenConfig), adapter)
|
||||
} else {
|
||||
enforcer, err = InitWithConfig(args[0].(*config.TokenConfig), adapter)
|
||||
}
|
||||
case string:
|
||||
enforcer, err = InitWithFile(args[0].(string), adapter)
|
||||
}
|
||||
}
|
||||
|
||||
return enforcer, err
|
||||
}
|
||||
|
||||
func InitWithDefaultConfig(adapter persist.Adapter) (*Enforcer, error) {
|
||||
if adapter == nil {
|
||||
return nil, errors.New("InitWithDefaultConfig() failed: parameters cannot be nil")
|
||||
}
|
||||
return InitWithConfig(config.DefaultTokenConfig(), adapter)
|
||||
}
|
||||
|
||||
func InitWithFile(conf string, adapter persist.Adapter) (*Enforcer, error) {
|
||||
if conf == "" || adapter == nil {
|
||||
return nil, errors.New("NewEnforcerByFile() params should be not nil")
|
||||
return nil, errors.New("InitWithFile() failed: parameters cannot be nil")
|
||||
}
|
||||
newConfig, err := config.NewConfig(conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fm := model.LoadFunctionMap()
|
||||
enforcer, err := InitWithConfig(newConfig.(*config.FileConfig).TokenConfig, adapter)
|
||||
enforcer.conf = conf
|
||||
return enforcer, err
|
||||
}
|
||||
|
||||
func InitWithConfig(tokenConfig *config.TokenConfig, adapter persist.Adapter) (*Enforcer, error) {
|
||||
fm := model.LoadFunctionMap()
|
||||
if tokenConfig == nil || adapter == nil {
|
||||
return nil, errors.New("InitWithConfig() failed: parameters cannot be nil")
|
||||
}
|
||||
return &Enforcer{
|
||||
loginType: "user",
|
||||
conf: conf,
|
||||
config: *(newConfig.(*config.FileConfig).TokenConfig),
|
||||
config: *tokenConfig,
|
||||
generateFunc: fm,
|
||||
adapter: adapter,
|
||||
logger: &log.DefaultLogger{},
|
||||
@@ -335,7 +365,7 @@ func (e *Enforcer) Kickout(id string, device string) error {
|
||||
if tokenSign, ok := element.Value.(*model.TokenSign); ok {
|
||||
elementV := tokenSign.Value
|
||||
session.RemoveTokenSign(elementV)
|
||||
// sign token replaced
|
||||
// sign token kicked
|
||||
err := e.adapter.UpdateStr(e.spliceTokenKey(elementV), strconv.Itoa(constant.BeKicked))
|
||||
if err != nil {
|
||||
return err
|
||||
|
@@ -124,6 +124,7 @@ func NewTestNotConcurrentEnforcer(t *testing.T) (error, *Enforcer, ctx.Context)
|
||||
req.Header.Set(constant.TokenName, "233")
|
||||
|
||||
ctx := httpCtx.NewHttpContext(req, rr)
|
||||
t.Log(ctx)
|
||||
|
||||
adapter := persist.NewDefaultAdapter()
|
||||
|
||||
@@ -144,7 +145,7 @@ func TestNewEnforcerByFile(t *testing.T) {
|
||||
adapter := persist.NewDefaultAdapter()
|
||||
conf := "testConf"
|
||||
|
||||
enforcer, err := NewEnforcerByFile(conf, adapter)
|
||||
enforcer, err := NewEnforcer(conf, adapter)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
@@ -163,7 +164,7 @@ func TestEnforcer_Login(t *testing.T) {
|
||||
err, enforcer, ctx := NewTestEnforcer(t)
|
||||
enforcer.EnableLog()
|
||||
if err != nil {
|
||||
t.Errorf("NewEnforcer() failed: %v", err)
|
||||
t.Errorf("InitWithConfig() failed: %v", err)
|
||||
}
|
||||
loginId := "1"
|
||||
_, err = enforcer.Login(loginId, ctx)
|
||||
@@ -202,7 +203,7 @@ func TestEnforcer_Login(t *testing.T) {
|
||||
func TestEnforcer_GetLoginId(t *testing.T) {
|
||||
err, enforcer, ctx := NewTestEnforcer(t)
|
||||
if err != nil {
|
||||
t.Errorf("NewEnforcer() failed: %v", err)
|
||||
t.Errorf("InitWithConfig() failed: %v", err)
|
||||
}
|
||||
loginModel := model.DefaultLoginModel()
|
||||
loginModel.Token = "233"
|
||||
@@ -225,7 +226,7 @@ func TestEnforcer_GetLoginId(t *testing.T) {
|
||||
func TestEnforcer_Logout(t *testing.T) {
|
||||
err, enforcer, ctx := NewTestEnforcer(t)
|
||||
if err != nil {
|
||||
t.Errorf("NewEnforcer() failed: %v", err)
|
||||
t.Errorf("InitWithConfig() failed: %v", err)
|
||||
}
|
||||
|
||||
loginModel := model.DefaultLoginModel()
|
||||
@@ -255,7 +256,7 @@ func TestEnforcer_Logout(t *testing.T) {
|
||||
func TestEnforcer_Kickout(t *testing.T) {
|
||||
err, enforcer, ctx := NewTestEnforcer(t)
|
||||
if err != nil {
|
||||
t.Errorf("NewEnforcer() failed: %v", err)
|
||||
t.Errorf("InitWithConfig() failed: %v", err)
|
||||
}
|
||||
|
||||
loginModel := model.DefaultLoginModel()
|
||||
@@ -288,7 +289,7 @@ func TestEnforcer_Kickout(t *testing.T) {
|
||||
func TestEnforcerNotConcurrentNotShareLogin(t *testing.T) {
|
||||
err, enforcer, ctx := NewTestNotConcurrentEnforcer(t)
|
||||
if err != nil {
|
||||
t.Errorf("NewEnforcer() failed: %v", err)
|
||||
t.Errorf("InitWithConfig() failed: %v", err)
|
||||
}
|
||||
|
||||
loginModel := model.DefaultLoginModel()
|
||||
@@ -309,7 +310,7 @@ func TestEnforcerNotConcurrentNotShareLogin(t *testing.T) {
|
||||
func TestEnforcer_ConcurrentShare(t *testing.T) {
|
||||
err, enforcer, ctx := NewTestEnforcer(t)
|
||||
if err != nil {
|
||||
t.Errorf("NewEnforcer() failed: %v", err)
|
||||
t.Errorf("InitWithConfig() failed: %v", err)
|
||||
}
|
||||
|
||||
loginModel := model.DefaultLoginModel()
|
||||
@@ -328,7 +329,7 @@ func TestEnforcer_ConcurrentShare(t *testing.T) {
|
||||
func TestEnforcer_ConcurrentNotShareMultiLogin(t *testing.T) {
|
||||
err, enforcer, ctx := NewTestConcurrentEnforcer(t)
|
||||
if err != nil {
|
||||
t.Errorf("NewEnforcer() failed: %v", err)
|
||||
t.Errorf("InitWithConfig() failed: %v", err)
|
||||
}
|
||||
|
||||
loginModel := model.DefaultLoginModel()
|
||||
@@ -354,8 +355,18 @@ func TestNewDefaultEnforcer(t *testing.T) {
|
||||
t.Errorf("NewTestHttpContext() failed: %v", err)
|
||||
}
|
||||
|
||||
enforcer, err := NewDefaultEnforcer(persist.NewDefaultAdapter())
|
||||
enforcer, err := NewEnforcer(persist.NewDefaultAdapter())
|
||||
|
||||
if err != nil || enforcer == nil {
|
||||
t.Errorf("NewEnforcer() failed: %v", err)
|
||||
t.Errorf("InitWithConfig() failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewEnforcer1(t *testing.T) {
|
||||
enforcer, err := NewEnforcer(NewDefaultAdapter())
|
||||
t.Log(err)
|
||||
t.Log(enforcer)
|
||||
enforcer, err = NewEnforcer(config.DefaultTokenConfig(), NewDefaultAdapter())
|
||||
t.Log(err)
|
||||
t.Log(enforcer)
|
||||
}
|
||||
|
10
util/util.go
Normal file
10
util/util.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package util
|
||||
|
||||
func HasNil(arr []interface{}) bool {
|
||||
for _, elem := range arr {
|
||||
if elem == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
Reference in New Issue
Block a user