diff --git a/config/token.go b/config/token.go index b7fe156..357ce72 100644 --- a/config/token.go +++ b/config/token.go @@ -1,12 +1,15 @@ package config +import "github.com/weloe/token-go/constant" + type TokenConfig struct { // TokenName prefix TokenStyle string TokenPrefix string TokenName string - Timeout int64 + Timeout int64 + // If last operate time < ActivityTimeout, token expired ActivityTimeout int64 // Data clean period DataRefreshPeriod int64 @@ -17,6 +20,7 @@ type TokenConfig struct { IsConcurrent bool IsShare bool // If (IsConcurrent == true && IsShare == false), support MaxLoginCount + // If IsConcurrent == -1, do not need to check loginCount MaxLoginCount int16 // Read token method @@ -46,7 +50,7 @@ func DefaultTokenConfig() *TokenConfig { return &TokenConfig{ TokenStyle: "uuid", TokenPrefix: "", - TokenName: "tokenGo", + TokenName: constant.TokenName, Timeout: 60 * 60 * 24 * 30, ActivityTimeout: -1, DataRefreshPeriod: 30, diff --git a/constant/constant.go b/constant/constant.go index 50d3b9d..c1bd0c9 100644 --- a/constant/constant.go +++ b/constant/constant.go @@ -7,3 +7,18 @@ const ( SetCookie = "Set-Cookie" ) +// persist timeout constant variable +const ( + // NeverExpire does not expire + NeverExpire int64 = -1 + // NotValueExpire does not exist + NotValueExpire int64 = -2 +) + +const ( + TokenName = "tokenGo" +) + +const ( + BeReplaced int = -4 +) diff --git a/ctx/go-http-context/context.go b/ctx/go-http-context/context.go index 8e1b395..4fe045d 100644 --- a/ctx/go-http-context/context.go +++ b/ctx/go-http-context/context.go @@ -2,6 +2,7 @@ package go_http_context import ( "github.com/weloe/token-go/ctx" + "net/http" "reflect" ) @@ -13,6 +14,14 @@ type HttpContext struct { reqStorage ctx.ReqStorage } +func NewHttpContext(req *http.Request, writer http.ResponseWriter) *HttpContext { + return &HttpContext{ + req: NewHttpRequest(req), + response: NewResponse(req, writer), + reqStorage: NewReqStorage(req), + } +} + func (h *HttpContext) IsValidContext() bool { return h.req != nil && !reflect.DeepEqual(h.req, &HttpRequest{}) } diff --git a/ctx/go-http-context/context_test.go b/ctx/go-http-context/context_test.go index 238e664..989180c 100644 --- a/ctx/go-http-context/context_test.go +++ b/ctx/go-http-context/context_test.go @@ -48,7 +48,11 @@ func NewTestHttpRequest(t *testing.T) *HttpRequest { func NewTestHttpReqStore(t *testing.T) *HttpReqStorage { request := NewTestRequest(t) - httpReqStorage := NewReqStorage(request) + httpReqStorage := NewReqStorage(nil) + if httpReqStorage != nil { + t.Errorf("NewReqStorage() failed: value = %v", httpReqStorage) + } + httpReqStorage = NewReqStorage(request) return httpReqStorage } @@ -441,3 +445,16 @@ func TestDeleteCookieHandler(t *testing.T) { func containsString(s string, substr string) bool { return len(s) >= len(substr) && s[:len(substr)] == substr } + +func TestNewHttpContext(t *testing.T) { + context := NewHttpContext(nil, nil) + if context == nil { + t.Errorf("NewHttpContext() failed: value = %v", context) + } + request := context.Request() + response := context.Response() + storage := context.ReqStorage() + if request == nil || response == nil || storage == nil { + t.Errorf("HttpContext failed ") + } +} diff --git a/ctx/go-http-context/request-storage.go b/ctx/go-http-context/request-storage.go index 3fa55fd..8add0c8 100644 --- a/ctx/go-http-context/request-storage.go +++ b/ctx/go-http-context/request-storage.go @@ -11,6 +11,9 @@ type HttpReqStorage struct { } func NewReqStorage(req *http.Request) *HttpReqStorage { + if req == nil { + return nil + } return &HttpReqStorage{source: req.Context()} } diff --git a/persist/adapter.go b/persist/adapter.go new file mode 100644 index 0000000..6e33453 --- /dev/null +++ b/persist/adapter.go @@ -0,0 +1,33 @@ +package persist + +type Adapter interface { + + // GetStr string operate string value + GetStr(key string) string + // SetStr set store value and timeout + SetStr(key string, value string, timeout int64) error + // UpdateStr only update value + UpdateStr(key string, value string) error + // DeleteStr delete string value + DeleteStr(key string) error + // GetStrTimeout get expire + GetStrTimeout(key string) int64 + // UpdateStrTimeout update expire time + UpdateStrTimeout(key string, timeout int64) error + + // Get get interface{} + Get(key string) interface{} + // Set store interface{} + Set(key string, value interface{}, timeout int64) error + // Update only update interface{} value + Update(key string, value interface{}) error + // Delete delete interface{} value + Delete(key string) error + // GetTimeout get expire + GetTimeout(key string) int64 + // UpdateTimeout update timeout + UpdateTimeout(key string, timeout int64) error + + // DeleteBatchFilteredKey delete data by keyPrefix + DeleteBatchFilteredKey(filterKeyPrefix string) error +} diff --git a/persist/default_adapter.go b/persist/default_adapter.go new file mode 100644 index 0000000..d409add --- /dev/null +++ b/persist/default_adapter.go @@ -0,0 +1,170 @@ +package persist + +import ( + "errors" + "fmt" + "github.com/weloe/token-go/constant" + "strings" + "sync" + "time" +) + +type DefaultAdapter struct { + dataMap *sync.Map + expireMap *sync.Map +} + +var _ Adapter = (*DefaultAdapter)(nil) + +func NewDefaultAdapter() *DefaultAdapter { + return &DefaultAdapter{ + dataMap: &sync.Map{}, + expireMap: &sync.Map{}, + } +} + +// GetStr if key is expired delete it before get data +func (d *DefaultAdapter) GetStr(key string) string { + _ = d.getExpireAndDelete(key) + value, _ := d.dataMap.Load(key) + if value == nil { + return "" + } + return fmt.Sprintf("%v", value) +} + +func (d *DefaultAdapter) SetStr(key string, value string, timeout int64) error { + if timeout == 0 || timeout <= constant.NotValueExpire { + return errors.New("args timeout error") + } + d.dataMap.Store(key, value) + + if timeout == constant.NeverExpire { + d.expireMap.Store(key, constant.NeverExpire) + } else { + d.expireMap.Store(key, time.Now().UnixMilli()+timeout*1000) + } + return nil +} + +func (d *DefaultAdapter) UpdateStr(key string, value string) error { + timeout := d.GetStrTimeout(key) + if timeout == constant.NotValueExpire { + return errors.New("does not exist") + } + d.dataMap.Store(key, value) + return nil +} + +func (d *DefaultAdapter) DeleteStr(key string) error { + d.dataMap.Delete(key) + d.expireMap.Delete(key) + return nil +} + +func (d *DefaultAdapter) GetStrTimeout(key string) int64 { + return d.getTimeout(key) +} + +func (d *DefaultAdapter) UpdateStrTimeout(key string, timeout int64) error { + if timeout == constant.NeverExpire { + d.expireMap.Store(key, constant.NeverExpire) + } else { + d.expireMap.Store(key, time.Now().UnixMilli()+timeout*1000) + } + return nil +} + +// interface{} operation +// +// + +func (d *DefaultAdapter) Get(key string) interface{} { + d.getExpireAndDelete(key) + value, _ := d.dataMap.Load(key) + return value +} + +func (d *DefaultAdapter) Set(key string, value interface{}, timeout int64) error { + if timeout == 0 || timeout <= constant.NotValueExpire { + return errors.New("args timeout error") + } + d.dataMap.Store(key, value) + + if timeout == constant.NeverExpire { + d.expireMap.Store(key, constant.NeverExpire) + } else { + d.expireMap.Store(key, time.Now().UnixMilli()+timeout*1000) + } + return nil +} + +func (d *DefaultAdapter) Update(key string, value interface{}) error { + timeout := d.GetStrTimeout(key) + if timeout == constant.NotValueExpire { + return errors.New("key does not exist") + } + d.dataMap.Store(key, value) + return nil +} + +func (d *DefaultAdapter) GetTimeout(key string) int64 { + return d.getTimeout(key) +} + +func (d *DefaultAdapter) UpdateTimeout(key string, timeout int64) error { + if timeout == constant.NeverExpire { + d.expireMap.Store(key, constant.NeverExpire) + } else { + d.expireMap.Store(key, time.Now().UnixMilli()+timeout*1000) + } + return nil +} + +func (d *DefaultAdapter) Delete(key string) error { + d.dataMap.Delete(key) + d.expireMap.Delete(key) + return nil +} + +func (d *DefaultAdapter) DeleteBatchFilteredKey(keyPrefix string) error { + d.dataMap.Range(func(key, value any) bool { + if strings.HasPrefix(key.(string), keyPrefix) { + d.dataMap.Delete(key) + } + return true + }) + return nil +} + +// delete key when getValue is expired +func (d *DefaultAdapter) getExpireAndDelete(key string) int64 { + expirationTime, _ := d.expireMap.Load(key) + + if expirationTime == nil { + return 0 + } + + if expirationTime.(int64) != constant.NeverExpire && expirationTime.(int64) <= time.Now().UnixMilli() { + d.dataMap.Delete(key) + d.expireMap.Delete(key) + } + return expirationTime.(int64) +} + +func (d *DefaultAdapter) getTimeout(key string) int64 { + expirationTime := d.getExpireAndDelete(key) + if expirationTime == 0 { + return constant.NotValueExpire + } + if expirationTime == constant.NeverExpire { + return constant.NeverExpire + } + timeout := (expirationTime - time.Now().UnixMilli()) / 1000 + if timeout <= 0 { + d.dataMap.Delete(key) + d.expireMap.Delete(key) + return constant.NotValueExpire + } + return timeout +} diff --git a/persist/default_adapter_test.go b/persist/default_adapter_test.go new file mode 100644 index 0000000..b0f1558 --- /dev/null +++ b/persist/default_adapter_test.go @@ -0,0 +1,201 @@ +package persist + +import ( + "testing" + "time" +) + +func NewTestDefaultAdapter() Adapter { + return NewDefaultAdapter() +} + +func TestDefaultAdapter_StrOperation(t *testing.T) { + defaultAdapter := NewTestDefaultAdapter() + + if err := defaultAdapter.SetStr("k1", "v1", 0); err == nil { + t.Errorf("SetStr() failed: set timeout = 0") + } + + if err := defaultAdapter.SetStr("k1", "v1", -2); err == nil { + t.Errorf("SetStr() failed: set timeout = -2") + } + + if err := defaultAdapter.SetStr("k2", "v2", -1); err != nil { + t.Errorf("SetStr() failed: can't set data") + } + if v := defaultAdapter.GetStr("k2"); v != "v2" { + t.Errorf("GetStr() failed: value is %s, want 'v2' ", v) + } + if v := defaultAdapter.GetStrTimeout("k2"); v != -1 { + t.Errorf("GetStrTimeout() failed: timeout is %v,want -1 ", v) + } + if v := defaultAdapter.GetStrTimeout("k3"); v != -2 { + t.Errorf("GetStrTimeout() failed: timeout is %v,want -2 ", v) + } + + if err := defaultAdapter.SetStr("k1", "v1", 1); err != nil { + t.Errorf("SetStr() failed: can't set data") + } + time.Sleep(1 * time.Second) + if v := defaultAdapter.Get("k1"); v != nil { + t.Errorf("getExpireAndDelete() faliled: get expired value") + } + + err1 := defaultAdapter.SetStr("k", "v", 9) + if err1 != nil { + t.Errorf("SetStr() failed: %v", err1) + } + time.Sleep(1 * time.Millisecond) + timeout := defaultAdapter.GetStrTimeout("k") + t.Logf("get timeout = %v", timeout) + if timeout > 8 { + t.Errorf("GetStrTimeout() failed: %v", timeout) + } + + if err := defaultAdapter.UpdateStrTimeout("k", -1); err != nil { + t.Errorf("UpdateStrTimeout() failed: %v", err) + } + + err2 := defaultAdapter.UpdateStrTimeout("k", 9) + if err2 != nil { + t.Errorf("UpdateStrTimeout() failed: %v", err2) + } + + timeout = defaultAdapter.GetStrTimeout("k") + t.Logf("get timeout = %v", timeout) + + getRes := defaultAdapter.GetStr("k") + if getRes != "v" { + t.Errorf("GetStr() failed: %v", getRes) + } + + err3 := defaultAdapter.UpdateStr("k", "L") + if err3 != nil { + t.Errorf("UpdateStr() failed: %v", err3) + } + + getRes = defaultAdapter.GetStr("k") + if getRes != "L" { + t.Errorf("GetStr() failed: GetStr() = %v want 'L' ", getRes) + } + + err4 := defaultAdapter.DeleteStr("k") + if err4 != nil { + t.Errorf("DeleteStr() failed: %v", err4) + } + err5 := defaultAdapter.UpdateStr("k", "L") + if err5 == nil { + t.Errorf("UpdateStr() failed: update not exist data") + } + + getRes = defaultAdapter.GetStr("k") + if getRes != "" { + t.Errorf("GetStr() failed: %v", getRes) + } +} + +func TestDefaultAdapter_InterfaceOperation(t *testing.T) { + defaultAdapter := NewTestDefaultAdapter() + if err := defaultAdapter.Set("k1", "v1", 0); err == nil { + t.Errorf("Set() failed: set timeout = 0") + } + + if err := defaultAdapter.Set("k1", "v1", -2); err == nil { + t.Errorf("Set() failed: set timeout = -2") + } + + if err := defaultAdapter.Set("k2", "v2", -1); err != nil { + t.Errorf("Set() failed: can't set data") + } + if v := defaultAdapter.Get("k2"); v.(string) != "v2" { + t.Errorf("Get() failed: value is %s, want 'v2' ", v.(string)) + } + if v := defaultAdapter.GetTimeout("k2"); v != -1 { + t.Errorf("GetTimeout() failed: timeout is %v,want -1 ", v) + } + if v := defaultAdapter.GetTimeout("k3"); v != -2 { + t.Errorf("GetTimeout() failed: timeout is %v,want -2 ", v) + } + + if err := defaultAdapter.Set("k1", "v1", 1); err != nil { + t.Errorf("Set() failed: can't set data") + } + time.Sleep(1 * time.Second) + if v := defaultAdapter.Get("k1"); v != nil { + t.Errorf("getExpireAndDelete() faliled: get expired value") + } + + err1 := defaultAdapter.Set("k", "v", 9) + if err1 != nil { + t.Errorf("Set() failed: %v", err1) + } + time.Sleep(1 * time.Millisecond) + + timeout := defaultAdapter.GetTimeout("k") + t.Logf("get timeout = %v", timeout) + if timeout > 8 { + t.Errorf("GetTimeout() failed: %v", timeout) + } + + if err := defaultAdapter.UpdateTimeout("k", -1); err != nil { + t.Errorf("UpdateTimeout() failed: %v", err) + } + err2 := defaultAdapter.UpdateTimeout("k", 9) + if err2 != nil { + t.Errorf("UpdateTimeout() failed: %v", err2) + } + + timeout = defaultAdapter.GetTimeout("k") + t.Logf("get timeout = %v", timeout) + + getRes := defaultAdapter.Get("k") + if getRes.(string) != "v" { + t.Errorf("Get() failed: Get() = %v, want 'v' ", getRes.(string)) + } + + err3 := defaultAdapter.Update("k", "L") + if err3 != nil { + t.Errorf("Update() failed: %v", err3) + } + + getRes = defaultAdapter.Get("k") + if getRes.(string) != "L" { + t.Errorf("Get() failed: Get() = %v want 'L' ", getRes.(string)) + } + + err4 := defaultAdapter.Delete("k") + if err4 != nil { + t.Errorf("Delete() failed: %v", err4) + } + err5 := defaultAdapter.Update("k", "L") + if err5 == nil { + t.Errorf("Update() failed: update not exist data") + } + + getRes = defaultAdapter.Get("k") + if getRes != nil { + t.Errorf("Get() failed: %v s", getRes) + } + +} + +func TestDefaultAdapter_DeleteBatchFilteredValue(t *testing.T) { + adapter := NewTestDefaultAdapter() + if err := adapter.SetStr("k_1", "v", -1); err != nil { + t.Errorf("SetStr() failed: %v", err) + } + if err := adapter.SetStr("k_2", "v", -1); err != nil { + t.Errorf("SetStr() failed: %v", err) + } + if err := adapter.SetStr("k_3", "v", -1); err != nil { + t.Errorf("SetStr() failed: %v", err) + } + err := adapter.DeleteBatchFilteredKey("k_") + if err != nil { + t.Errorf("DeleteBatchFilteredKey() failed: %v", err) + } + str := adapter.GetStr("k_2") + if str != "" { + t.Errorf("DeleteBatchFilteredKey() failed") + } +}