From 900e86123a287cc9eceb51f2602511c82c98a4d4 Mon Sep 17 00:00:00 2001 From: weloe <1345895607@qq.com> Date: Tue, 16 May 2023 21:57:21 +0800 Subject: [PATCH] refactor: update tokenSignList to slice --- enforcer.go | 36 ++++++++++++++++---------------- enforcer_internal_api.go | 2 +- enforcer_test.go | 16 +++++++------- model/session.go | 45 +++++++++++++++++++++++++++++----------- 4 files changed, 61 insertions(+), 38 deletions(-) diff --git a/enforcer.go b/enforcer.go index 6977297..cddce42 100644 --- a/enforcer.go +++ b/enforcer.go @@ -214,22 +214,24 @@ func (e *Enforcer) LoginByModel(id string, loginModel *model.Login, ctx ctx.Cont if tokenConfig.MaxLoginCount != -1 { if session = e.GetSession(id); session != nil { // logout account until loginCount == maxLoginCount if loginCount > maxLoginCount - for element, i := session.TokenSignList.Front(), 0; element != nil && i < session.TokenSignList.Len()-int(tokenConfig.MaxLoginCount); element, i = element.Next(), i+1 { - tokenSign := element.Value.(*model.TokenSign) - // delete tokenSign - session.RemoveTokenSign(tokenSign.Value) - err = e.updateSession(id, session) - if err != nil { - return "", err - } - // delete token-id - err = e.adapter.Delete(e.spliceTokenKey(tokenSign.Value)) - if err != nil { - return "", err + for _, tokenSign := range session.TokenSignList { + if session.TokenSignSize() > int(tokenConfig.MaxLoginCount) { + // delete tokenSign + session.RemoveTokenSign(tokenSign.Value) + err = e.updateSession(id, session) + if err != nil { + return "", err + } + // delete token-id + err = e.adapter.Delete(e.spliceTokenKey(tokenSign.Value)) + if err != nil { + return "", err + } } } + // check TokenSignList length, if length == 0, delete this session - if session != nil && session.TokenSignList.Len() == 0 { + if session != nil && session.TokenSignSize() == 0 { err = e.deleteSession(id) if err != nil { return "", err @@ -306,8 +308,7 @@ func (e *Enforcer) IsLoginById(id string) (bool, error) { session := e.GetSession(id) if session != nil { l := session.TokenSignList - for element := l.Back(); element != nil; element = element.Prev() { - tokenSign := element.Value.(*model.TokenSign) + for _, tokenSign := range l { str := e.adapter.GetStr(e.spliceTokenKey(tokenSign.Value)) if str == "" { continue @@ -320,7 +321,6 @@ func (e *Enforcer) IsLoginById(id string) (bool, error) { if value { return true, nil } - } } @@ -365,7 +365,7 @@ func (e *Enforcer) GetLoginId(ctx ctx.Context) (string, error) { func (e *Enforcer) GetLoginCount(id string) int { if session := e.GetSession(id); session != nil { - return session.TokenSignList.Len() + return session.TokenSignSize() } return 0 } @@ -407,7 +407,7 @@ func (e *Enforcer) Kickout(id string, device string) error { } // check TokenSignList length, if length == 0, delete this session - if session != nil && session.TokenSignList.Len() == 0 { + if session != nil && session.TokenSignSize() == 0 { err := e.deleteSession(id) if err != nil { return err diff --git a/enforcer_internal_api.go b/enforcer_internal_api.go index fc207fd..b5b4e3a 100644 --- a/enforcer_internal_api.go +++ b/enforcer_internal_api.go @@ -99,7 +99,7 @@ func (e *Enforcer) logoutByToken(token string) error { } } // check TokenSignList length, if length == 0, delete this session - if session != nil && session.TokenSignList.Len() == 0 { + if session != nil && session.TokenSignSize() == 0 { err = e.deleteSession(id) if err != nil { return err diff --git a/enforcer_test.go b/enforcer_test.go index 1996305..a6896a4 100644 --- a/enforcer_test.go +++ b/enforcer_test.go @@ -190,7 +190,7 @@ func TestEnforcer_Login(t *testing.T) { t.Errorf("Replaced() failed: %v", err) } session := enforcer.GetSession("1") - t.Logf("id = %v session.tokenSign len = %v", "1", session.TokenSignList.Len()) + t.Logf("id = %v session.tokenSign len = %v", "1", session.TokenSignSize()) err = enforcer.CheckLogin(ctx) if err == nil { @@ -307,8 +307,8 @@ func TestEnforcerNotConcurrentNotShareLogin(t *testing.T) { } } session := enforcer.GetSession("id") - if session.TokenSignList.Len() != 1 { - t.Errorf("Login() failed: unexpected session.TokenSignList length = %v", session.TokenSignList.Len()) + if session.TokenSignSize() != 1 { + t.Errorf("Login() failed: unexpected session.TokenSignList length = %v", session.TokenSignSize()) } } @@ -327,8 +327,10 @@ func TestEnforcer_ConcurrentShare(t *testing.T) { } } session := enforcer.GetSession("id") - if session.TokenSignList.Len() != 1 { - t.Errorf("Login() failed: unexpected session.TokenSignList length = %v", session.TokenSignList.Len()) + t.Logf("Login(): session.TokenSignList length = %v", session.TokenSignSize()) + + if session.TokenSignSize() != 1 { + t.Errorf("Login() failed: unexpected session.TokenSignList length = %v", session.TokenSignSize()) } } @@ -346,8 +348,8 @@ func TestEnforcer_ConcurrentNotShareMultiLogin(t *testing.T) { } } session := enforcer.GetSession("id") - if session.TokenSignList.Len() != 12 { - t.Errorf("Login() failed: unexpected session.TokenSignList length = %v", session.TokenSignList.Len()) + if session.TokenSignSize() != 12 { + t.Errorf("Login() failed: unexpected session.TokenSignList length = %v", session.TokenSignSize()) } } diff --git a/model/session.go b/model/session.go index e25a74b..2e5ae47 100644 --- a/model/session.go +++ b/model/session.go @@ -2,6 +2,7 @@ package model import ( "container/list" + "encoding/json" "sync" "time" ) @@ -19,7 +20,7 @@ type Session struct { Token string CreateTime int64 DataMap *sync.Map - TokenSignList *list.List + TokenSignList []*TokenSign `json:"TokenSignList"` } func DefaultSession(id string) *Session { @@ -35,7 +36,7 @@ func NewSession(id string, sessionType string, loginId string) *Session { Type: sessionType, LoginId: loginId, CreateTime: time.Now().UnixMilli(), - TokenSignList: list.New(), + TokenSignList: make([]*TokenSign, 0), } } @@ -45,8 +46,8 @@ func (s *Session) GetFilterTokenSign(device string) *list.List { return s.GetTokenSignListCopy() } copyList := list.New() - for e := s.TokenSignList.Front(); e != nil; e = e.Next() { - if tokenSign, ok := e.Value.(*TokenSign); ok && tokenSign.Device == device { + for _, tokenSign := range s.TokenSignList { + if tokenSign.Device == device { copyList.PushBack(tokenSign) } } @@ -56,8 +57,8 @@ func (s *Session) GetFilterTokenSign(device string) *list.List { // GetTokenSignListCopy find all TokenSign func (s *Session) GetTokenSignListCopy() *list.List { copyList := list.New() - for e := s.TokenSignList.Front(); e != nil; e = e.Next() { - copyList.PushBack(e.Value) + for _, tokenSign := range s.TokenSignList { + copyList.PushBack(tokenSign) } return copyList } @@ -67,8 +68,8 @@ func (s *Session) GetTokenSign(tokenValue string) *TokenSign { if tokenValue == "" { return nil } - for e := s.TokenSignList.Front(); e != nil; e = e.Next() { - if tokenSign, ok := e.Value.(*TokenSign); ok && tokenSign.Value == tokenValue { + for _, tokenSign := range s.TokenSignList { + if tokenSign.Value == tokenValue { return tokenSign } } @@ -80,7 +81,7 @@ func (s *Session) AddTokenSign(sign *TokenSign) { if s.GetTokenSign(sign.Value) != nil { return } - s.TokenSignList.PushBack(sign) + s.TokenSignList = append(s.TokenSignList, sign) } // RemoveTokenSign remove TokenSign by TokenSign.Value @@ -88,14 +89,20 @@ func (s *Session) RemoveTokenSign(tokenValue string) bool { if tokenValue == "" { return false } - for e := s.TokenSignList.Front(); e != nil; e = e.Next() { - if tokenSign, ok := e.Value.(*TokenSign); ok && tokenSign.Value == tokenValue { - s.TokenSignList.Remove(e) + for i, tokenSign := range s.TokenSignList { + if tokenSign.Value == tokenValue { + // delete + s.RemoveTokenSignByIndex(i) + return true } } return true } +func (s *Session) RemoveTokenSignByIndex(i int) { + s.TokenSignList = append(s.TokenSignList[:i], s.TokenSignList[i+1:]...) +} + // GetLastTokenByDevice get TokenSign.Value by device func (s *Session) GetLastTokenByDevice(device string) string { if device == "" { @@ -107,3 +114,17 @@ func (s *Session) GetLastTokenByDevice(device string) string { } return "" } + +// TokenSignSize get tokenSign size +func (s *Session) TokenSignSize() int { + return len(s.TokenSignList) +} + +// Json return json string +func (s *Session) Json() string { + b, err := json.Marshal(s) + if err != nil { + return "" + } + return string(b) +}