refactor: improve NewEnforcer()

This commit is contained in:
weloe
2023-05-06 23:24:53 +08:00
parent bbcdb0fedb
commit 25f85c7114
3 changed files with 92 additions and 41 deletions

View File

@@ -2,6 +2,7 @@ package token_go
import (
"errors"
"fmt"
"github.com/weloe/token-go/config"
"github.com/weloe/token-go/constant"
"github.com/weloe/token-go/ctx"
@@ -9,6 +10,7 @@ import (
"github.com/weloe/token-go/log"
"github.com/weloe/token-go/model"
"github.com/weloe/token-go/persist"
"github.com/weloe/token-go/util"
"net/http"
"strconv"
)
@@ -23,20 +25,6 @@ type Enforcer struct {
logger log.Logger
}
func NewEnforcer(tokenConfig *config.TokenConfig, adapter persist.Adapter) (*Enforcer, error) {
fm := model.LoadFunctionMap()
if tokenConfig == nil || adapter == nil {
return nil, errors.New("NewEnforcer() params should be not nil")
}
return &Enforcer{
loginType: "user",
config: *tokenConfig,
generateFunc: fm,
adapter: adapter,
logger: &log.DefaultLogger{},
}, nil
}
func NewDefaultAdapter() persist.Adapter {
return persist.NewDefaultAdapter()
}
@@ -45,34 +33,76 @@ func NewHttpContext(req *http.Request, writer http.ResponseWriter) ctx.Context {
return httpCtx.NewHttpContext(req, writer)
}
func NewDefaultEnforcer(adapter persist.Adapter) (*Enforcer, error) {
fm := model.LoadFunctionMap()
if adapter == nil {
return nil, errors.New("NewDefaultEnforcer() params should be not nil")
func NewEnforcer(args ...interface{}) (*Enforcer, error) {
var err error
var enforcer *Enforcer
if len(args) > 2 {
return nil, fmt.Errorf("NewEnforcer() failed: unexpected args length = %v, it should be less than or equal to 2", len(args))
}
return &Enforcer{
loginType: "user",
config: *config.DefaultTokenConfig(),
generateFunc: fm,
adapter: adapter,
logger: &log.DefaultLogger{},
}, nil
if util.HasNil(args) {
return nil, errors.New("NewEnforcer() failed: parameters cannot be nil")
}
func NewEnforcerByFile(conf string, adapter persist.Adapter) (*Enforcer, error) {
if len(args) == 1 {
switch args[0].(type) {
case persist.Adapter:
enforcer, err = InitWithDefaultConfig(args[0].(persist.Adapter))
case *config.TokenConfig:
adapter, ok := args[1].(persist.Adapter)
if ok {
enforcer, err = InitWithConfig(args[0].(*config.TokenConfig), adapter)
} else {
return nil, fmt.Errorf("NewEnforcer() failed: unexpected args[1] type, it should be persist.Adapter")
}
}
} else if len(args) == 2 {
adapter, ok := args[1].(persist.Adapter)
if !ok {
return nil, errors.New("NewEnforcer() failed: unexpected args[1] type, it should be persist.Adapter")
}
switch args[0].(type) {
case *config.TokenConfig:
if ok {
enforcer, err = InitWithConfig(args[0].(*config.TokenConfig), adapter)
} else {
enforcer, err = InitWithConfig(args[0].(*config.TokenConfig), adapter)
}
case string:
enforcer, err = InitWithFile(args[0].(string), adapter)
}
}
return enforcer, err
}
func InitWithDefaultConfig(adapter persist.Adapter) (*Enforcer, error) {
if adapter == nil {
return nil, errors.New("InitWithDefaultConfig() failed: parameters cannot be nil")
}
return InitWithConfig(config.DefaultTokenConfig(), adapter)
}
func InitWithFile(conf string, adapter persist.Adapter) (*Enforcer, error) {
if conf == "" || adapter == nil {
return nil, errors.New("NewEnforcerByFile() params should be not nil")
return nil, errors.New("InitWithFile() failed: parameters cannot be nil")
}
newConfig, err := config.NewConfig(conf)
if err != nil {
return nil, err
}
fm := model.LoadFunctionMap()
enforcer, err := InitWithConfig(newConfig.(*config.FileConfig).TokenConfig, adapter)
enforcer.conf = conf
return enforcer, err
}
func InitWithConfig(tokenConfig *config.TokenConfig, adapter persist.Adapter) (*Enforcer, error) {
fm := model.LoadFunctionMap()
if tokenConfig == nil || adapter == nil {
return nil, errors.New("InitWithConfig() failed: parameters cannot be nil")
}
return &Enforcer{
loginType: "user",
conf: conf,
config: *(newConfig.(*config.FileConfig).TokenConfig),
config: *tokenConfig,
generateFunc: fm,
adapter: adapter,
logger: &log.DefaultLogger{},
@@ -335,7 +365,7 @@ func (e *Enforcer) Kickout(id string, device string) error {
if tokenSign, ok := element.Value.(*model.TokenSign); ok {
elementV := tokenSign.Value
session.RemoveTokenSign(elementV)
// sign token replaced
// sign token kicked
err := e.adapter.UpdateStr(e.spliceTokenKey(elementV), strconv.Itoa(constant.BeKicked))
if err != nil {
return err

View File

@@ -124,6 +124,7 @@ func NewTestNotConcurrentEnforcer(t *testing.T) (error, *Enforcer, ctx.Context)
req.Header.Set(constant.TokenName, "233")
ctx := httpCtx.NewHttpContext(req, rr)
t.Log(ctx)
adapter := persist.NewDefaultAdapter()
@@ -144,7 +145,7 @@ func TestNewEnforcerByFile(t *testing.T) {
adapter := persist.NewDefaultAdapter()
conf := "testConf"
enforcer, err := NewEnforcerByFile(conf, adapter)
enforcer, err := NewEnforcer(conf, adapter)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
@@ -163,7 +164,7 @@ func TestEnforcer_Login(t *testing.T) {
err, enforcer, ctx := NewTestEnforcer(t)
enforcer.EnableLog()
if err != nil {
t.Errorf("NewEnforcer() failed: %v", err)
t.Errorf("InitWithConfig() failed: %v", err)
}
loginId := "1"
_, err = enforcer.Login(loginId, ctx)
@@ -202,7 +203,7 @@ func TestEnforcer_Login(t *testing.T) {
func TestEnforcer_GetLoginId(t *testing.T) {
err, enforcer, ctx := NewTestEnforcer(t)
if err != nil {
t.Errorf("NewEnforcer() failed: %v", err)
t.Errorf("InitWithConfig() failed: %v", err)
}
loginModel := model.DefaultLoginModel()
loginModel.Token = "233"
@@ -225,7 +226,7 @@ func TestEnforcer_GetLoginId(t *testing.T) {
func TestEnforcer_Logout(t *testing.T) {
err, enforcer, ctx := NewTestEnforcer(t)
if err != nil {
t.Errorf("NewEnforcer() failed: %v", err)
t.Errorf("InitWithConfig() failed: %v", err)
}
loginModel := model.DefaultLoginModel()
@@ -255,7 +256,7 @@ func TestEnforcer_Logout(t *testing.T) {
func TestEnforcer_Kickout(t *testing.T) {
err, enforcer, ctx := NewTestEnforcer(t)
if err != nil {
t.Errorf("NewEnforcer() failed: %v", err)
t.Errorf("InitWithConfig() failed: %v", err)
}
loginModel := model.DefaultLoginModel()
@@ -288,7 +289,7 @@ func TestEnforcer_Kickout(t *testing.T) {
func TestEnforcerNotConcurrentNotShareLogin(t *testing.T) {
err, enforcer, ctx := NewTestNotConcurrentEnforcer(t)
if err != nil {
t.Errorf("NewEnforcer() failed: %v", err)
t.Errorf("InitWithConfig() failed: %v", err)
}
loginModel := model.DefaultLoginModel()
@@ -309,7 +310,7 @@ func TestEnforcerNotConcurrentNotShareLogin(t *testing.T) {
func TestEnforcer_ConcurrentShare(t *testing.T) {
err, enforcer, ctx := NewTestEnforcer(t)
if err != nil {
t.Errorf("NewEnforcer() failed: %v", err)
t.Errorf("InitWithConfig() failed: %v", err)
}
loginModel := model.DefaultLoginModel()
@@ -328,7 +329,7 @@ func TestEnforcer_ConcurrentShare(t *testing.T) {
func TestEnforcer_ConcurrentNotShareMultiLogin(t *testing.T) {
err, enforcer, ctx := NewTestConcurrentEnforcer(t)
if err != nil {
t.Errorf("NewEnforcer() failed: %v", err)
t.Errorf("InitWithConfig() failed: %v", err)
}
loginModel := model.DefaultLoginModel()
@@ -354,8 +355,18 @@ func TestNewDefaultEnforcer(t *testing.T) {
t.Errorf("NewTestHttpContext() failed: %v", err)
}
enforcer, err := NewDefaultEnforcer(persist.NewDefaultAdapter())
enforcer, err := NewEnforcer(persist.NewDefaultAdapter())
if err != nil || enforcer == nil {
t.Errorf("NewEnforcer() failed: %v", err)
t.Errorf("InitWithConfig() failed: %v", err)
}
}
func TestNewEnforcer1(t *testing.T) {
enforcer, err := NewEnforcer(NewDefaultAdapter())
t.Log(err)
t.Log(enforcer)
enforcer, err = NewEnforcer(config.DefaultTokenConfig(), NewDefaultAdapter())
t.Log(err)
t.Log(enforcer)
}

10
util/util.go Normal file
View File

@@ -0,0 +1,10 @@
package util
func HasNil(arr []interface{}) bool {
for _, elem := range arr {
if elem == nil {
return true
}
}
return false
}