diff --git a/enforcer.go b/enforcer.go index 421ccbb..d78acaf 100644 --- a/enforcer.go +++ b/enforcer.go @@ -22,7 +22,21 @@ type Enforcer struct { adapter persist.Adapter watcher persist.Watcher logger log.Logger - authManager interface{} + + dispatcher persist.Dispatcher + notifyDispatcher bool + + updatableWatcher persist.UpdatableWatcher + notifyUpdatableWatcher bool + + authManager interface{} +} + +func (e *Enforcer) EnableUpdatableWatcher(b bool) { + if e.updatableWatcher == nil { + return + } + e.notifyUpdatableWatcher = b } func NewDefaultAdapter() persist.Adapter { @@ -482,7 +496,7 @@ func (e *Enforcer) GetSession(id string) *model.Session { } func (e *Enforcer) SetSession(id string, session *model.Session, timeout int64) error { - err := e.adapter.Set(e.spliceSessionKey(id), session, timeout) + err := e.notifySet(e.spliceSessionKey(id), session, timeout) if err != nil { return err } @@ -490,7 +504,7 @@ func (e *Enforcer) SetSession(id string, session *model.Session, timeout int64) } func (e *Enforcer) DeleteSession(id string) error { - err := e.adapter.Delete(e.spliceSessionKey(id)) + err := e.notifyDelete(e.spliceSessionKey(id)) if err != nil { return err } @@ -498,7 +512,7 @@ func (e *Enforcer) DeleteSession(id string) error { } func (e *Enforcer) UpdateSession(id string, session *model.Session) error { - err := e.adapter.Update(e.spliceSessionKey(id), session) + err := e.notifyUpdate(e.spliceSessionKey(id), session) if err != nil { return err } diff --git a/enforcer_distributed.go b/enforcer_distributed.go new file mode 100644 index 0000000..b5c481a --- /dev/null +++ b/enforcer_distributed.go @@ -0,0 +1,46 @@ +package token_go + +import "github.com/weloe/token-go/persist" + +type DistributedEnforcer struct { + *Enforcer +} + +func NewDistributedEnforcer(enforcer *Enforcer) *DistributedEnforcer { + return &DistributedEnforcer{enforcer} +} + +func (e *DistributedEnforcer) SetStrSelf(key string, value string, timeout int64) error { + return e.adapter.SetStr(key, value, timeout) +} + +func (e *DistributedEnforcer) UpdateStrSelf(key string, value string) error { + return e.adapter.UpdateStr(key, value) +} + +func (e *DistributedEnforcer) SetSelf(key string, value interface{}, timeout int64) error { + return e.adapter.Set(key, value, timeout) +} + +func (e *DistributedEnforcer) UpdateSelf(key string, value interface{}) error { + return e.adapter.Update(key, value) +} + +func (e *DistributedEnforcer) DeleteSelf(key string) error { + return e.adapter.DeleteStr(key) +} + +func (e *DistributedEnforcer) UpdateTimeoutSelf(key string, timeout int64) error { + return e.adapter.UpdateTimeout(key, timeout) +} + +func (e *DistributedEnforcer) EnableDispatcher(b bool) { + if e.dispatcher == nil { + return + } + e.notifyDispatcher = b +} + +func (e *Enforcer) SetDispatcher(dispatcher persist.Dispatcher) { + e.dispatcher = dispatcher +} diff --git a/enforcer_interface.go b/enforcer_interface.go index 9320778..065d787 100644 --- a/enforcer_interface.go +++ b/enforcer_interface.go @@ -91,3 +91,21 @@ type IEnforcer interface { UpdateSession(id string, session *model.Session) error SetSession(id string, session *model.Session, timeout int64) error } + +var _ IDistributedEnforcer = &DistributedEnforcer{} + +type IDistributedEnforcer interface { + IEnforcer + // SetStrSelf store string in all instances + SetStrSelf(key string, value string, timeout int64) error + // UpdateStrSelf only update string value in all instances + UpdateStrSelf(key string, value string) error + // SetSelf store interface{} in all instances + SetSelf(key string, value interface{}, timeout int64) error + // UpdateSelf only update interface{} value in all instances + UpdateSelf(key string, value interface{}) error + // DeleteSelf delete interface{} value in all instances + DeleteSelf(key string) error + // UpdateTimeoutSelf update timeout in all instances + UpdateTimeoutSelf(key string, timeout int64) error +} diff --git a/enforcer_internal_api.go b/enforcer_internal_api.go index 2df2037..ff68bc8 100644 --- a/enforcer_internal_api.go +++ b/enforcer_internal_api.go @@ -114,7 +114,7 @@ func (e *Enforcer) checkId(str string) (bool, error) { } func (e *Enforcer) SetIdByToken(id string, tokenValue string, timeout int64) error { - err := e.adapter.SetStr(e.spliceTokenKey(tokenValue), id, timeout) + err := e.notifySetStr(e.spliceTokenKey(tokenValue), id, timeout) return err } @@ -123,22 +123,22 @@ func (e *Enforcer) getIdByToken(token string) string { } func (e *Enforcer) deleteIdByToken(tokenValue string) error { - err := e.adapter.DeleteStr(e.spliceTokenKey(tokenValue)) + err := e.notifyDelete(e.spliceTokenKey(tokenValue)) return err } func (e *Enforcer) updateIdByToken(tokenValue string, id string) error { - err := e.adapter.UpdateStr(e.spliceTokenKey(tokenValue), id) + err := e.notifyUpdateStr(e.spliceTokenKey(tokenValue), id) return err } func (e *Enforcer) setBanned(id string, service string, level int, time int64) error { - err := e.adapter.SetStr(e.spliceBannedKey(id, service), strconv.Itoa(level), time) + err := e.notifySetStr(e.spliceBannedKey(id, service), strconv.Itoa(level), time) return err } func (e *Enforcer) deleteBanned(id string, service string) error { - err := e.adapter.DeleteStr(e.spliceBannedKey(id, service)) + err := e.notifyDelete(e.spliceBannedKey(id, service)) return err } @@ -153,7 +153,7 @@ func (e *Enforcer) getBannedTime(id string, service string) int64 { } func (e *Enforcer) setSecSafe(token string, service string, time int64) error { - err := e.adapter.SetStr(e.spliceSecSafeKey(token, service), constant.DefaultSecondAuthValue, time) + err := e.notifySetStr(e.spliceSecSafeKey(token, service), constant.DefaultSecondAuthValue, time) return err } @@ -168,12 +168,12 @@ func (e *Enforcer) getSecSafe(token string, service string) string { } func (e *Enforcer) deleteSecSafe(token string, service string) error { - err := e.adapter.DeleteStr(e.spliceSecSafeKey(token, service)) + err := e.notifyDelete(e.spliceSecSafeKey(token, service)) return err } func (e *Enforcer) setTempToken(service string, token string, value string, timeout int64) error { - err := e.adapter.SetStr(e.spliceTempTokenKey(service, token), value, timeout) + err := e.notifySetStr(e.spliceTempTokenKey(service, token), value, timeout) return err } @@ -182,11 +182,11 @@ func (e *Enforcer) getTimeoutByTempToken(service string, token string) int64 { } func (e *Enforcer) deleteByTempToken(service string, tempToken string) error { - return e.adapter.DeleteStr(e.spliceTempTokenKey(service, tempToken)) + return e.notifyDelete(e.spliceTempTokenKey(service, tempToken)) } func (e *Enforcer) createQRCode(id string, timeout int64) error { - return e.adapter.Set(e.spliceQRCodeKey(id), model.NewQRCode(id), timeout) + return e.notifySet(e.spliceQRCodeKey(id), model.NewQRCode(id), timeout) } func (e *Enforcer) getQRCode(id string) *model.QRCode { @@ -213,11 +213,11 @@ func (e *Enforcer) getQRCodeTimeout(id string) int64 { } func (e *Enforcer) updateQRCode(id string, qrCode *model.QRCode) error { - return e.adapter.Update(e.spliceQRCodeKey(id), qrCode) + return e.notifyUpdate(e.spliceQRCodeKey(id), qrCode) } func (e *Enforcer) deleteQRCode(id string) error { - return e.adapter.Delete(e.spliceQRCodeKey(id)) + return e.notifyDelete(e.spliceQRCodeKey(id)) } func (e *Enforcer) getByTempToken(service string, tempToken string) string { @@ -253,3 +253,107 @@ func (e *Enforcer) spliceQRCodeKey(QRCodeId string) string { func (e *Enforcer) SetJwtSecretKey(key string) { e.config.JwtSecretKey = key } + +func (e *Enforcer) notifySetStr(key string, value string, timeout int64) error { + if e.shouldNotifyDispatcher() { + return e.dispatcher.SetAllStr(key, value, timeout) + } + err := e.adapter.SetStr(key, value, timeout) + if err != nil { + return err + } + if e.shouldNotifyUpdatableWatcher() { + return e.updatableWatcher.UpdateForSetStr(key, value, timeout) + } + return nil +} + +func (e *Enforcer) notifyUpdateStr(key string, value string) error { + if e.shouldNotifyDispatcher() { + return e.dispatcher.UpdateAllStr(key, value) + } + err := e.adapter.UpdateStr(key, value) + if err != nil { + return err + } + if e.shouldNotifyUpdatableWatcher() { + return e.updatableWatcher.UpdateForUpdateStr(key, value) + } + return nil +} + +func (e *Enforcer) notifySet(key string, value interface{}, timeout int64) error { + if e.shouldNotifyDispatcher() { + return e.dispatcher.SetAll(key, value, timeout) + } + err := e.adapter.Set(key, value, timeout) + if err != nil { + return err + } + if e.shouldNotifyUpdatableWatcher() { + return e.updatableWatcher.UpdateForSet(key, value, timeout) + } + return nil +} + +func (e *Enforcer) notifyUpdate(key string, value interface{}) error { + if e.shouldNotifyDispatcher() { + return e.dispatcher.UpdateAll(key, value) + } + err := e.adapter.Update(key, value) + if err != nil { + return err + } + if e.shouldNotifyUpdatableWatcher() { + return e.updatableWatcher.UpdateForUpdate(key, value) + } + return nil +} + +func (e *Enforcer) notifyDelete(key string) error { + if e.shouldNotifyDispatcher() { + return e.dispatcher.DeleteAll(key) + } + err := e.adapter.Delete(key) + if err != nil { + return err + } + if e.shouldNotifyUpdatableWatcher() { + return e.updatableWatcher.UpdateForDelete(key) + } + return nil +} + +// nolint:golint,unused +func (e *Enforcer) notifyUpdateTimeout(key string, timeout int64) error { + if e.shouldNotifyDispatcher() { + return e.dispatcher.UpdateAllTimeout(key, timeout) + } + err := e.adapter.UpdateTimeout(key, timeout) + if err != nil { + return err + } + if e.shouldNotifyUpdatableWatcher() { + return e.updatableWatcher.UpdateForUpdateTimeout(key, timeout) + } + return nil +} + +func (e *Enforcer) shouldNotifyDispatcher() bool { + if e.dispatcher != nil && e.notifyDispatcher { + return true + } + return false +} + +func (e *Enforcer) shouldNotifyUpdatableWatcher() bool { + if e.updatableWatcher != nil && e.notifyUpdatableWatcher { + return true + } + return false +} + +// nolint:golint,unused +func (e *Enforcer) shouldPersist() bool { + return e.adapter != nil +} diff --git a/persist/dispatcher.go b/persist/dispatcher.go new file mode 100644 index 0000000..292f77f --- /dev/null +++ b/persist/dispatcher.go @@ -0,0 +1,19 @@ +package persist + +type Dispatcher interface { + + // SetAllStr store string in all instances + SetAllStr(key string, value string, timeout int64) error + // UpdateAllStr only update string value in all instances + UpdateAllStr(key string, value string) error + + // SetAll store interface{} in all instances + SetAll(key string, value interface{}, timeout int64) error + // UpdateAll only update interface{} value in all instances + UpdateAll(key string, value interface{}) error + + // DeleteAll delete interface{} value in all instances + DeleteAll(key string) error + // UpdateAllTimeout update timeout in all instances + UpdateAllTimeout(key string, timeout int64) error +} diff --git a/persist/watcher_update.go b/persist/watcher_update.go new file mode 100644 index 0000000..0d223b7 --- /dev/null +++ b/persist/watcher_update.go @@ -0,0 +1,11 @@ +package persist + +// UpdatableWatcher called when data updated +type UpdatableWatcher interface { + UpdateForSetStr(key string, value interface{}, timeout int64) error + UpdateForUpdateStr(key string, value interface{}) error + UpdateForSet(key string, value interface{}, timeout int64) error + UpdateForUpdate(key string, value interface{}) error + UpdateForDelete(key string) error + UpdateForUpdateTimeout(key string, timeout int64) error +}