refactor: update tokenSignList to slice

This commit is contained in:
weloe
2023-05-16 21:57:21 +08:00
parent 29d76d7bfb
commit 900e86123a
4 changed files with 61 additions and 38 deletions

View File

@@ -214,22 +214,24 @@ func (e *Enforcer) LoginByModel(id string, loginModel *model.Login, ctx ctx.Cont
if tokenConfig.MaxLoginCount != -1 { if tokenConfig.MaxLoginCount != -1 {
if session = e.GetSession(id); session != nil { if session = e.GetSession(id); session != nil {
// logout account until loginCount == maxLoginCount if loginCount > maxLoginCount // logout account until loginCount == maxLoginCount if loginCount > maxLoginCount
for element, i := session.TokenSignList.Front(), 0; element != nil && i < session.TokenSignList.Len()-int(tokenConfig.MaxLoginCount); element, i = element.Next(), i+1 { for _, tokenSign := range session.TokenSignList {
tokenSign := element.Value.(*model.TokenSign) if session.TokenSignSize() > int(tokenConfig.MaxLoginCount) {
// delete tokenSign // delete tokenSign
session.RemoveTokenSign(tokenSign.Value) session.RemoveTokenSign(tokenSign.Value)
err = e.updateSession(id, session) err = e.updateSession(id, session)
if err != nil { if err != nil {
return "", err return "", err
} }
// delete token-id // delete token-id
err = e.adapter.Delete(e.spliceTokenKey(tokenSign.Value)) err = e.adapter.Delete(e.spliceTokenKey(tokenSign.Value))
if err != nil { if err != nil {
return "", err return "", err
}
} }
} }
// check TokenSignList length, if length == 0, delete this session // check TokenSignList length, if length == 0, delete this session
if session != nil && session.TokenSignList.Len() == 0 { if session != nil && session.TokenSignSize() == 0 {
err = e.deleteSession(id) err = e.deleteSession(id)
if err != nil { if err != nil {
return "", err return "", err
@@ -306,8 +308,7 @@ func (e *Enforcer) IsLoginById(id string) (bool, error) {
session := e.GetSession(id) session := e.GetSession(id)
if session != nil { if session != nil {
l := session.TokenSignList l := session.TokenSignList
for element := l.Back(); element != nil; element = element.Prev() { for _, tokenSign := range l {
tokenSign := element.Value.(*model.TokenSign)
str := e.adapter.GetStr(e.spliceTokenKey(tokenSign.Value)) str := e.adapter.GetStr(e.spliceTokenKey(tokenSign.Value))
if str == "" { if str == "" {
continue continue
@@ -320,7 +321,6 @@ func (e *Enforcer) IsLoginById(id string) (bool, error) {
if value { if value {
return true, nil return true, nil
} }
} }
} }
@@ -365,7 +365,7 @@ func (e *Enforcer) GetLoginId(ctx ctx.Context) (string, error) {
func (e *Enforcer) GetLoginCount(id string) int { func (e *Enforcer) GetLoginCount(id string) int {
if session := e.GetSession(id); session != nil { if session := e.GetSession(id); session != nil {
return session.TokenSignList.Len() return session.TokenSignSize()
} }
return 0 return 0
} }
@@ -407,7 +407,7 @@ func (e *Enforcer) Kickout(id string, device string) error {
} }
// check TokenSignList length, if length == 0, delete this session // check TokenSignList length, if length == 0, delete this session
if session != nil && session.TokenSignList.Len() == 0 { if session != nil && session.TokenSignSize() == 0 {
err := e.deleteSession(id) err := e.deleteSession(id)
if err != nil { if err != nil {
return err return err

View File

@@ -99,7 +99,7 @@ func (e *Enforcer) logoutByToken(token string) error {
} }
} }
// check TokenSignList length, if length == 0, delete this session // check TokenSignList length, if length == 0, delete this session
if session != nil && session.TokenSignList.Len() == 0 { if session != nil && session.TokenSignSize() == 0 {
err = e.deleteSession(id) err = e.deleteSession(id)
if err != nil { if err != nil {
return err return err

View File

@@ -190,7 +190,7 @@ func TestEnforcer_Login(t *testing.T) {
t.Errorf("Replaced() failed: %v", err) t.Errorf("Replaced() failed: %v", err)
} }
session := enforcer.GetSession("1") session := enforcer.GetSession("1")
t.Logf("id = %v session.tokenSign len = %v", "1", session.TokenSignList.Len()) t.Logf("id = %v session.tokenSign len = %v", "1", session.TokenSignSize())
err = enforcer.CheckLogin(ctx) err = enforcer.CheckLogin(ctx)
if err == nil { if err == nil {
@@ -307,8 +307,8 @@ func TestEnforcerNotConcurrentNotShareLogin(t *testing.T) {
} }
} }
session := enforcer.GetSession("id") session := enforcer.GetSession("id")
if session.TokenSignList.Len() != 1 { if session.TokenSignSize() != 1 {
t.Errorf("Login() failed: unexpected session.TokenSignList length = %v", session.TokenSignList.Len()) t.Errorf("Login() failed: unexpected session.TokenSignList length = %v", session.TokenSignSize())
} }
} }
@@ -327,8 +327,10 @@ func TestEnforcer_ConcurrentShare(t *testing.T) {
} }
} }
session := enforcer.GetSession("id") session := enforcer.GetSession("id")
if session.TokenSignList.Len() != 1 { t.Logf("Login(): session.TokenSignList length = %v", session.TokenSignSize())
t.Errorf("Login() failed: unexpected session.TokenSignList length = %v", session.TokenSignList.Len())
if session.TokenSignSize() != 1 {
t.Errorf("Login() failed: unexpected session.TokenSignList length = %v", session.TokenSignSize())
} }
} }
@@ -346,8 +348,8 @@ func TestEnforcer_ConcurrentNotShareMultiLogin(t *testing.T) {
} }
} }
session := enforcer.GetSession("id") session := enforcer.GetSession("id")
if session.TokenSignList.Len() != 12 { if session.TokenSignSize() != 12 {
t.Errorf("Login() failed: unexpected session.TokenSignList length = %v", session.TokenSignList.Len()) t.Errorf("Login() failed: unexpected session.TokenSignList length = %v", session.TokenSignSize())
} }
} }

View File

@@ -2,6 +2,7 @@ package model
import ( import (
"container/list" "container/list"
"encoding/json"
"sync" "sync"
"time" "time"
) )
@@ -19,7 +20,7 @@ type Session struct {
Token string Token string
CreateTime int64 CreateTime int64
DataMap *sync.Map DataMap *sync.Map
TokenSignList *list.List TokenSignList []*TokenSign `json:"TokenSignList"`
} }
func DefaultSession(id string) *Session { func DefaultSession(id string) *Session {
@@ -35,7 +36,7 @@ func NewSession(id string, sessionType string, loginId string) *Session {
Type: sessionType, Type: sessionType,
LoginId: loginId, LoginId: loginId,
CreateTime: time.Now().UnixMilli(), CreateTime: time.Now().UnixMilli(),
TokenSignList: list.New(), TokenSignList: make([]*TokenSign, 0),
} }
} }
@@ -45,8 +46,8 @@ func (s *Session) GetFilterTokenSign(device string) *list.List {
return s.GetTokenSignListCopy() return s.GetTokenSignListCopy()
} }
copyList := list.New() copyList := list.New()
for e := s.TokenSignList.Front(); e != nil; e = e.Next() { for _, tokenSign := range s.TokenSignList {
if tokenSign, ok := e.Value.(*TokenSign); ok && tokenSign.Device == device { if tokenSign.Device == device {
copyList.PushBack(tokenSign) copyList.PushBack(tokenSign)
} }
} }
@@ -56,8 +57,8 @@ func (s *Session) GetFilterTokenSign(device string) *list.List {
// GetTokenSignListCopy find all TokenSign // GetTokenSignListCopy find all TokenSign
func (s *Session) GetTokenSignListCopy() *list.List { func (s *Session) GetTokenSignListCopy() *list.List {
copyList := list.New() copyList := list.New()
for e := s.TokenSignList.Front(); e != nil; e = e.Next() { for _, tokenSign := range s.TokenSignList {
copyList.PushBack(e.Value) copyList.PushBack(tokenSign)
} }
return copyList return copyList
} }
@@ -67,8 +68,8 @@ func (s *Session) GetTokenSign(tokenValue string) *TokenSign {
if tokenValue == "" { if tokenValue == "" {
return nil return nil
} }
for e := s.TokenSignList.Front(); e != nil; e = e.Next() { for _, tokenSign := range s.TokenSignList {
if tokenSign, ok := e.Value.(*TokenSign); ok && tokenSign.Value == tokenValue { if tokenSign.Value == tokenValue {
return tokenSign return tokenSign
} }
} }
@@ -80,7 +81,7 @@ func (s *Session) AddTokenSign(sign *TokenSign) {
if s.GetTokenSign(sign.Value) != nil { if s.GetTokenSign(sign.Value) != nil {
return return
} }
s.TokenSignList.PushBack(sign) s.TokenSignList = append(s.TokenSignList, sign)
} }
// RemoveTokenSign remove TokenSign by TokenSign.Value // RemoveTokenSign remove TokenSign by TokenSign.Value
@@ -88,14 +89,20 @@ func (s *Session) RemoveTokenSign(tokenValue string) bool {
if tokenValue == "" { if tokenValue == "" {
return false return false
} }
for e := s.TokenSignList.Front(); e != nil; e = e.Next() { for i, tokenSign := range s.TokenSignList {
if tokenSign, ok := e.Value.(*TokenSign); ok && tokenSign.Value == tokenValue { if tokenSign.Value == tokenValue {
s.TokenSignList.Remove(e) // delete
s.RemoveTokenSignByIndex(i)
return true
} }
} }
return true return true
} }
func (s *Session) RemoveTokenSignByIndex(i int) {
s.TokenSignList = append(s.TokenSignList[:i], s.TokenSignList[i+1:]...)
}
// GetLastTokenByDevice get TokenSign.Value by device // GetLastTokenByDevice get TokenSign.Value by device
func (s *Session) GetLastTokenByDevice(device string) string { func (s *Session) GetLastTokenByDevice(device string) string {
if device == "" { if device == "" {
@@ -107,3 +114,17 @@ func (s *Session) GetLastTokenByDevice(device string) string {
} }
return "" return ""
} }
// TokenSignSize get tokenSign size
func (s *Session) TokenSignSize() int {
return len(s.TokenSignList)
}
// Json return json string
func (s *Session) Json() string {
b, err := json.Marshal(s)
if err != nil {
return ""
}
return string(b)
}