refactor: persist add Serializer

This commit is contained in:
weloe
2023-08-10 18:33:48 +08:00
parent 5f31fdee04
commit 097d1b9445
7 changed files with 85 additions and 75 deletions

View File

@@ -587,59 +587,14 @@ func (e *Enforcer) AddTokenGenerateFun(tokenStyle string, f model.GenerateFunc)
} }
func (e *Enforcer) GetSession(id string) *model.Session { func (e *Enforcer) GetSession(id string) *model.Session {
if v := e.adapter.Get(e.spliceSessionKey(id)); v != nil { if v := e.adapter.Get(e.spliceSessionKey(id), util.GetType(&model.Session{})); v != nil {
if s := e.sessionUnSerialize(v); s != nil { return v.(*model.Session)
return s
} else {
session, ok := v.(*model.Session)
if !ok {
return nil
}
return session
}
} }
return nil return nil
} }
func (e *Enforcer) sessionUnSerialize(v interface{}) *model.Session {
// get serializer
serializer, ok := e.adapter.(persist.SerializerAdapter)
if !ok {
return nil
}
// to bytes
bytes, err := util.InterfaceToBytes(v)
if err != nil {
return nil
}
session := &model.Session{}
err = serializer.UnSerialize(bytes, session)
if err != nil {
return nil
}
return session
}
func (e *Enforcer) sessionSerialize(v *model.Session) ([]byte, error) {
serializer, ok := e.adapter.(persist.SerializerAdapter)
if !ok {
return nil, nil
}
return serializer.Serialize(v)
}
func (e *Enforcer) SetSession(id string, session *model.Session, timeout int64) error { func (e *Enforcer) SetSession(id string, session *model.Session, timeout int64) error {
bytes, err := e.sessionSerialize(session) err := e.adapter.Set(e.spliceSessionKey(id), session, timeout)
if err != nil {
return err
}
if bytes != nil {
err = e.adapter.Set(e.spliceSessionKey(id), bytes, timeout)
} else {
err = e.adapter.Set(e.spliceSessionKey(id), session, timeout)
}
if err != nil { if err != nil {
return err return err
} }
@@ -655,15 +610,7 @@ func (e *Enforcer) DeleteSession(id string) error {
} }
func (e *Enforcer) UpdateSession(id string, session *model.Session) error { func (e *Enforcer) UpdateSession(id string, session *model.Session) error {
bytes, err := e.sessionSerialize(session) err := e.adapter.Update(e.spliceSessionKey(id), session)
if err != nil {
return err
}
if bytes != nil {
err = e.adapter.Update(e.spliceSessionKey(id), bytes)
} else {
err = e.adapter.Update(e.spliceSessionKey(id), session)
}
if err != nil { if err != nil {
return err return err
} }

View File

@@ -1,5 +1,7 @@
package persist package persist
import "reflect"
type Adapter interface { type Adapter interface {
// GetStr string operate string value // GetStr string operate string value
@@ -15,8 +17,9 @@ type Adapter interface {
// UpdateStrTimeout update expire time // UpdateStrTimeout update expire time
UpdateStrTimeout(key string, timeout int64) error UpdateStrTimeout(key string, timeout int64) error
// Get get interface{} // Get returns interface{}
Get(key string) interface{} // If serializer != nil, need to input reflect.Type, used to serializer to deserialize
Get(key string, t ...reflect.Type) interface{}
// Set store interface{} // Set store interface{}
Set(key string, value interface{}, timeout int64) error Set(key string, value interface{}, timeout int64) error
// Update only update interface{} value // Update only update interface{} value
@@ -27,4 +30,8 @@ type Adapter interface {
GetTimeout(key string) int64 GetTimeout(key string) int64
// UpdateTimeout update timeout // UpdateTimeout update timeout
UpdateTimeout(key string, timeout int64) error UpdateTimeout(key string, timeout int64) error
// SetSerializer used to serialize and deserialize
// Serialize when call Set() or Update(), deserialize when call Get(key,t)
SetSerializer(serializer Serializer)
} }

View File

@@ -4,15 +4,19 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/weloe/token-go/constant" "github.com/weloe/token-go/constant"
"github.com/weloe/token-go/util"
"log"
"reflect"
"strings" "strings"
"sync" "sync"
"time" "time"
) )
type DefaultAdapter struct { type DefaultAdapter struct {
dataMap *sync.Map dataMap *sync.Map
expireMap *sync.Map expireMap *sync.Map
once sync.Once once sync.Once
serializer Serializer
} }
var _ Adapter = (*DefaultAdapter)(nil) var _ Adapter = (*DefaultAdapter)(nil)
@@ -24,6 +28,10 @@ func NewDefaultAdapter() *DefaultAdapter {
} }
} }
func (d *DefaultAdapter) SetSerializer(serializer Serializer) {
d.serializer = serializer
}
// GetStr if key is expired delete it before get data // GetStr if key is expired delete it before get data
func (d *DefaultAdapter) GetStr(key string) string { func (d *DefaultAdapter) GetStr(key string) string {
_ = d.getExpireAndDelete(key) _ = d.getExpireAndDelete(key)
@@ -80,17 +88,45 @@ func (d *DefaultAdapter) UpdateStrTimeout(key string, timeout int64) error {
// //
// //
func (d *DefaultAdapter) Get(key string) interface{} { func (d *DefaultAdapter) Get(key string, t ...reflect.Type) interface{} {
d.getExpireAndDelete(key) d.getExpireAndDelete(key)
value, _ := d.dataMap.Load(key) value, _ := d.dataMap.Load(key)
return value
if d.serializer == nil {
return value
}
if t == nil && len(t) == 0 {
return nil
}
bytes, err := util.InterfaceToBytes(value)
if err != nil {
log.Printf("Adapter.Get() failed: %v", err)
return nil
}
instance := reflect.New(t[0].Elem()).Interface()
err = d.serializer.UnSerialize(bytes, instance)
if err != nil {
log.Printf("Adapter.Get() failed: %v", err)
return nil
}
return instance
} }
func (d *DefaultAdapter) Set(key string, value interface{}, timeout int64) error { func (d *DefaultAdapter) Set(key string, value interface{}, timeout int64) error {
if timeout == 0 || timeout <= constant.NotValueExpire { if timeout == 0 || timeout <= constant.NotValueExpire {
return errors.New("args timeout error") return errors.New("args timeout error")
} }
d.dataMap.Store(key, value)
if d.serializer != nil {
bytes, err := d.serializer.Serialize(value)
if err != nil {
return err
}
d.dataMap.Store(key, bytes)
} else {
d.dataMap.Store(key, value)
}
if timeout == constant.NeverExpire { if timeout == constant.NeverExpire {
d.expireMap.Store(key, constant.NeverExpire) d.expireMap.Store(key, constant.NeverExpire)
@@ -105,7 +141,15 @@ func (d *DefaultAdapter) Update(key string, value interface{}) error {
if timeout == constant.NotValueExpire { if timeout == constant.NotValueExpire {
return errors.New("key does not exist") return errors.New("key does not exist")
} }
d.dataMap.Store(key, value) if d.serializer != nil {
bytes, err := d.serializer.Serialize(value)
if err != nil {
return err
}
d.dataMap.Store(key, bytes)
} else {
d.dataMap.Store(key, value)
}
return nil return nil
} }

View File

@@ -1,11 +1,18 @@
package persist package persist
import "reflect"
var _ Adapter = (*EmptyAdapter)(nil) var _ Adapter = (*EmptyAdapter)(nil)
// EmptyAdapter empty adapter for extension to init enforcer // EmptyAdapter empty adapter for extension to init enforcer
type EmptyAdapter struct { type EmptyAdapter struct {
} }
func (e *EmptyAdapter) SetSerializer(serializer Serializer) {
//TODO implement me
panic("implement me")
}
func (e *EmptyAdapter) GetStr(key string) string { func (e *EmptyAdapter) GetStr(key string) string {
//TODO implement me //TODO implement me
panic("implement me") panic("implement me")
@@ -36,7 +43,7 @@ func (e *EmptyAdapter) UpdateStrTimeout(key string, timeout int64) error {
panic("implement me") panic("implement me")
} }
func (e *EmptyAdapter) Get(key string) interface{} { func (e *EmptyAdapter) Get(key string, t ...reflect.Type) interface{} {
//TODO implement me //TODO implement me
panic("implement me") panic("implement me")
} }
@@ -65,4 +72,3 @@ func (e *EmptyAdapter) UpdateTimeout(key string, timeout int64) error {
//TODO implement me //TODO implement me
panic("implement me") panic("implement me")
} }

View File

@@ -1,10 +1,10 @@
package persist package persist
type JsonAdapter struct { type JsonAdapter struct {
*DefaultAdapter
*JsonSerializer
} }
func NewJsonAdapter() *JsonAdapter { func NewJsonAdapter() *DefaultAdapter {
return &JsonAdapter{NewDefaultAdapter(), NewJsonSerializer()} d := NewDefaultAdapter()
d.SetSerializer(NewJsonSerializer())
return d
} }

View File

@@ -1,7 +1,6 @@
package persist package persist
type SerializerAdapter interface { type Serializer interface {
Adapter
Serialize(data interface{}) ([]byte, error) Serialize(data interface{}) ([]byte, error)
UnSerialize([]byte, interface{}) error UnSerialize([]byte, interface{}) error
} }

View File

@@ -1,6 +1,13 @@
package util package util
import "fmt" import (
"fmt"
"reflect"
)
func GetType(i any) reflect.Type {
return reflect.TypeOf(i)
}
func HasNil(arr []interface{}) bool { func HasNil(arr []interface{}) bool {
for _, elem := range arr { for _, elem := range arr {