diff --git a/Readme.md b/Readme.md index 6a52715..6e672b9 100644 --- a/Readme.md +++ b/Readme.md @@ -99,7 +99,7 @@ func main() { IsShare: true, MaxLoginCount: -1, } - enforcer, err = tokenGo.NewEnforcer(tokenConfig, adapter) + enforcer, err = tokenGo.NewEnforcer(adapter, tokenConfig) } ``` diff --git a/enforcer.go b/enforcer.go index 17f801c..e219a17 100644 --- a/enforcer.go +++ b/enforcer.go @@ -34,7 +34,7 @@ func NewHttpContext(req *http.Request, writer http.ResponseWriter) ctx.Context { return httpCtx.NewHttpContext(req, writer) } -func NewEnforcer(args ...interface{}) (*Enforcer, error) { +func NewEnforcer(adapter persist.Adapter, args ...interface{}) (*Enforcer, error) { var err error var enforcer *Enforcer if len(args) > 2 { @@ -44,32 +44,16 @@ func NewEnforcer(args ...interface{}) (*Enforcer, error) { return nil, errors.New("NewEnforcer() failed: parameters cannot be nil") } - if len(args) == 1 { - switch args[0].(type) { - case persist.Adapter: - enforcer, err = InitWithDefaultConfig(args[0].(persist.Adapter)) - case *config.TokenConfig: - adapter, ok := args[1].(persist.Adapter) - if ok { - enforcer, err = InitWithConfig(args[0].(*config.TokenConfig), adapter) - } else { - return nil, fmt.Errorf("NewEnforcer() failed: unexpected args[1] type, it should be persist.Adapter") - } - } - } else if len(args) == 2 { - adapter, ok := args[1].(persist.Adapter) - if !ok { - return nil, errors.New("NewEnforcer() failed: unexpected args[1] type, it should be persist.Adapter") - } + if len(args) == 0 { + enforcer, err = InitWithDefaultConfig(adapter) + } else if len(args) == 1 { switch args[0].(type) { case *config.TokenConfig: - if ok { - enforcer, err = InitWithConfig(args[0].(*config.TokenConfig), adapter) - } else { - enforcer, err = InitWithConfig(args[0].(*config.TokenConfig), adapter) - } + enforcer, err = InitWithConfig(args[0].(*config.TokenConfig), adapter) case string: enforcer, err = InitWithFile(args[0].(string), adapter) + default: + return nil, errors.New("NewEnforcer() failed: the second parameter should be *TokenConfig or string") } } diff --git a/enforcer_test.go b/enforcer_test.go index 6ba5d9f..1996305 100644 --- a/enforcer_test.go +++ b/enforcer_test.go @@ -48,7 +48,7 @@ func TestNewEnforcer(t *testing.T) { } logger := &log.DefaultLogger{} - enforcer, err := NewEnforcer(tokenConfig, adapter) + enforcer, err := NewEnforcer(adapter, tokenConfig) enforcer.SetType("u") if enforcer.GetType() != "u" { t.Error("enforcer.loginType should be user") @@ -88,7 +88,7 @@ func NewTestEnforcer(t *testing.T) (error, *Enforcer, ctx.Context) { tokenConfig := config.DefaultTokenConfig() - enforcer, err := NewEnforcer(tokenConfig, adapter) + enforcer, err := NewEnforcer(adapter, tokenConfig) return err, enforcer, ctx } @@ -110,7 +110,7 @@ func NewTestConcurrentEnforcer(t *testing.T) (error, *Enforcer, ctx.Context) { tokenConfig.IsConcurrent = true tokenConfig.IsShare = false - enforcer, err := NewEnforcer(tokenConfig, adapter) + enforcer, err := NewEnforcer(adapter, tokenConfig) return err, enforcer, ctx } @@ -134,7 +134,7 @@ func NewTestNotConcurrentEnforcer(t *testing.T) (error, *Enforcer, ctx.Context) tokenConfig.IsShare = false tokenConfig.DataRefreshPeriod = 30 - enforcer, err := NewEnforcer(tokenConfig, adapter) + enforcer, err := NewEnforcer(adapter, tokenConfig) return err, enforcer, ctx } @@ -147,7 +147,7 @@ func TestNewEnforcerByFile(t *testing.T) { adapter := persist.NewDefaultAdapter() conf := "./examples/token_conf.yaml" - enforcer, err := NewEnforcer(conf, adapter) + enforcer, err := NewEnforcer(adapter, conf) if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -372,7 +372,7 @@ func TestNewEnforcer1(t *testing.T) { enforcer, err := NewEnforcer(NewDefaultAdapter()) t.Log(err) t.Log(enforcer) - enforcer, err = NewEnforcer(config.DefaultTokenConfig(), NewDefaultAdapter()) + enforcer, err = NewEnforcer(NewDefaultAdapter(), config.DefaultTokenConfig()) t.Log(err) t.Log(enforcer) }