mirror of
https://github.com/weloe/token-go.git
synced 2025-10-05 07:26:50 +08:00
refactor: improve NewEnforcer()
This commit is contained in:
92
enforcer.go
92
enforcer.go
@@ -2,6 +2,7 @@ package token_go
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"github.com/weloe/token-go/config"
|
"github.com/weloe/token-go/config"
|
||||||
"github.com/weloe/token-go/constant"
|
"github.com/weloe/token-go/constant"
|
||||||
"github.com/weloe/token-go/ctx"
|
"github.com/weloe/token-go/ctx"
|
||||||
@@ -9,6 +10,7 @@ import (
|
|||||||
"github.com/weloe/token-go/log"
|
"github.com/weloe/token-go/log"
|
||||||
"github.com/weloe/token-go/model"
|
"github.com/weloe/token-go/model"
|
||||||
"github.com/weloe/token-go/persist"
|
"github.com/weloe/token-go/persist"
|
||||||
|
"github.com/weloe/token-go/util"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
@@ -23,20 +25,6 @@ type Enforcer struct {
|
|||||||
logger log.Logger
|
logger log.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewEnforcer(tokenConfig *config.TokenConfig, adapter persist.Adapter) (*Enforcer, error) {
|
|
||||||
fm := model.LoadFunctionMap()
|
|
||||||
if tokenConfig == nil || adapter == nil {
|
|
||||||
return nil, errors.New("NewEnforcer() params should be not nil")
|
|
||||||
}
|
|
||||||
return &Enforcer{
|
|
||||||
loginType: "user",
|
|
||||||
config: *tokenConfig,
|
|
||||||
generateFunc: fm,
|
|
||||||
adapter: adapter,
|
|
||||||
logger: &log.DefaultLogger{},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewDefaultAdapter() persist.Adapter {
|
func NewDefaultAdapter() persist.Adapter {
|
||||||
return persist.NewDefaultAdapter()
|
return persist.NewDefaultAdapter()
|
||||||
}
|
}
|
||||||
@@ -45,34 +33,76 @@ func NewHttpContext(req *http.Request, writer http.ResponseWriter) ctx.Context {
|
|||||||
return httpCtx.NewHttpContext(req, writer)
|
return httpCtx.NewHttpContext(req, writer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDefaultEnforcer(adapter persist.Adapter) (*Enforcer, error) {
|
func NewEnforcer(args ...interface{}) (*Enforcer, error) {
|
||||||
fm := model.LoadFunctionMap()
|
var err error
|
||||||
if adapter == nil {
|
var enforcer *Enforcer
|
||||||
return nil, errors.New("NewDefaultEnforcer() params should be not nil")
|
if len(args) > 2 {
|
||||||
|
return nil, fmt.Errorf("NewEnforcer() failed: unexpected args length = %v, it should be less than or equal to 2", len(args))
|
||||||
}
|
}
|
||||||
return &Enforcer{
|
if util.HasNil(args) {
|
||||||
loginType: "user",
|
return nil, errors.New("NewEnforcer() failed: parameters cannot be nil")
|
||||||
config: *config.DefaultTokenConfig(),
|
}
|
||||||
generateFunc: fm,
|
|
||||||
adapter: adapter,
|
if len(args) == 1 {
|
||||||
logger: &log.DefaultLogger{},
|
switch args[0].(type) {
|
||||||
}, nil
|
case persist.Adapter:
|
||||||
|
enforcer, err = InitWithDefaultConfig(args[0].(persist.Adapter))
|
||||||
|
case *config.TokenConfig:
|
||||||
|
adapter, ok := args[1].(persist.Adapter)
|
||||||
|
if ok {
|
||||||
|
enforcer, err = InitWithConfig(args[0].(*config.TokenConfig), adapter)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("NewEnforcer() failed: unexpected args[1] type, it should be persist.Adapter")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if len(args) == 2 {
|
||||||
|
adapter, ok := args[1].(persist.Adapter)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("NewEnforcer() failed: unexpected args[1] type, it should be persist.Adapter")
|
||||||
|
}
|
||||||
|
switch args[0].(type) {
|
||||||
|
case *config.TokenConfig:
|
||||||
|
if ok {
|
||||||
|
enforcer, err = InitWithConfig(args[0].(*config.TokenConfig), adapter)
|
||||||
|
} else {
|
||||||
|
enforcer, err = InitWithConfig(args[0].(*config.TokenConfig), adapter)
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
enforcer, err = InitWithFile(args[0].(string), adapter)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return enforcer, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewEnforcerByFile(conf string, adapter persist.Adapter) (*Enforcer, error) {
|
func InitWithDefaultConfig(adapter persist.Adapter) (*Enforcer, error) {
|
||||||
|
if adapter == nil {
|
||||||
|
return nil, errors.New("InitWithDefaultConfig() failed: parameters cannot be nil")
|
||||||
|
}
|
||||||
|
return InitWithConfig(config.DefaultTokenConfig(), adapter)
|
||||||
|
}
|
||||||
|
|
||||||
|
func InitWithFile(conf string, adapter persist.Adapter) (*Enforcer, error) {
|
||||||
if conf == "" || adapter == nil {
|
if conf == "" || adapter == nil {
|
||||||
return nil, errors.New("NewEnforcerByFile() params should be not nil")
|
return nil, errors.New("InitWithFile() failed: parameters cannot be nil")
|
||||||
}
|
}
|
||||||
newConfig, err := config.NewConfig(conf)
|
newConfig, err := config.NewConfig(conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
fm := model.LoadFunctionMap()
|
enforcer, err := InitWithConfig(newConfig.(*config.FileConfig).TokenConfig, adapter)
|
||||||
|
enforcer.conf = conf
|
||||||
|
return enforcer, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func InitWithConfig(tokenConfig *config.TokenConfig, adapter persist.Adapter) (*Enforcer, error) {
|
||||||
|
fm := model.LoadFunctionMap()
|
||||||
|
if tokenConfig == nil || adapter == nil {
|
||||||
|
return nil, errors.New("InitWithConfig() failed: parameters cannot be nil")
|
||||||
|
}
|
||||||
return &Enforcer{
|
return &Enforcer{
|
||||||
loginType: "user",
|
loginType: "user",
|
||||||
conf: conf,
|
config: *tokenConfig,
|
||||||
config: *(newConfig.(*config.FileConfig).TokenConfig),
|
|
||||||
generateFunc: fm,
|
generateFunc: fm,
|
||||||
adapter: adapter,
|
adapter: adapter,
|
||||||
logger: &log.DefaultLogger{},
|
logger: &log.DefaultLogger{},
|
||||||
@@ -335,7 +365,7 @@ func (e *Enforcer) Kickout(id string, device string) error {
|
|||||||
if tokenSign, ok := element.Value.(*model.TokenSign); ok {
|
if tokenSign, ok := element.Value.(*model.TokenSign); ok {
|
||||||
elementV := tokenSign.Value
|
elementV := tokenSign.Value
|
||||||
session.RemoveTokenSign(elementV)
|
session.RemoveTokenSign(elementV)
|
||||||
// sign token replaced
|
// sign token kicked
|
||||||
err := e.adapter.UpdateStr(e.spliceTokenKey(elementV), strconv.Itoa(constant.BeKicked))
|
err := e.adapter.UpdateStr(e.spliceTokenKey(elementV), strconv.Itoa(constant.BeKicked))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@@ -124,6 +124,7 @@ func NewTestNotConcurrentEnforcer(t *testing.T) (error, *Enforcer, ctx.Context)
|
|||||||
req.Header.Set(constant.TokenName, "233")
|
req.Header.Set(constant.TokenName, "233")
|
||||||
|
|
||||||
ctx := httpCtx.NewHttpContext(req, rr)
|
ctx := httpCtx.NewHttpContext(req, rr)
|
||||||
|
t.Log(ctx)
|
||||||
|
|
||||||
adapter := persist.NewDefaultAdapter()
|
adapter := persist.NewDefaultAdapter()
|
||||||
|
|
||||||
@@ -144,7 +145,7 @@ func TestNewEnforcerByFile(t *testing.T) {
|
|||||||
adapter := persist.NewDefaultAdapter()
|
adapter := persist.NewDefaultAdapter()
|
||||||
conf := "testConf"
|
conf := "testConf"
|
||||||
|
|
||||||
enforcer, err := NewEnforcerByFile(conf, adapter)
|
enforcer, err := NewEnforcer(conf, adapter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Unexpected error: %v", err)
|
t.Errorf("Unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -163,7 +164,7 @@ func TestEnforcer_Login(t *testing.T) {
|
|||||||
err, enforcer, ctx := NewTestEnforcer(t)
|
err, enforcer, ctx := NewTestEnforcer(t)
|
||||||
enforcer.EnableLog()
|
enforcer.EnableLog()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("NewEnforcer() failed: %v", err)
|
t.Errorf("InitWithConfig() failed: %v", err)
|
||||||
}
|
}
|
||||||
loginId := "1"
|
loginId := "1"
|
||||||
_, err = enforcer.Login(loginId, ctx)
|
_, err = enforcer.Login(loginId, ctx)
|
||||||
@@ -202,7 +203,7 @@ func TestEnforcer_Login(t *testing.T) {
|
|||||||
func TestEnforcer_GetLoginId(t *testing.T) {
|
func TestEnforcer_GetLoginId(t *testing.T) {
|
||||||
err, enforcer, ctx := NewTestEnforcer(t)
|
err, enforcer, ctx := NewTestEnforcer(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("NewEnforcer() failed: %v", err)
|
t.Errorf("InitWithConfig() failed: %v", err)
|
||||||
}
|
}
|
||||||
loginModel := model.DefaultLoginModel()
|
loginModel := model.DefaultLoginModel()
|
||||||
loginModel.Token = "233"
|
loginModel.Token = "233"
|
||||||
@@ -225,7 +226,7 @@ func TestEnforcer_GetLoginId(t *testing.T) {
|
|||||||
func TestEnforcer_Logout(t *testing.T) {
|
func TestEnforcer_Logout(t *testing.T) {
|
||||||
err, enforcer, ctx := NewTestEnforcer(t)
|
err, enforcer, ctx := NewTestEnforcer(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("NewEnforcer() failed: %v", err)
|
t.Errorf("InitWithConfig() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loginModel := model.DefaultLoginModel()
|
loginModel := model.DefaultLoginModel()
|
||||||
@@ -255,7 +256,7 @@ func TestEnforcer_Logout(t *testing.T) {
|
|||||||
func TestEnforcer_Kickout(t *testing.T) {
|
func TestEnforcer_Kickout(t *testing.T) {
|
||||||
err, enforcer, ctx := NewTestEnforcer(t)
|
err, enforcer, ctx := NewTestEnforcer(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("NewEnforcer() failed: %v", err)
|
t.Errorf("InitWithConfig() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loginModel := model.DefaultLoginModel()
|
loginModel := model.DefaultLoginModel()
|
||||||
@@ -288,7 +289,7 @@ func TestEnforcer_Kickout(t *testing.T) {
|
|||||||
func TestEnforcerNotConcurrentNotShareLogin(t *testing.T) {
|
func TestEnforcerNotConcurrentNotShareLogin(t *testing.T) {
|
||||||
err, enforcer, ctx := NewTestNotConcurrentEnforcer(t)
|
err, enforcer, ctx := NewTestNotConcurrentEnforcer(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("NewEnforcer() failed: %v", err)
|
t.Errorf("InitWithConfig() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loginModel := model.DefaultLoginModel()
|
loginModel := model.DefaultLoginModel()
|
||||||
@@ -309,7 +310,7 @@ func TestEnforcerNotConcurrentNotShareLogin(t *testing.T) {
|
|||||||
func TestEnforcer_ConcurrentShare(t *testing.T) {
|
func TestEnforcer_ConcurrentShare(t *testing.T) {
|
||||||
err, enforcer, ctx := NewTestEnforcer(t)
|
err, enforcer, ctx := NewTestEnforcer(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("NewEnforcer() failed: %v", err)
|
t.Errorf("InitWithConfig() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loginModel := model.DefaultLoginModel()
|
loginModel := model.DefaultLoginModel()
|
||||||
@@ -328,7 +329,7 @@ func TestEnforcer_ConcurrentShare(t *testing.T) {
|
|||||||
func TestEnforcer_ConcurrentNotShareMultiLogin(t *testing.T) {
|
func TestEnforcer_ConcurrentNotShareMultiLogin(t *testing.T) {
|
||||||
err, enforcer, ctx := NewTestConcurrentEnforcer(t)
|
err, enforcer, ctx := NewTestConcurrentEnforcer(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("NewEnforcer() failed: %v", err)
|
t.Errorf("InitWithConfig() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loginModel := model.DefaultLoginModel()
|
loginModel := model.DefaultLoginModel()
|
||||||
@@ -354,8 +355,18 @@ func TestNewDefaultEnforcer(t *testing.T) {
|
|||||||
t.Errorf("NewTestHttpContext() failed: %v", err)
|
t.Errorf("NewTestHttpContext() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
enforcer, err := NewDefaultEnforcer(persist.NewDefaultAdapter())
|
enforcer, err := NewEnforcer(persist.NewDefaultAdapter())
|
||||||
|
|
||||||
if err != nil || enforcer == nil {
|
if err != nil || enforcer == nil {
|
||||||
t.Errorf("NewEnforcer() failed: %v", err)
|
t.Errorf("InitWithConfig() failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewEnforcer1(t *testing.T) {
|
||||||
|
enforcer, err := NewEnforcer(NewDefaultAdapter())
|
||||||
|
t.Log(err)
|
||||||
|
t.Log(enforcer)
|
||||||
|
enforcer, err = NewEnforcer(config.DefaultTokenConfig(), NewDefaultAdapter())
|
||||||
|
t.Log(err)
|
||||||
|
t.Log(enforcer)
|
||||||
|
}
|
||||||
|
10
util/util.go
Normal file
10
util/util.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
func HasNil(arr []interface{}) bool {
|
||||||
|
for _, elem := range arr {
|
||||||
|
if elem == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
Reference in New Issue
Block a user