diff --git a/enforcer.go b/enforcer.go index a9d4257..6ff3f0b 100644 --- a/enforcer.go +++ b/enforcer.go @@ -2,6 +2,7 @@ package token_go import ( "errors" + "fmt" "github.com/weloe/token-go/config" "github.com/weloe/token-go/constant" "github.com/weloe/token-go/ctx" @@ -9,6 +10,7 @@ import ( "github.com/weloe/token-go/log" "github.com/weloe/token-go/model" "github.com/weloe/token-go/persist" + "github.com/weloe/token-go/util" "net/http" "strconv" ) @@ -23,20 +25,6 @@ type Enforcer struct { logger log.Logger } -func NewEnforcer(tokenConfig *config.TokenConfig, adapter persist.Adapter) (*Enforcer, error) { - fm := model.LoadFunctionMap() - if tokenConfig == nil || adapter == nil { - return nil, errors.New("NewEnforcer() params should be not nil") - } - return &Enforcer{ - loginType: "user", - config: *tokenConfig, - generateFunc: fm, - adapter: adapter, - logger: &log.DefaultLogger{}, - }, nil -} - func NewDefaultAdapter() persist.Adapter { return persist.NewDefaultAdapter() } @@ -45,34 +33,76 @@ func NewHttpContext(req *http.Request, writer http.ResponseWriter) ctx.Context { return httpCtx.NewHttpContext(req, writer) } -func NewDefaultEnforcer(adapter persist.Adapter) (*Enforcer, error) { - fm := model.LoadFunctionMap() - if adapter == nil { - return nil, errors.New("NewDefaultEnforcer() params should be not nil") +func NewEnforcer(args ...interface{}) (*Enforcer, error) { + var err error + var enforcer *Enforcer + if len(args) > 2 { + return nil, fmt.Errorf("NewEnforcer() failed: unexpected args length = %v, it should be less than or equal to 2", len(args)) } - return &Enforcer{ - loginType: "user", - config: *config.DefaultTokenConfig(), - generateFunc: fm, - adapter: adapter, - logger: &log.DefaultLogger{}, - }, nil + if util.HasNil(args) { + return nil, errors.New("NewEnforcer() failed: parameters cannot be nil") + } + + if len(args) == 1 { + switch args[0].(type) { + case persist.Adapter: + enforcer, err = InitWithDefaultConfig(args[0].(persist.Adapter)) + case *config.TokenConfig: + adapter, ok := args[1].(persist.Adapter) + if ok { + enforcer, err = InitWithConfig(args[0].(*config.TokenConfig), adapter) + } else { + return nil, fmt.Errorf("NewEnforcer() failed: unexpected args[1] type, it should be persist.Adapter") + } + } + } else if len(args) == 2 { + adapter, ok := args[1].(persist.Adapter) + if !ok { + return nil, errors.New("NewEnforcer() failed: unexpected args[1] type, it should be persist.Adapter") + } + switch args[0].(type) { + case *config.TokenConfig: + if ok { + enforcer, err = InitWithConfig(args[0].(*config.TokenConfig), adapter) + } else { + enforcer, err = InitWithConfig(args[0].(*config.TokenConfig), adapter) + } + case string: + enforcer, err = InitWithFile(args[0].(string), adapter) + } + } + + return enforcer, err } -func NewEnforcerByFile(conf string, adapter persist.Adapter) (*Enforcer, error) { +func InitWithDefaultConfig(adapter persist.Adapter) (*Enforcer, error) { + if adapter == nil { + return nil, errors.New("InitWithDefaultConfig() failed: parameters cannot be nil") + } + return InitWithConfig(config.DefaultTokenConfig(), adapter) +} + +func InitWithFile(conf string, adapter persist.Adapter) (*Enforcer, error) { if conf == "" || adapter == nil { - return nil, errors.New("NewEnforcerByFile() params should be not nil") + return nil, errors.New("InitWithFile() failed: parameters cannot be nil") } newConfig, err := config.NewConfig(conf) if err != nil { return nil, err } - fm := model.LoadFunctionMap() + enforcer, err := InitWithConfig(newConfig.(*config.FileConfig).TokenConfig, adapter) + enforcer.conf = conf + return enforcer, err +} +func InitWithConfig(tokenConfig *config.TokenConfig, adapter persist.Adapter) (*Enforcer, error) { + fm := model.LoadFunctionMap() + if tokenConfig == nil || adapter == nil { + return nil, errors.New("InitWithConfig() failed: parameters cannot be nil") + } return &Enforcer{ loginType: "user", - conf: conf, - config: *(newConfig.(*config.FileConfig).TokenConfig), + config: *tokenConfig, generateFunc: fm, adapter: adapter, logger: &log.DefaultLogger{}, @@ -335,7 +365,7 @@ func (e *Enforcer) Kickout(id string, device string) error { if tokenSign, ok := element.Value.(*model.TokenSign); ok { elementV := tokenSign.Value session.RemoveTokenSign(elementV) - // sign token replaced + // sign token kicked err := e.adapter.UpdateStr(e.spliceTokenKey(elementV), strconv.Itoa(constant.BeKicked)) if err != nil { return err diff --git a/enforcer_test.go b/enforcer_test.go index 5ec3294..e9fdf37 100644 --- a/enforcer_test.go +++ b/enforcer_test.go @@ -124,6 +124,7 @@ func NewTestNotConcurrentEnforcer(t *testing.T) (error, *Enforcer, ctx.Context) req.Header.Set(constant.TokenName, "233") ctx := httpCtx.NewHttpContext(req, rr) + t.Log(ctx) adapter := persist.NewDefaultAdapter() @@ -144,7 +145,7 @@ func TestNewEnforcerByFile(t *testing.T) { adapter := persist.NewDefaultAdapter() conf := "testConf" - enforcer, err := NewEnforcerByFile(conf, adapter) + enforcer, err := NewEnforcer(conf, adapter) if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -163,7 +164,7 @@ func TestEnforcer_Login(t *testing.T) { err, enforcer, ctx := NewTestEnforcer(t) enforcer.EnableLog() if err != nil { - t.Errorf("NewEnforcer() failed: %v", err) + t.Errorf("InitWithConfig() failed: %v", err) } loginId := "1" _, err = enforcer.Login(loginId, ctx) @@ -202,7 +203,7 @@ func TestEnforcer_Login(t *testing.T) { func TestEnforcer_GetLoginId(t *testing.T) { err, enforcer, ctx := NewTestEnforcer(t) if err != nil { - t.Errorf("NewEnforcer() failed: %v", err) + t.Errorf("InitWithConfig() failed: %v", err) } loginModel := model.DefaultLoginModel() loginModel.Token = "233" @@ -225,7 +226,7 @@ func TestEnforcer_GetLoginId(t *testing.T) { func TestEnforcer_Logout(t *testing.T) { err, enforcer, ctx := NewTestEnforcer(t) if err != nil { - t.Errorf("NewEnforcer() failed: %v", err) + t.Errorf("InitWithConfig() failed: %v", err) } loginModel := model.DefaultLoginModel() @@ -255,7 +256,7 @@ func TestEnforcer_Logout(t *testing.T) { func TestEnforcer_Kickout(t *testing.T) { err, enforcer, ctx := NewTestEnforcer(t) if err != nil { - t.Errorf("NewEnforcer() failed: %v", err) + t.Errorf("InitWithConfig() failed: %v", err) } loginModel := model.DefaultLoginModel() @@ -288,7 +289,7 @@ func TestEnforcer_Kickout(t *testing.T) { func TestEnforcerNotConcurrentNotShareLogin(t *testing.T) { err, enforcer, ctx := NewTestNotConcurrentEnforcer(t) if err != nil { - t.Errorf("NewEnforcer() failed: %v", err) + t.Errorf("InitWithConfig() failed: %v", err) } loginModel := model.DefaultLoginModel() @@ -309,7 +310,7 @@ func TestEnforcerNotConcurrentNotShareLogin(t *testing.T) { func TestEnforcer_ConcurrentShare(t *testing.T) { err, enforcer, ctx := NewTestEnforcer(t) if err != nil { - t.Errorf("NewEnforcer() failed: %v", err) + t.Errorf("InitWithConfig() failed: %v", err) } loginModel := model.DefaultLoginModel() @@ -328,7 +329,7 @@ func TestEnforcer_ConcurrentShare(t *testing.T) { func TestEnforcer_ConcurrentNotShareMultiLogin(t *testing.T) { err, enforcer, ctx := NewTestConcurrentEnforcer(t) if err != nil { - t.Errorf("NewEnforcer() failed: %v", err) + t.Errorf("InitWithConfig() failed: %v", err) } loginModel := model.DefaultLoginModel() @@ -354,8 +355,18 @@ func TestNewDefaultEnforcer(t *testing.T) { t.Errorf("NewTestHttpContext() failed: %v", err) } - enforcer, err := NewDefaultEnforcer(persist.NewDefaultAdapter()) + enforcer, err := NewEnforcer(persist.NewDefaultAdapter()) + if err != nil || enforcer == nil { - t.Errorf("NewEnforcer() failed: %v", err) + t.Errorf("InitWithConfig() failed: %v", err) } } + +func TestNewEnforcer1(t *testing.T) { + enforcer, err := NewEnforcer(NewDefaultAdapter()) + t.Log(err) + t.Log(enforcer) + enforcer, err = NewEnforcer(config.DefaultTokenConfig(), NewDefaultAdapter()) + t.Log(err) + t.Log(enforcer) +} diff --git a/util/util.go b/util/util.go new file mode 100644 index 0000000..c85c3ad --- /dev/null +++ b/util/util.go @@ -0,0 +1,10 @@ +package util + +func HasNil(arr []interface{}) bool { + for _, elem := range arr { + if elem == nil { + return true + } + } + return false +}