diff --git a/enforcer.go b/enforcer.go index 3b83060..590346e 100644 --- a/enforcer.go +++ b/enforcer.go @@ -587,59 +587,14 @@ func (e *Enforcer) AddTokenGenerateFun(tokenStyle string, f model.GenerateFunc) } func (e *Enforcer) GetSession(id string) *model.Session { - if v := e.adapter.Get(e.spliceSessionKey(id)); v != nil { - if s := e.sessionUnSerialize(v); s != nil { - return s - } else { - session, ok := v.(*model.Session) - if !ok { - return nil - } - return session - } + if v := e.adapter.Get(e.spliceSessionKey(id), util.GetType(&model.Session{})); v != nil { + return v.(*model.Session) } return nil } -func (e *Enforcer) sessionUnSerialize(v interface{}) *model.Session { - // get serializer - serializer, ok := e.adapter.(persist.SerializerAdapter) - if !ok { - return nil - } - - // to bytes - bytes, err := util.InterfaceToBytes(v) - if err != nil { - return nil - } - session := &model.Session{} - err = serializer.UnSerialize(bytes, session) - if err != nil { - return nil - } - return session -} - -func (e *Enforcer) sessionSerialize(v *model.Session) ([]byte, error) { - serializer, ok := e.adapter.(persist.SerializerAdapter) - - if !ok { - return nil, nil - } - return serializer.Serialize(v) -} - func (e *Enforcer) SetSession(id string, session *model.Session, timeout int64) error { - bytes, err := e.sessionSerialize(session) - if err != nil { - return err - } - if bytes != nil { - err = e.adapter.Set(e.spliceSessionKey(id), bytes, timeout) - } else { - err = e.adapter.Set(e.spliceSessionKey(id), session, timeout) - } + err := e.adapter.Set(e.spliceSessionKey(id), session, timeout) if err != nil { return err } @@ -655,15 +610,7 @@ func (e *Enforcer) DeleteSession(id string) error { } func (e *Enforcer) UpdateSession(id string, session *model.Session) error { - bytes, err := e.sessionSerialize(session) - if err != nil { - return err - } - if bytes != nil { - err = e.adapter.Update(e.spliceSessionKey(id), bytes) - } else { - err = e.adapter.Update(e.spliceSessionKey(id), session) - } + err := e.adapter.Update(e.spliceSessionKey(id), session) if err != nil { return err } diff --git a/persist/adapter.go b/persist/adapter.go index 6088bc1..5e52aa2 100644 --- a/persist/adapter.go +++ b/persist/adapter.go @@ -1,5 +1,7 @@ package persist +import "reflect" + type Adapter interface { // GetStr string operate string value @@ -15,8 +17,9 @@ type Adapter interface { // UpdateStrTimeout update expire time UpdateStrTimeout(key string, timeout int64) error - // Get get interface{} - Get(key string) interface{} + // Get returns interface{} + // If serializer != nil, need to input reflect.Type, used to serializer to deserialize + Get(key string, t ...reflect.Type) interface{} // Set store interface{} Set(key string, value interface{}, timeout int64) error // Update only update interface{} value @@ -27,4 +30,8 @@ type Adapter interface { GetTimeout(key string) int64 // UpdateTimeout update timeout UpdateTimeout(key string, timeout int64) error + + // SetSerializer used to serialize and deserialize + // Serialize when call Set() or Update(), deserialize when call Get(key,t) + SetSerializer(serializer Serializer) } diff --git a/persist/default_adapter.go b/persist/default_adapter.go index bd9d570..0410255 100644 --- a/persist/default_adapter.go +++ b/persist/default_adapter.go @@ -4,15 +4,19 @@ import ( "errors" "fmt" "github.com/weloe/token-go/constant" + "github.com/weloe/token-go/util" + "log" + "reflect" "strings" "sync" "time" ) type DefaultAdapter struct { - dataMap *sync.Map - expireMap *sync.Map - once sync.Once + dataMap *sync.Map + expireMap *sync.Map + once sync.Once + serializer Serializer } var _ Adapter = (*DefaultAdapter)(nil) @@ -24,6 +28,10 @@ func NewDefaultAdapter() *DefaultAdapter { } } +func (d *DefaultAdapter) SetSerializer(serializer Serializer) { + d.serializer = serializer +} + // GetStr if key is expired delete it before get data func (d *DefaultAdapter) GetStr(key string) string { _ = d.getExpireAndDelete(key) @@ -80,17 +88,45 @@ func (d *DefaultAdapter) UpdateStrTimeout(key string, timeout int64) error { // // -func (d *DefaultAdapter) Get(key string) interface{} { +func (d *DefaultAdapter) Get(key string, t ...reflect.Type) interface{} { d.getExpireAndDelete(key) value, _ := d.dataMap.Load(key) - return value + + if d.serializer == nil { + return value + } + if t == nil && len(t) == 0 { + return nil + } + bytes, err := util.InterfaceToBytes(value) + if err != nil { + log.Printf("Adapter.Get() failed: %v", err) + return nil + } + instance := reflect.New(t[0].Elem()).Interface() + err = d.serializer.UnSerialize(bytes, instance) + if err != nil { + log.Printf("Adapter.Get() failed: %v", err) + return nil + } + + return instance } func (d *DefaultAdapter) Set(key string, value interface{}, timeout int64) error { if timeout == 0 || timeout <= constant.NotValueExpire { return errors.New("args timeout error") } - d.dataMap.Store(key, value) + + if d.serializer != nil { + bytes, err := d.serializer.Serialize(value) + if err != nil { + return err + } + d.dataMap.Store(key, bytes) + } else { + d.dataMap.Store(key, value) + } if timeout == constant.NeverExpire { d.expireMap.Store(key, constant.NeverExpire) @@ -105,7 +141,15 @@ func (d *DefaultAdapter) Update(key string, value interface{}) error { if timeout == constant.NotValueExpire { return errors.New("key does not exist") } - d.dataMap.Store(key, value) + if d.serializer != nil { + bytes, err := d.serializer.Serialize(value) + if err != nil { + return err + } + d.dataMap.Store(key, bytes) + } else { + d.dataMap.Store(key, value) + } return nil } diff --git a/persist/empty_adapter.go b/persist/empty_adapter.go index fb92066..0128075 100644 --- a/persist/empty_adapter.go +++ b/persist/empty_adapter.go @@ -1,11 +1,18 @@ package persist +import "reflect" + var _ Adapter = (*EmptyAdapter)(nil) // EmptyAdapter empty adapter for extension to init enforcer type EmptyAdapter struct { } +func (e *EmptyAdapter) SetSerializer(serializer Serializer) { + //TODO implement me + panic("implement me") +} + func (e *EmptyAdapter) GetStr(key string) string { //TODO implement me panic("implement me") @@ -36,7 +43,7 @@ func (e *EmptyAdapter) UpdateStrTimeout(key string, timeout int64) error { panic("implement me") } -func (e *EmptyAdapter) Get(key string) interface{} { +func (e *EmptyAdapter) Get(key string, t ...reflect.Type) interface{} { //TODO implement me panic("implement me") } @@ -65,4 +72,3 @@ func (e *EmptyAdapter) UpdateTimeout(key string, timeout int64) error { //TODO implement me panic("implement me") } - diff --git a/persist/json_adapter.go b/persist/json_adapter.go index d1ad98c..7f8d003 100644 --- a/persist/json_adapter.go +++ b/persist/json_adapter.go @@ -1,10 +1,10 @@ package persist type JsonAdapter struct { - *DefaultAdapter - *JsonSerializer } -func NewJsonAdapter() *JsonAdapter { - return &JsonAdapter{NewDefaultAdapter(), NewJsonSerializer()} +func NewJsonAdapter() *DefaultAdapter { + d := NewDefaultAdapter() + d.SetSerializer(NewJsonSerializer()) + return d } diff --git a/persist/serializer_adapter.go b/persist/serializer.go similarity index 70% rename from persist/serializer_adapter.go rename to persist/serializer.go index 9a0c632..cee0381 100644 --- a/persist/serializer_adapter.go +++ b/persist/serializer.go @@ -1,7 +1,6 @@ package persist -type SerializerAdapter interface { - Adapter +type Serializer interface { Serialize(data interface{}) ([]byte, error) UnSerialize([]byte, interface{}) error } diff --git a/util/util.go b/util/util.go index eef97b4..e6e3548 100644 --- a/util/util.go +++ b/util/util.go @@ -1,6 +1,13 @@ package util -import "fmt" +import ( + "fmt" + "reflect" +) + +func GetType(i any) reflect.Type { + return reflect.TypeOf(i) +} func HasNil(arr []interface{}) bool { for _, elem := range arr {