diff --git a/enforcer.go b/enforcer.go new file mode 100644 index 0000000..85274af --- /dev/null +++ b/enforcer.go @@ -0,0 +1,417 @@ +package token_go + +import ( + "errors" + "github.com/weloe/token-go/config" + "github.com/weloe/token-go/constant" + "github.com/weloe/token-go/ctx" + httpCtx "github.com/weloe/token-go/ctx/go-http-context" + "github.com/weloe/token-go/log" + "github.com/weloe/token-go/model" + "github.com/weloe/token-go/persist" + "net/http" + "strconv" +) + +type Enforcer struct { + conf string + loginType string + config config.TokenConfig + generateFunc model.GenerateTokenFunc + adapter persist.Adapter + watcher persist.Watcher + webCtx ctx.Context + logger log.Logger +} + +func NewEnforcer(tokenConfig *config.TokenConfig, adapter persist.Adapter, ctx ctx.Context) (*Enforcer, error) { + fm := model.LoadFunctionMap() + if tokenConfig == nil || adapter == nil { + return nil, errors.New("NewEnforcer() params should be not nil") + } + return &Enforcer{ + loginType: "user", + config: *tokenConfig, + generateFunc: fm, + adapter: adapter, + webCtx: ctx, + logger: &log.DefaultLogger{}, + }, nil +} + +func NewDefaultAdapter() persist.Adapter { + return persist.NewDefaultAdapter() +} + +func NewHttpContext(req *http.Request, writer http.ResponseWriter) ctx.Context { + return httpCtx.NewHttpContext(req, writer) +} + +func NewDefaultEnforcer(adapter persist.Adapter, ctx ctx.Context) (*Enforcer, error) { + fm := model.LoadFunctionMap() + if adapter == nil { + return nil, errors.New("NewDefaultEnforcer() params should be not nil") + } + return &Enforcer{ + loginType: "user", + config: *config.DefaultTokenConfig(), + generateFunc: fm, + adapter: adapter, + webCtx: ctx, + logger: &log.DefaultLogger{}, + }, nil +} + +func NewEnforcerByFile(conf string, adapter persist.Adapter, ctx ctx.Context) (*Enforcer, error) { + if conf == "" || adapter == nil { + return nil, errors.New("NewEnforcerByFile() params should be not nil") + } + newConfig, err := config.NewConfig(conf) + if err != nil { + return nil, err + } + fm := model.LoadFunctionMap() + + return &Enforcer{ + loginType: "user", + conf: conf, + config: *(newConfig.(*config.FileConfig).TokenConfig), + generateFunc: fm, + adapter: adapter, + webCtx: ctx, + logger: &log.DefaultLogger{}, + }, nil +} + +func (e *Enforcer) SetContext(context ctx.Context) { + e.webCtx = context +} + +func (e *Enforcer) SetType(t string) { + e.loginType = t +} + +func (e *Enforcer) GetType() string { + return e.loginType +} + +func (e *Enforcer) GetAdapter() persist.Adapter { + return e.adapter +} + +func (e *Enforcer) SetAdapter(adapter persist.Adapter) { + e.adapter = adapter +} + +func (e *Enforcer) SetWatcher(watcher persist.Watcher) { + e.watcher = watcher +} + +func (e *Enforcer) SetLogger(logger log.Logger) { + e.logger = logger +} + +func (e *Enforcer) EnableLog() { + e.logger.Enable(true) +} + +func (e *Enforcer) IsLogEnable() bool { + return e.logger.IsEnabled() +} + +// Login login by id and default loginModel, return tokenValue and error +func (e *Enforcer) Login(id string) (string, error) { + return e.LoginByModel(id, model.DefaultLoginModel()) +} + +// LoginByModel login by id and loginModel, return tokenValue and error +func (e *Enforcer) LoginByModel(id string, loginModel *model.Login) (string, error) { + var err error + var session *model.Session + var tokenValue string + tokenConfig := e.config + + // allocate token + tokenValue, err = e.createLoginToken(id, loginModel) + + if err != nil { + return "", err + } + + // add tokenSign + if session = e.GetSession(id); session == nil { + session = model.NewSession(e.spliceSessionKey(id), "account-session", id) + session.AddTokenSign(&model.TokenSign{ + Value: tokenValue, + Device: loginModel.Device, + }) + } + + if !(tokenConfig.IsConcurrent && tokenConfig.IsShare) { + session.AddTokenSign(&model.TokenSign{ + Value: tokenValue, + Device: loginModel.Device, + }) + } + + // reset session + err = e.SetSession(id, session, loginModel.Timeout) + if err != nil { + return "", err + } + + // set token-id + err = e.adapter.SetStr(e.spliceTokenKey(tokenValue), id, loginModel.Timeout) + if err != nil { + return "", err + } + + // response token + err = e.responseToken(tokenValue, loginModel) + if err != nil { + return "", err + } + + // called watcher + m := &model.Login{ + Device: loginModel.Device, + IsLastingCookie: loginModel.IsLastingCookie, + Timeout: loginModel.Timeout, + JwtData: loginModel.JwtData, + Token: tokenValue, + IsWriteHeader: loginModel.IsWriteHeader, + } + + // called logger + e.logger.Login(e.loginType, id, tokenValue, m) + + if e.watcher != nil { + e.watcher.Login(e.loginType, id, tokenValue, m) + } + + // if login success check it + if tokenConfig.IsConcurrent && !tokenConfig.IsShare { + // check if the number of sessions for this account exceeds the maximum limit. + if tokenConfig.MaxLoginCount != -1 { + if session = e.GetSession(id); session != nil { + // logout account until loginCount == maxLoginCount if loginCount > maxLoginCount + for element, i := session.TokenSignList.Front(), 0; element != nil && i < session.TokenSignList.Len()-int(tokenConfig.MaxLoginCount); element, i = element.Next(), i+1 { + tokenSign := element.Value.(*model.TokenSign) + // delete tokenSign + session.RemoveTokenSign(tokenSign.Value) + // delete token-id + err = e.adapter.Delete(e.spliceTokenKey(tokenSign.Value)) + if err != nil { + return "", err + } + } + // check TokenSignList length, if length == 0, delete this session + if session != nil && session.TokenSignList.Len() == 0 { + err = e.deleteSession(id) + if err != nil { + return "", err + } + } + } + } + + } + + return tokenValue, nil +} + +// Replaced replace other user +func (e *Enforcer) Replaced(id string, device string) error { + if session := e.GetSession(id); session != nil { + // get by login device + tokenSignList := session.GetFilterTokenSign(device) + // sign account replaced + for element := tokenSignList.Front(); element != nil; element = element.Next() { + if tokenSign, ok := element.Value.(*model.TokenSign); ok { + elementV := tokenSign.Value + session.RemoveTokenSign(elementV) + // sign token replaced + err := e.adapter.UpdateStr(e.spliceTokenKey(elementV), strconv.Itoa(constant.BeReplaced)) + if err != nil { + return err + } + + // called logger + e.logger.Replace(e.loginType, id, tokenSign.Value) + + // called watcher + if e.watcher != nil { + e.watcher.Replace(e.loginType, id, tokenSign.Value) + } + } + } + + } + return nil +} + +// Logout user logout +func (e *Enforcer) Logout() error { + tokenConfig := e.config + + token := e.GetRequestToken() + if token == "" { + return errors.New("logout() failed: token doesn't exist") + } + if e.config.IsReadCookie { + e.webCtx.Response().DeleteCookie(tokenConfig.TokenName, + tokenConfig.CookieConfig.Path, + tokenConfig.CookieConfig.Domain) + } + + err := e.logoutByToken(token) + + if err != nil { + return err + } + return nil +} + +// IsLoginById check if user logged in by loginId. +// check all tokenValue and if one is validated return true +func (e *Enforcer) IsLoginById(id string) (bool, error) { + var error error + session := e.GetSession(id) + if session != nil { + l := session.TokenSignList + for element := l.Back(); element != nil; element = element.Prev() { + tokenSign := element.Value.(*model.TokenSign) + str := e.adapter.GetStr(e.spliceTokenKey(tokenSign.Value)) + if str == "" { + continue + } + value, err := e.validateValue(str) + if err != nil { + error = err + continue + } + if value { + return true, nil + } + + } + } + + return false, error +} + +// IsLogin check if user logged in by token. +func (e *Enforcer) IsLogin() (bool, error) { + tokenValue := e.GetRequestToken() + if tokenValue == "" { + return false, nil + } + str := e.adapter.GetStr(e.spliceTokenKey(tokenValue)) + if str == "" { + return false, nil + } + + return e.validateValue(str) +} + +func (e *Enforcer) GetLoginId() (string, error) { + tokenValue := e.GetRequestToken() + str := e.adapter.GetStr(e.spliceTokenKey(tokenValue)) + if str == "" { + return "", errors.New("GetLoginId() failed: not logged in") + } + validate, err := e.validateValue(str) + if !validate { + return "", err + } + + return str, nil +} + +func (e *Enforcer) Banned(id string, service string) error { + panic("implement me ...") +} + +// Kickout kickout user +func (e *Enforcer) Kickout(id string, device string) error { + session := e.GetSession(id) + if session != nil { + // get by login device + tokenSignList := session.GetFilterTokenSign(device) + // sign account kicked + for element := tokenSignList.Front(); element != nil; element = element.Next() { + if tokenSign, ok := element.Value.(*model.TokenSign); ok { + elementV := tokenSign.Value + session.RemoveTokenSign(elementV) + // sign token replaced + err := e.adapter.UpdateStr(e.spliceTokenKey(elementV), strconv.Itoa(constant.BeKicked)) + if err != nil { + return err + } + + // called logger + e.logger.Kickout(e.loginType, id, tokenSign.Value) + + // called watcher + if e.watcher != nil { + e.watcher.Kickout(e.loginType, id, tokenSign.Value) + } + } + } + + } + // check TokenSignList length, if length == 0, delete this session + if session != nil && session.TokenSignList.Len() == 0 { + err := e.deleteSession(id) + if err != nil { + return err + } + } + return nil +} + +// GetRequestToken read token from requestHeader | cookie | requestBody +func (e *Enforcer) GetRequestToken() string { + var tokenValue string + if e.config.IsReadHeader { + if tokenValue = e.webCtx.Request().Header(e.config.TokenName); tokenValue != "" { + return tokenValue + } + } + if e.config.IsReadCookie { + if tokenValue = e.webCtx.Request().Cookie(e.config.TokenName); tokenValue != "" { + return tokenValue + } + + } + if e.config.IsReadBody { + if tokenValue = e.webCtx.Request().PostForm(e.config.TokenName); tokenValue != "" { + return tokenValue + } + } + return tokenValue +} + +func (e *Enforcer) GetSession(id string) *model.Session { + if v := e.adapter.Get(e.spliceSessionKey(id)); v != nil { + session := v.(*model.Session) + return session + } + return nil +} + +func (e *Enforcer) SetSession(id string, session *model.Session, timeout int64) error { + err := e.adapter.Set(e.spliceSessionKey(id), session, timeout) + if err != nil { + return err + } + return nil +} + +func (e *Enforcer) deleteSession(id string) error { + err := e.adapter.Delete(e.spliceSessionKey(id)) + if err != nil { + return err + } + return nil +} diff --git a/enforcer_interface.go b/enforcer_interface.go new file mode 100644 index 0000000..37fda73 --- /dev/null +++ b/enforcer_interface.go @@ -0,0 +1,38 @@ +package token_go + +import ( + "github.com/weloe/token-go/ctx" + "github.com/weloe/token-go/log" + "github.com/weloe/token-go/model" + "github.com/weloe/token-go/persist" +) + +var _ IEnforcer = &Enforcer{} + +type IEnforcer interface { + Login(id string) (string, error) + LoginByModel(id string, loginModel *model.Login) (string, error) + Logout() error + IsLogin() (bool, error) + IsLoginById(id string) (bool, error) + GetLoginId() (string, error) + + Replaced(id string, device string) error + // Banned TODO + Banned(id string, service string) error + Kickout(id string, device string) error + + GetRequestToken() string + + SetType(t string) + GetType() string + SetContext(ctx ctx.Context) + GetAdapter() persist.Adapter + SetAdapter(adapter persist.Adapter) + SetWatcher(watcher persist.Watcher) + SetLogger(logger log.Logger) + EnableLog() + IsLogEnable() bool + GetSession(id string) *model.Session + SetSession(id string, session *model.Session, timeout int64) error +} diff --git a/enforcer_internal_api.go b/enforcer_internal_api.go new file mode 100644 index 0000000..bdcc3cb --- /dev/null +++ b/enforcer_internal_api.go @@ -0,0 +1,138 @@ +package token_go + +import ( + "errors" + "github.com/weloe/token-go/constant" + "github.com/weloe/token-go/model" + "strconv" +) + +// createLoginToken create by config.TokenConfig and model.Login +func (e *Enforcer) createLoginToken(id string, loginModel *model.Login) (string, error) { + tokenConfig := e.config + var tokenValue string + var err error + // if isConcurrent is false, + if !tokenConfig.IsConcurrent { + err = e.Replaced(id, loginModel.Device) + if err != nil { + return "", err + } + } + + // if loginModel set token, return directly + if loginModel.Token != "" { + return loginModel.Token, nil + } + + // if share token + if tokenConfig.IsConcurrent && tokenConfig.IsShare { + // reuse the previous token. + if v := e.GetSession(id); v != nil { + tokenValue = v.GetLastTokenByDevice(loginModel.Device) + if tokenValue != "" { + return tokenValue, nil + } + + } + } + + // create new token + tokenValue, err = e.generateFunc.Exec(tokenConfig.TokenStyle) + if err != nil { + return "", err + } + + return tokenValue, nil +} + +// responseToken set token to cookie or header +func (e *Enforcer) responseToken(tokenValue string, loginModel *model.Login) error { + + tokenConfig := e.config + + // set token to cookie + if tokenConfig.IsReadCookie { + cookieTimeout := tokenConfig.Timeout + if loginModel.IsLastingCookie { + cookieTimeout = -1 + } + // add cookie use tokenConfig.CookieConfig + e.webCtx.Response().AddCookie(tokenConfig.TokenName, + tokenValue, + tokenConfig.CookieConfig.Path, + tokenConfig.CookieConfig.Domain, + cookieTimeout) + } + + // set token to header + if loginModel.IsWriteHeader { + e.webCtx.Response().SetHeader(tokenConfig.TokenName, tokenValue) + } + + return nil +} + +// logoutByToken clear token info +func (e *Enforcer) logoutByToken(token string) error { + var err error + // delete token-id + id := e.adapter.GetStr(e.spliceTokenKey(token)) + if id == "" { + return errors.New("not logged in") + } + // delete token-id + err = e.adapter.Delete(e.spliceTokenKey(token)) + if err != nil { + return err + } + session := e.GetSession(id) + if session != nil { + // delete tokenSign + session.RemoveTokenSign(token) + } + // check TokenSignList length, if length == 0, delete this session + if session != nil && session.TokenSignList.Len() == 0 { + err = e.deleteSession(id) + if err != nil { + return err + } + } + + e.logger.Logout(e.loginType, id, token) + + if e.watcher != nil { + e.watcher.Logout(e.loginType, id, token) + } + + return nil +} + +// validateValue validate if value is proper +func (e *Enforcer) validateValue(str string) (bool, error) { + i, err := strconv.Atoi(str) + // if convert err return true + if err != nil { + return true, nil + } + if i == constant.BeReplaced { + return false, errors.New("this account is replaced") + } + if i == constant.BeKicked { + return false, errors.New("this account is kicked out") + } + if i == constant.BeBanned { + return false, errors.New("this account is banned") + } + return true, nil +} + +// spliceSessionKey splice session-id key +func (e *Enforcer) spliceSessionKey(id string) string { + return e.config.TokenName + ":" + e.loginType + ":session:" + id +} + +// spliceTokenKey splice token-id key +func (e *Enforcer) spliceTokenKey(id string) string { + return e.config.TokenName + ":" + e.loginType + ":token:" + id +} diff --git a/enforcer_test.go b/enforcer_test.go new file mode 100644 index 0000000..963bf0f --- /dev/null +++ b/enforcer_test.go @@ -0,0 +1,362 @@ +package token_go + +import ( + "bytes" + "fmt" + "github.com/weloe/token-go/config" + "github.com/weloe/token-go/constant" + "github.com/weloe/token-go/ctx" + httpCtx "github.com/weloe/token-go/ctx/go-http-context" + "github.com/weloe/token-go/log" + "github.com/weloe/token-go/model" + "github.com/weloe/token-go/persist" + "net/http" + "net/http/httptest" + "testing" +) + +func NewTestHttpContext(t *testing.T) (error, ctx.Context) { + reqBody := bytes.NewBufferString("test request body") + req, err := http.NewRequest("POST", "/test", reqBody) + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + req.Header.Set(constant.TokenName, "233") + + ctx := NewHttpContext(req, rr) + return err, ctx +} + +func TestNewEnforcer(t *testing.T) { + adapter := NewDefaultAdapter() + ctx := httpCtx.NewHttpContext(nil, nil) + + tokenConfig := &config.TokenConfig{ + TokenName: "testToken", + Timeout: 60, + IsReadCookie: true, + IsReadHeader: true, + IsReadBody: false, + IsConcurrent: false, + IsShare: false, + MaxLoginCount: -1, + } + logger := &log.DefaultLogger{} + + enforcer, err := NewEnforcer(tokenConfig, adapter, ctx) + enforcer.SetType("u") + if enforcer.GetType() != "u" { + t.Error("enforcer.loginType should be user") + } + enforcer.SetAdapter(adapter) + enforcer.SetLogger(logger) + enforcer.SetWatcher(nil) + enforcer.EnableLog() + if !enforcer.IsLogEnable() { + t.Errorf("enforcer.IsLogEnable() should be %v", enforcer.IsLogEnable()) + } + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if enforcer.config != *tokenConfig { + t.Error("enforcer.config should be equal to the passed tokenConfig parameter") + } + if enforcer.GetAdapter() != adapter { + t.Error("enforcer.adapter should be equal to the passed adapter parameter") + } + + if enforcer.webCtx != ctx { + t.Error("enforcer.webCtx should be equal to the passed ctx parameter") + } + +} + +func NewTestEnforcer(t *testing.T) (error, *Enforcer) { + reqBody := bytes.NewBufferString("test request body") + req, err := http.NewRequest("POST", "/test", reqBody) + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + req.Header.Set(constant.TokenName, "233") + + ctx := httpCtx.NewHttpContext(req, rr) + + adapter := persist.NewDefaultAdapter() + + tokenConfig := config.DefaultTokenConfig() + + enforcer, err := NewEnforcer(tokenConfig, adapter, ctx) + return err, enforcer +} + +func NewTestConcurrentEnforcer(t *testing.T) (error, *Enforcer) { + reqBody := bytes.NewBufferString("test request body") + req, err := http.NewRequest("POST", "/test", reqBody) + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + req.Header.Set(constant.TokenName, "233") + + ctx := httpCtx.NewHttpContext(req, rr) + + adapter := persist.NewDefaultAdapter() + + tokenConfig := config.DefaultTokenConfig() + tokenConfig.IsConcurrent = true + tokenConfig.IsShare = false + + enforcer, err := NewEnforcer(tokenConfig, adapter, ctx) + return err, enforcer +} + +func NewTestNotConcurrentEnforcer(t *testing.T) (error, *Enforcer) { + reqBody := bytes.NewBufferString("test request body") + req, err := http.NewRequest("POST", "/test", reqBody) + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + req.Header.Set(constant.TokenName, "233") + + ctx := httpCtx.NewHttpContext(req, rr) + + adapter := persist.NewDefaultAdapter() + + tokenConfig := config.DefaultTokenConfig() + tokenConfig.IsConcurrent = false + tokenConfig.IsShare = false + + enforcer, err := NewEnforcer(tokenConfig, adapter, ctx) + return err, enforcer +} + +func TestNewEnforcerByFile(t *testing.T) { + err, ctx := NewTestHttpContext(t) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + adapter := persist.NewDefaultAdapter() + conf := "testConf" + + enforcer, err := NewEnforcerByFile(conf, adapter, ctx) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if enforcer.conf != conf { + t.Error("enforcer.conf should be equal to the passed conf parameter") + } + + if enforcer.adapter != adapter { + t.Error("enforcer.adapter should be equal to the passed adapter parameter") + } + if enforcer.webCtx != ctx { + t.Error("enforcer.webCtx should be equal to the passed ctx parameter") + } +} + +func TestEnforcer_Login(t *testing.T) { + err, enforcer := NewTestEnforcer(t) + enforcer.EnableLog() + if err != nil { + t.Errorf("NewEnforcer() failed: %v", err) + } + loginId := "1" + _, err = enforcer.Login(loginId) + if err != nil { + t.Errorf("LoginByModel() failed: %v", err) + } + + _, err = enforcer.LoginByModel(loginId, model.DefaultLoginModel()) + if err != nil { + t.Errorf("LoginByModel() failed: %v", err) + } + login, err := enforcer.IsLoginById(loginId) + if err != nil { + t.Errorf("IsLoginById() failed: err should be nil now: %v", err) + } + if !login { + t.Errorf("IsLoginById() failed: IsLoginById() = %v", login) + } + err = enforcer.Replaced("1", "") + if err != nil { + t.Errorf("Replaced() failed: %v", err) + } + session := enforcer.GetSession("1") + t.Logf("id = %v session.tokenSign len = %v", "1", session.TokenSignList.Len()) + + login, err = enforcer.IsLoginById(loginId) + if err != nil { + t.Logf("%v error: %v", login, err) + } + if login { + t.Errorf("IsLoginById() failed: IsLoginById() = %v", login) + } + +} + +func TestEnforcer_GetLoginId(t *testing.T) { + err, enforcer := NewTestEnforcer(t) + if err != nil { + t.Errorf("NewEnforcer() failed: %v", err) + } + loginModel := model.DefaultLoginModel() + loginModel.Token = "233" + _, err = enforcer.LoginByModel("id", loginModel) + if err != nil { + t.Errorf("Login() failed: %v", err) + } + + id, err := enforcer.GetLoginId() + if err != nil { + t.Errorf("GetLoginId() failed: %v", err) + } + t.Logf("LoginId = %v", id) + if id != "id" { + t.Errorf("GetLoginId() failed: %v", err) + } + +} + +func TestEnforcer_Logout(t *testing.T) { + err, enforcer := NewTestEnforcer(t) + if err != nil { + t.Errorf("NewEnforcer() failed: %v", err) + } + + loginModel := model.DefaultLoginModel() + loginModel.Token = "233" + token, err := enforcer.LoginByModel("id", loginModel) + if token != "233" { + t.Errorf("LoginByModel() failed: unexpected token value %s, want '233' ", token) + } + if err != nil { + t.Errorf("Login() failed: %v", err) + } + + err = enforcer.Logout() + if err != nil { + t.Errorf("Logout() failed: %v", err) + } + + login, err := enforcer.IsLogin() + if login { + t.Errorf("IsLogin() failed: unexpected value %v", login) + } + if err != nil { + t.Errorf("err: %v", err) + } +} + +func TestEnforcer_Kickout(t *testing.T) { + err, enforcer := NewTestEnforcer(t) + if err != nil { + t.Errorf("NewEnforcer() failed: %v", err) + } + + loginModel := model.DefaultLoginModel() + loginModel.Token = "233" + _, err = enforcer.LoginByModel("id", loginModel) + if err != nil { + t.Errorf("Login() failed: %v", err) + } + + err = enforcer.Kickout("id", "") + if err != nil { + t.Errorf("Kickout() failed %v", err) + } + + session := enforcer.GetSession("id") + if session != nil { + t.Errorf("unexpected session value %v", session) + } + login, err := enforcer.IsLogin() + if login { + t.Errorf("IsLogin() failed: unexpected value %v", login) + } + n := fmt.Sprintf("%v", err) + if n != "this account is kicked out" { + t.Errorf("IsLogin() failed: unexpected error value %v", err) + } + +} + +func TestEnforcerNotConcurrentNotShareLogin(t *testing.T) { + err, enforcer := NewTestNotConcurrentEnforcer(t) + if err != nil { + t.Errorf("NewEnforcer() failed: %v", err) + } + + loginModel := model.DefaultLoginModel() + + for i := 0; i < 4; i++ { + _, err = enforcer.LoginByModel("id", loginModel) + if err != nil { + t.Errorf("Login() failed: %v", err) + } + } + session := enforcer.GetSession("id") + if session.TokenSignList.Len() != 1 { + t.Errorf("Login() failed: unexpected session.TokenSignList length = %v", session.TokenSignList.Len()) + } + +} + +func TestEnforcer_ConcurrentShare(t *testing.T) { + err, enforcer := NewTestEnforcer(t) + if err != nil { + t.Errorf("NewEnforcer() failed: %v", err) + } + + loginModel := model.DefaultLoginModel() + for i := 0; i < 5; i++ { + _, err = enforcer.LoginByModel("id", loginModel) + if err != nil { + t.Errorf("Login() failed: %v", err) + } + } + session := enforcer.GetSession("id") + if session.TokenSignList.Len() != 1 { + t.Errorf("Login() failed: unexpected session.TokenSignList length = %v", session.TokenSignList.Len()) + } + +} +func TestEnforcer_ConcurrentNotShareMultiLogin(t *testing.T) { + err, enforcer := NewTestConcurrentEnforcer(t) + if err != nil { + t.Errorf("NewEnforcer() failed: %v", err) + } + + loginModel := model.DefaultLoginModel() + for i := 0; i < 14; i++ { + _, err = enforcer.LoginByModel("id", loginModel) + if err != nil { + t.Errorf("Login() failed: %v", err) + } + } + session := enforcer.GetSession("id") + if session.TokenSignList.Len() != 12 { + t.Errorf("Login() failed: unexpected session.TokenSignList length = %v", session.TokenSignList.Len()) + } + +} + +func TestNewDefaultEnforcer(t *testing.T) { + err, ctx := NewTestHttpContext(t) + if err != nil { + t.Errorf("NewTestHttpContext() failed: %v", err) + } + + enforcer, err := NewDefaultEnforcer(persist.NewDefaultAdapter(), ctx) + if err != nil || enforcer == nil { + t.Errorf("NewEnforcer() failed: %v", err) + } +} diff --git a/first.go b/first.go deleted file mode 100644 index 0face58..0000000 --- a/first.go +++ /dev/null @@ -1,8 +0,0 @@ -package token_go - -import "fmt" - -func First() { - version := "1.0" - _ = fmt.Sprintf("first init success %s", version) -} diff --git a/first_test.go b/first_test.go deleted file mode 100644 index e4ded9b..0000000 --- a/first_test.go +++ /dev/null @@ -1,14 +0,0 @@ -package token_go - -import "testing" - -func TestFirst(t *testing.T) { - First() -} - -func BenchmarkFirst(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - First() - } -} diff --git a/model/session.go b/model/session.go new file mode 100644 index 0000000..e25a74b --- /dev/null +++ b/model/session.go @@ -0,0 +1,109 @@ +package model + +import ( + "container/list" + "sync" + "time" +) + +type TokenSign struct { + Value string + Device string +} + +type Session struct { + Id string + Type string + LoginType string + LoginId string + Token string + CreateTime int64 + DataMap *sync.Map + TokenSignList *list.List +} + +func DefaultSession(id string) *Session { + return &Session{ + Id: id, + CreateTime: time.Now().UnixMilli(), + } +} + +func NewSession(id string, sessionType string, loginId string) *Session { + return &Session{ + Id: id, + Type: sessionType, + LoginId: loginId, + CreateTime: time.Now().UnixMilli(), + TokenSignList: list.New(), + } +} + +// GetFilterTokenSign filter by TokenSign.Device from all TokenSign +func (s *Session) GetFilterTokenSign(device string) *list.List { + if device == "" { + return s.GetTokenSignListCopy() + } + copyList := list.New() + for e := s.TokenSignList.Front(); e != nil; e = e.Next() { + if tokenSign, ok := e.Value.(*TokenSign); ok && tokenSign.Device == device { + copyList.PushBack(tokenSign) + } + } + return copyList +} + +// GetTokenSignListCopy find all TokenSign +func (s *Session) GetTokenSignListCopy() *list.List { + copyList := list.New() + for e := s.TokenSignList.Front(); e != nil; e = e.Next() { + copyList.PushBack(e.Value) + } + return copyList +} + +// GetTokenSign find TokenSign by TokenSign.Value +func (s *Session) GetTokenSign(tokenValue string) *TokenSign { + if tokenValue == "" { + return nil + } + for e := s.TokenSignList.Front(); e != nil; e = e.Next() { + if tokenSign, ok := e.Value.(*TokenSign); ok && tokenSign.Value == tokenValue { + return tokenSign + } + } + return nil +} + +// AddTokenSign add TokenSign +func (s *Session) AddTokenSign(sign *TokenSign) { + if s.GetTokenSign(sign.Value) != nil { + return + } + s.TokenSignList.PushBack(sign) +} + +// RemoveTokenSign remove TokenSign by TokenSign.Value +func (s *Session) RemoveTokenSign(tokenValue string) bool { + if tokenValue == "" { + return false + } + for e := s.TokenSignList.Front(); e != nil; e = e.Next() { + if tokenSign, ok := e.Value.(*TokenSign); ok && tokenSign.Value == tokenValue { + s.TokenSignList.Remove(e) + } + } + return true +} + +// GetLastTokenByDevice get TokenSign.Value by device +func (s *Session) GetLastTokenByDevice(device string) string { + if device == "" { + return "" + } + tokenSignList := s.GetFilterTokenSign(device) + if tokenSign, ok := tokenSignList.Back().Value.(*TokenSign); ok && tokenSign.Device == device { + return tokenSign.Value + } + return "" +}