reafacor: simplify constructor function

This commit is contained in:
weloe
2023-05-13 13:05:53 +08:00
parent f58ba4d93f
commit ab16d567f8
3 changed files with 14 additions and 30 deletions

View File

@@ -99,7 +99,7 @@ func main() {
IsShare: true, IsShare: true,
MaxLoginCount: -1, MaxLoginCount: -1,
} }
enforcer, err = tokenGo.NewEnforcer(tokenConfig, adapter) enforcer, err = tokenGo.NewEnforcer(adapter, tokenConfig)
} }
``` ```

View File

@@ -34,7 +34,7 @@ func NewHttpContext(req *http.Request, writer http.ResponseWriter) ctx.Context {
return httpCtx.NewHttpContext(req, writer) return httpCtx.NewHttpContext(req, writer)
} }
func NewEnforcer(args ...interface{}) (*Enforcer, error) { func NewEnforcer(adapter persist.Adapter, args ...interface{}) (*Enforcer, error) {
var err error var err error
var enforcer *Enforcer var enforcer *Enforcer
if len(args) > 2 { if len(args) > 2 {
@@ -44,32 +44,16 @@ func NewEnforcer(args ...interface{}) (*Enforcer, error) {
return nil, errors.New("NewEnforcer() failed: parameters cannot be nil") return nil, errors.New("NewEnforcer() failed: parameters cannot be nil")
} }
if len(args) == 1 { if len(args) == 0 {
switch args[0].(type) { enforcer, err = InitWithDefaultConfig(adapter)
case persist.Adapter: } else if len(args) == 1 {
enforcer, err = InitWithDefaultConfig(args[0].(persist.Adapter))
case *config.TokenConfig:
adapter, ok := args[1].(persist.Adapter)
if ok {
enforcer, err = InitWithConfig(args[0].(*config.TokenConfig), adapter)
} else {
return nil, fmt.Errorf("NewEnforcer() failed: unexpected args[1] type, it should be persist.Adapter")
}
}
} else if len(args) == 2 {
adapter, ok := args[1].(persist.Adapter)
if !ok {
return nil, errors.New("NewEnforcer() failed: unexpected args[1] type, it should be persist.Adapter")
}
switch args[0].(type) { switch args[0].(type) {
case *config.TokenConfig: case *config.TokenConfig:
if ok {
enforcer, err = InitWithConfig(args[0].(*config.TokenConfig), adapter) enforcer, err = InitWithConfig(args[0].(*config.TokenConfig), adapter)
} else {
enforcer, err = InitWithConfig(args[0].(*config.TokenConfig), adapter)
}
case string: case string:
enforcer, err = InitWithFile(args[0].(string), adapter) enforcer, err = InitWithFile(args[0].(string), adapter)
default:
return nil, errors.New("NewEnforcer() failed: the second parameter should be *TokenConfig or string")
} }
} }

View File

@@ -48,7 +48,7 @@ func TestNewEnforcer(t *testing.T) {
} }
logger := &log.DefaultLogger{} logger := &log.DefaultLogger{}
enforcer, err := NewEnforcer(tokenConfig, adapter) enforcer, err := NewEnforcer(adapter, tokenConfig)
enforcer.SetType("u") enforcer.SetType("u")
if enforcer.GetType() != "u" { if enforcer.GetType() != "u" {
t.Error("enforcer.loginType should be user") t.Error("enforcer.loginType should be user")
@@ -88,7 +88,7 @@ func NewTestEnforcer(t *testing.T) (error, *Enforcer, ctx.Context) {
tokenConfig := config.DefaultTokenConfig() tokenConfig := config.DefaultTokenConfig()
enforcer, err := NewEnforcer(tokenConfig, adapter) enforcer, err := NewEnforcer(adapter, tokenConfig)
return err, enforcer, ctx return err, enforcer, ctx
} }
@@ -110,7 +110,7 @@ func NewTestConcurrentEnforcer(t *testing.T) (error, *Enforcer, ctx.Context) {
tokenConfig.IsConcurrent = true tokenConfig.IsConcurrent = true
tokenConfig.IsShare = false tokenConfig.IsShare = false
enforcer, err := NewEnforcer(tokenConfig, adapter) enforcer, err := NewEnforcer(adapter, tokenConfig)
return err, enforcer, ctx return err, enforcer, ctx
} }
@@ -134,7 +134,7 @@ func NewTestNotConcurrentEnforcer(t *testing.T) (error, *Enforcer, ctx.Context)
tokenConfig.IsShare = false tokenConfig.IsShare = false
tokenConfig.DataRefreshPeriod = 30 tokenConfig.DataRefreshPeriod = 30
enforcer, err := NewEnforcer(tokenConfig, adapter) enforcer, err := NewEnforcer(adapter, tokenConfig)
return err, enforcer, ctx return err, enforcer, ctx
} }
@@ -147,7 +147,7 @@ func TestNewEnforcerByFile(t *testing.T) {
adapter := persist.NewDefaultAdapter() adapter := persist.NewDefaultAdapter()
conf := "./examples/token_conf.yaml" conf := "./examples/token_conf.yaml"
enforcer, err := NewEnforcer(conf, adapter) enforcer, err := NewEnforcer(adapter, conf)
if err != nil { if err != nil {
t.Errorf("Unexpected error: %v", err) t.Errorf("Unexpected error: %v", err)
} }
@@ -372,7 +372,7 @@ func TestNewEnforcer1(t *testing.T) {
enforcer, err := NewEnforcer(NewDefaultAdapter()) enforcer, err := NewEnforcer(NewDefaultAdapter())
t.Log(err) t.Log(err)
t.Log(enforcer) t.Log(enforcer)
enforcer, err = NewEnforcer(config.DefaultTokenConfig(), NewDefaultAdapter()) enforcer, err = NewEnforcer(NewDefaultAdapter(), config.DefaultTokenConfig())
t.Log(err) t.Log(err)
t.Log(enforcer) t.Log(enforcer)
} }