refactor: delete enforcer.webctx

This commit is contained in:
weloe
2023-05-06 23:07:28 +08:00
parent 8f9a56d602
commit bbcdb0fedb
4 changed files with 73 additions and 77 deletions

View File

@@ -20,11 +20,10 @@ type Enforcer struct {
generateFunc model.GenerateTokenFunc
adapter persist.Adapter
watcher persist.Watcher
webCtx ctx.Context
logger log.Logger
}
func NewEnforcer(tokenConfig *config.TokenConfig, adapter persist.Adapter, ctx ctx.Context) (*Enforcer, error) {
func NewEnforcer(tokenConfig *config.TokenConfig, adapter persist.Adapter) (*Enforcer, error) {
fm := model.LoadFunctionMap()
if tokenConfig == nil || adapter == nil {
return nil, errors.New("NewEnforcer() params should be not nil")
@@ -34,7 +33,6 @@ func NewEnforcer(tokenConfig *config.TokenConfig, adapter persist.Adapter, ctx c
config: *tokenConfig,
generateFunc: fm,
adapter: adapter,
webCtx: ctx,
logger: &log.DefaultLogger{},
}, nil
}
@@ -47,7 +45,7 @@ func NewHttpContext(req *http.Request, writer http.ResponseWriter) ctx.Context {
return httpCtx.NewHttpContext(req, writer)
}
func NewDefaultEnforcer(adapter persist.Adapter, ctx ctx.Context) (*Enforcer, error) {
func NewDefaultEnforcer(adapter persist.Adapter) (*Enforcer, error) {
fm := model.LoadFunctionMap()
if adapter == nil {
return nil, errors.New("NewDefaultEnforcer() params should be not nil")
@@ -57,12 +55,11 @@ func NewDefaultEnforcer(adapter persist.Adapter, ctx ctx.Context) (*Enforcer, er
config: *config.DefaultTokenConfig(),
generateFunc: fm,
adapter: adapter,
webCtx: ctx,
logger: &log.DefaultLogger{},
}, nil
}
func NewEnforcerByFile(conf string, adapter persist.Adapter, ctx ctx.Context) (*Enforcer, error) {
func NewEnforcerByFile(conf string, adapter persist.Adapter) (*Enforcer, error) {
if conf == "" || adapter == nil {
return nil, errors.New("NewEnforcerByFile() params should be not nil")
}
@@ -78,15 +75,10 @@ func NewEnforcerByFile(conf string, adapter persist.Adapter, ctx ctx.Context) (*
config: *(newConfig.(*config.FileConfig).TokenConfig),
generateFunc: fm,
adapter: adapter,
webCtx: ctx,
logger: &log.DefaultLogger{},
}, nil
}
func (e *Enforcer) SetContext(context ctx.Context) {
e.webCtx = context
}
func (e *Enforcer) SetType(t string) {
e.loginType = t
}
@@ -120,12 +112,12 @@ func (e *Enforcer) IsLogEnable() bool {
}
// Login login by id and default loginModel, return tokenValue and error
func (e *Enforcer) Login(id string) (string, error) {
return e.LoginByModel(id, model.DefaultLoginModel())
func (e *Enforcer) Login(id string, ctx ctx.Context) (string, error) {
return e.LoginByModel(id, model.DefaultLoginModel(), ctx)
}
// LoginByModel login by id and loginModel, return tokenValue and error
func (e *Enforcer) LoginByModel(id string, loginModel *model.Login) (string, error) {
func (e *Enforcer) LoginByModel(id string, loginModel *model.Login, ctx ctx.Context) (string, error) {
var err error
var session *model.Session
var tokenValue string
@@ -167,7 +159,7 @@ func (e *Enforcer) LoginByModel(id string, loginModel *model.Login) (string, err
}
// response token
err = e.responseToken(tokenValue, loginModel)
err = e.responseToken(tokenValue, loginModel, ctx)
if err != nil {
return "", err
}
@@ -251,15 +243,15 @@ func (e *Enforcer) Replaced(id string, device string) error {
}
// Logout user logout
func (e *Enforcer) Logout() error {
func (e *Enforcer) Logout(ctx ctx.Context) error {
tokenConfig := e.config
token := e.GetRequestToken()
token := e.GetRequestToken(ctx)
if token == "" {
return errors.New("logout() failed: token doesn't exist")
}
if e.config.IsReadCookie {
e.webCtx.Response().DeleteCookie(tokenConfig.TokenName,
ctx.Response().DeleteCookie(tokenConfig.TokenName,
tokenConfig.CookieConfig.Path,
tokenConfig.CookieConfig.Domain)
}
@@ -301,8 +293,8 @@ func (e *Enforcer) IsLoginById(id string) (bool, error) {
}
// IsLogin check if user logged in by token.
func (e *Enforcer) IsLogin() (bool, error) {
tokenValue := e.GetRequestToken()
func (e *Enforcer) IsLogin(ctx ctx.Context) (bool, error) {
tokenValue := e.GetRequestToken(ctx)
if tokenValue == "" {
return false, nil
}
@@ -314,8 +306,8 @@ func (e *Enforcer) IsLogin() (bool, error) {
return e.validateValue(str)
}
func (e *Enforcer) GetLoginId() (string, error) {
tokenValue := e.GetRequestToken()
func (e *Enforcer) GetLoginId(ctx ctx.Context) (string, error) {
tokenValue := e.GetRequestToken(ctx)
str := e.adapter.GetStr(e.spliceTokenKey(tokenValue))
if str == "" {
return "", errors.New("GetLoginId() failed: not logged in")
@@ -371,21 +363,24 @@ func (e *Enforcer) Kickout(id string, device string) error {
}
// GetRequestToken read token from requestHeader | cookie | requestBody
func (e *Enforcer) GetRequestToken() string {
func (e *Enforcer) GetRequestToken(ctx ctx.Context) string {
var tokenValue string
if ctx == nil {
return ""
}
if e.config.IsReadHeader {
if tokenValue = e.webCtx.Request().Header(e.config.TokenName); tokenValue != "" {
if tokenValue = ctx.Request().Header(e.config.TokenName); tokenValue != "" {
return tokenValue
}
}
if e.config.IsReadCookie {
if tokenValue = e.webCtx.Request().Cookie(e.config.TokenName); tokenValue != "" {
if tokenValue = ctx.Request().Cookie(e.config.TokenName); tokenValue != "" {
return tokenValue
}
}
if e.config.IsReadBody {
if tokenValue = e.webCtx.Request().PostForm(e.config.TokenName); tokenValue != "" {
if tokenValue = ctx.Request().PostForm(e.config.TokenName); tokenValue != "" {
return tokenValue
}
}

View File

@@ -10,23 +10,22 @@ import (
var _ IEnforcer = &Enforcer{}
type IEnforcer interface {
Login(id string) (string, error)
LoginByModel(id string, loginModel *model.Login) (string, error)
Logout() error
IsLogin() (bool, error)
Login(id string, ctx ctx.Context) (string, error)
LoginByModel(id string, loginModel *model.Login, ctx ctx.Context) (string, error)
Logout(ctx ctx.Context) error
IsLogin(ctx ctx.Context) (bool, error)
IsLoginById(id string) (bool, error)
GetLoginId() (string, error)
GetLoginId(ctx ctx.Context) (string, error)
Replaced(id string, device string) error
// Banned TODO
Banned(id string, service string) error
Kickout(id string, device string) error
GetRequestToken() string
GetRequestToken(ctx ctx.Context) string
SetType(t string)
GetType() string
SetContext(ctx ctx.Context)
GetAdapter() persist.Adapter
SetAdapter(adapter persist.Adapter)
SetWatcher(watcher persist.Watcher)

View File

@@ -3,6 +3,7 @@ package token_go
import (
"errors"
"github.com/weloe/token-go/constant"
"github.com/weloe/token-go/ctx"
"github.com/weloe/token-go/model"
"strconv"
)
@@ -47,8 +48,10 @@ func (e *Enforcer) createLoginToken(id string, loginModel *model.Login) (string,
}
// responseToken set token to cookie or header
func (e *Enforcer) responseToken(tokenValue string, loginModel *model.Login) error {
func (e *Enforcer) responseToken(tokenValue string, loginModel *model.Login, ctx ctx.Context) error {
if ctx == nil {
return nil
}
tokenConfig := e.config
// set token to cookie
@@ -58,7 +61,7 @@ func (e *Enforcer) responseToken(tokenValue string, loginModel *model.Login) err
cookieTimeout = -1
}
// add cookie use tokenConfig.CookieConfig
e.webCtx.Response().AddCookie(tokenConfig.TokenName,
ctx.Response().AddCookie(tokenConfig.TokenName,
tokenValue,
tokenConfig.CookieConfig.Path,
tokenConfig.CookieConfig.Domain,
@@ -67,7 +70,7 @@ func (e *Enforcer) responseToken(tokenValue string, loginModel *model.Login) err
// set token to header
if loginModel.IsWriteHeader {
e.webCtx.Response().SetHeader(tokenConfig.TokenName, tokenValue)
ctx.Response().SetHeader(tokenConfig.TokenName, tokenValue)
}
return nil

View File

@@ -32,7 +32,9 @@ func NewTestHttpContext(t *testing.T) (error, ctx.Context) {
func TestNewEnforcer(t *testing.T) {
adapter := NewDefaultAdapter()
ctx := httpCtx.NewHttpContext(nil, nil)
if ctx == nil {
t.Errorf("NewHttpContext failed: %v", ctx)
}
tokenConfig := &config.TokenConfig{
TokenName: "testToken",
Timeout: 60,
@@ -45,7 +47,7 @@ func TestNewEnforcer(t *testing.T) {
}
logger := &log.DefaultLogger{}
enforcer, err := NewEnforcer(tokenConfig, adapter, ctx)
enforcer, err := NewEnforcer(tokenConfig, adapter)
enforcer.SetType("u")
if enforcer.GetType() != "u" {
t.Error("enforcer.loginType should be user")
@@ -67,13 +69,9 @@ func TestNewEnforcer(t *testing.T) {
t.Error("enforcer.adapter should be equal to the passed adapter parameter")
}
if enforcer.webCtx != ctx {
t.Error("enforcer.webCtx should be equal to the passed ctx parameter")
}
}
func NewTestEnforcer(t *testing.T) (error, *Enforcer) {
func NewTestEnforcer(t *testing.T) (error, *Enforcer, ctx.Context) {
reqBody := bytes.NewBufferString("test request body")
req, err := http.NewRequest("POST", "/test", reqBody)
if err != nil {
@@ -89,11 +87,11 @@ func NewTestEnforcer(t *testing.T) (error, *Enforcer) {
tokenConfig := config.DefaultTokenConfig()
enforcer, err := NewEnforcer(tokenConfig, adapter, ctx)
return err, enforcer
enforcer, err := NewEnforcer(tokenConfig, adapter)
return err, enforcer, ctx
}
func NewTestConcurrentEnforcer(t *testing.T) (error, *Enforcer) {
func NewTestConcurrentEnforcer(t *testing.T) (error, *Enforcer, ctx.Context) {
reqBody := bytes.NewBufferString("test request body")
req, err := http.NewRequest("POST", "/test", reqBody)
if err != nil {
@@ -111,11 +109,11 @@ func NewTestConcurrentEnforcer(t *testing.T) (error, *Enforcer) {
tokenConfig.IsConcurrent = true
tokenConfig.IsShare = false
enforcer, err := NewEnforcer(tokenConfig, adapter, ctx)
return err, enforcer
enforcer, err := NewEnforcer(tokenConfig, adapter)
return err, enforcer, ctx
}
func NewTestNotConcurrentEnforcer(t *testing.T) (error, *Enforcer) {
func NewTestNotConcurrentEnforcer(t *testing.T) (error, *Enforcer, ctx.Context) {
reqBody := bytes.NewBufferString("test request body")
req, err := http.NewRequest("POST", "/test", reqBody)
if err != nil {
@@ -133,12 +131,12 @@ func NewTestNotConcurrentEnforcer(t *testing.T) (error, *Enforcer) {
tokenConfig.IsConcurrent = false
tokenConfig.IsShare = false
enforcer, err := NewEnforcer(tokenConfig, adapter, ctx)
return err, enforcer
enforcer, err := NewEnforcer(tokenConfig, adapter)
return err, enforcer, ctx
}
func TestNewEnforcerByFile(t *testing.T) {
err, ctx := NewTestHttpContext(t)
err, _ := NewTestHttpContext(t)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
@@ -146,7 +144,7 @@ func TestNewEnforcerByFile(t *testing.T) {
adapter := persist.NewDefaultAdapter()
conf := "testConf"
enforcer, err := NewEnforcerByFile(conf, adapter, ctx)
enforcer, err := NewEnforcerByFile(conf, adapter)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
@@ -158,24 +156,22 @@ func TestNewEnforcerByFile(t *testing.T) {
if enforcer.adapter != adapter {
t.Error("enforcer.adapter should be equal to the passed adapter parameter")
}
if enforcer.webCtx != ctx {
t.Error("enforcer.webCtx should be equal to the passed ctx parameter")
}
}
func TestEnforcer_Login(t *testing.T) {
err, enforcer := NewTestEnforcer(t)
err, enforcer, ctx := NewTestEnforcer(t)
enforcer.EnableLog()
if err != nil {
t.Errorf("NewEnforcer() failed: %v", err)
}
loginId := "1"
_, err = enforcer.Login(loginId)
_, err = enforcer.Login(loginId, ctx)
if err != nil {
t.Errorf("LoginByModel() failed: %v", err)
}
_, err = enforcer.LoginByModel(loginId, model.DefaultLoginModel())
_, err = enforcer.LoginByModel(loginId, model.DefaultLoginModel(), ctx)
if err != nil {
t.Errorf("LoginByModel() failed: %v", err)
}
@@ -204,18 +200,18 @@ func TestEnforcer_Login(t *testing.T) {
}
func TestEnforcer_GetLoginId(t *testing.T) {
err, enforcer := NewTestEnforcer(t)
err, enforcer, ctx := NewTestEnforcer(t)
if err != nil {
t.Errorf("NewEnforcer() failed: %v", err)
}
loginModel := model.DefaultLoginModel()
loginModel.Token = "233"
_, err = enforcer.LoginByModel("id", loginModel)
_, err = enforcer.LoginByModel("id", loginModel, ctx)
if err != nil {
t.Errorf("Login() failed: %v", err)
}
id, err := enforcer.GetLoginId()
id, err := enforcer.GetLoginId(ctx)
if err != nil {
t.Errorf("GetLoginId() failed: %v", err)
}
@@ -227,14 +223,14 @@ func TestEnforcer_GetLoginId(t *testing.T) {
}
func TestEnforcer_Logout(t *testing.T) {
err, enforcer := NewTestEnforcer(t)
err, enforcer, ctx := NewTestEnforcer(t)
if err != nil {
t.Errorf("NewEnforcer() failed: %v", err)
}
loginModel := model.DefaultLoginModel()
loginModel.Token = "233"
token, err := enforcer.LoginByModel("id", loginModel)
token, err := enforcer.LoginByModel("id", loginModel, ctx)
if token != "233" {
t.Errorf("LoginByModel() failed: unexpected token value %s, want '233' ", token)
}
@@ -242,12 +238,12 @@ func TestEnforcer_Logout(t *testing.T) {
t.Errorf("Login() failed: %v", err)
}
err = enforcer.Logout()
err = enforcer.Logout(ctx)
if err != nil {
t.Errorf("Logout() failed: %v", err)
}
login, err := enforcer.IsLogin()
login, err := enforcer.IsLogin(ctx)
if login {
t.Errorf("IsLogin() failed: unexpected value %v", login)
}
@@ -257,14 +253,14 @@ func TestEnforcer_Logout(t *testing.T) {
}
func TestEnforcer_Kickout(t *testing.T) {
err, enforcer := NewTestEnforcer(t)
err, enforcer, ctx := NewTestEnforcer(t)
if err != nil {
t.Errorf("NewEnforcer() failed: %v", err)
}
loginModel := model.DefaultLoginModel()
loginModel.Token = "233"
_, err = enforcer.LoginByModel("id", loginModel)
_, err = enforcer.LoginByModel("id", loginModel, ctx)
if err != nil {
t.Errorf("Login() failed: %v", err)
}
@@ -278,7 +274,7 @@ func TestEnforcer_Kickout(t *testing.T) {
if session != nil {
t.Errorf("unexpected session value %v", session)
}
login, err := enforcer.IsLogin()
login, err := enforcer.IsLogin(ctx)
if login {
t.Errorf("IsLogin() failed: unexpected value %v", login)
}
@@ -290,7 +286,7 @@ func TestEnforcer_Kickout(t *testing.T) {
}
func TestEnforcerNotConcurrentNotShareLogin(t *testing.T) {
err, enforcer := NewTestNotConcurrentEnforcer(t)
err, enforcer, ctx := NewTestNotConcurrentEnforcer(t)
if err != nil {
t.Errorf("NewEnforcer() failed: %v", err)
}
@@ -298,7 +294,7 @@ func TestEnforcerNotConcurrentNotShareLogin(t *testing.T) {
loginModel := model.DefaultLoginModel()
for i := 0; i < 4; i++ {
_, err = enforcer.LoginByModel("id", loginModel)
_, err = enforcer.LoginByModel("id", loginModel, ctx)
if err != nil {
t.Errorf("Login() failed: %v", err)
}
@@ -311,14 +307,14 @@ func TestEnforcerNotConcurrentNotShareLogin(t *testing.T) {
}
func TestEnforcer_ConcurrentShare(t *testing.T) {
err, enforcer := NewTestEnforcer(t)
err, enforcer, ctx := NewTestEnforcer(t)
if err != nil {
t.Errorf("NewEnforcer() failed: %v", err)
}
loginModel := model.DefaultLoginModel()
for i := 0; i < 5; i++ {
_, err = enforcer.LoginByModel("id", loginModel)
_, err = enforcer.LoginByModel("id", loginModel, ctx)
if err != nil {
t.Errorf("Login() failed: %v", err)
}
@@ -330,14 +326,14 @@ func TestEnforcer_ConcurrentShare(t *testing.T) {
}
func TestEnforcer_ConcurrentNotShareMultiLogin(t *testing.T) {
err, enforcer := NewTestConcurrentEnforcer(t)
err, enforcer, ctx := NewTestConcurrentEnforcer(t)
if err != nil {
t.Errorf("NewEnforcer() failed: %v", err)
}
loginModel := model.DefaultLoginModel()
for i := 0; i < 14; i++ {
_, err = enforcer.LoginByModel("id", loginModel)
_, err = enforcer.LoginByModel("id", loginModel, ctx)
if err != nil {
t.Errorf("Login() failed: %v", err)
}
@@ -351,11 +347,14 @@ func TestEnforcer_ConcurrentNotShareMultiLogin(t *testing.T) {
func TestNewDefaultEnforcer(t *testing.T) {
err, ctx := NewTestHttpContext(t)
if ctx == nil {
t.Errorf("NewTestHttpContext() failed: %v", err)
}
if err != nil {
t.Errorf("NewTestHttpContext() failed: %v", err)
}
enforcer, err := NewDefaultEnforcer(persist.NewDefaultAdapter(), ctx)
enforcer, err := NewDefaultEnforcer(persist.NewDefaultAdapter())
if err != nil || enforcer == nil {
t.Errorf("NewEnforcer() failed: %v", err)
}