mirror of
https://github.com/gofiber/storage.git
synced 2025-12-24 13:29:30 +08:00
Returning the stored slice directly allows callers to mutate the stored data, which defeats the purpose of the defensive copying in Set(). Since this package doesn't have access to gofiber/utils, we manually copy the slice using make() and copy(). This completes the fix by ensuring stored data cannot be corrupted either on input (Set) or accessed mutably on output (Get).
193 lines
3.9 KiB
Go
193 lines
3.9 KiB
Go
package memory
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/gofiber/storage/memory/v2/internal"
|
|
)
|
|
|
|
// Storage interface that is implemented by storage providers
|
|
type Storage struct {
|
|
mux sync.RWMutex
|
|
db map[string]entry
|
|
gcInterval time.Duration
|
|
done chan struct{}
|
|
}
|
|
|
|
type entry struct {
|
|
data []byte
|
|
expiry uint32 // max value is 4294967295 -> Sun Feb 07 2106 06:28:15 GMT+0000
|
|
}
|
|
|
|
// New creates a new memory storage
|
|
func New(config ...Config) *Storage {
|
|
// Set default config
|
|
cfg := configDefault(config...)
|
|
|
|
// Create storage
|
|
store := &Storage{
|
|
db: make(map[string]entry),
|
|
gcInterval: cfg.GCInterval,
|
|
done: make(chan struct{}),
|
|
}
|
|
|
|
// Start garbage collector
|
|
internal.StartTimeStampUpdater()
|
|
go store.gc()
|
|
|
|
return store
|
|
}
|
|
|
|
// Get value by key
|
|
func (s *Storage) Get(key string) ([]byte, error) {
|
|
if len(key) <= 0 {
|
|
return nil, nil
|
|
}
|
|
s.mux.RLock()
|
|
v, ok := s.db[key]
|
|
s.mux.RUnlock()
|
|
if !ok || (v.expiry != 0 && v.expiry <= atomic.LoadUint32(&internal.Timestamp)) {
|
|
return nil, nil
|
|
}
|
|
|
|
// Return a copy to prevent callers from mutating stored data
|
|
valCopy := make([]byte, len(v.data))
|
|
copy(valCopy, v.data)
|
|
return valCopy, nil
|
|
}
|
|
|
|
// GetWithContext gets value by key (dummy context support)
|
|
func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) {
|
|
return s.Get(key)
|
|
}
|
|
|
|
// Set key with value
|
|
func (s *Storage) Set(key string, val []byte, exp time.Duration) error {
|
|
if len(key) <= 0 || len(val) <= 0 {
|
|
return nil
|
|
}
|
|
|
|
var expire uint32
|
|
// Copy both key and value to avoid unsafe reuse from sync.Pool
|
|
// When Fiber uses pooled buffers, the underlying memory can be reused
|
|
keyCopy := string([]byte(key))
|
|
valCopy := make([]byte, len(val))
|
|
copy(valCopy, val)
|
|
|
|
if exp != 0 {
|
|
expire = uint32(exp.Seconds()) + atomic.LoadUint32(&internal.Timestamp)
|
|
}
|
|
|
|
e := entry{valCopy, expire}
|
|
s.mux.Lock()
|
|
s.db[keyCopy] = e
|
|
s.mux.Unlock()
|
|
return nil
|
|
}
|
|
|
|
// SetWithContext sets value by key (dummy context support)
|
|
func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error {
|
|
return s.Set(key, val, exp)
|
|
}
|
|
|
|
// Delete key by key
|
|
func (s *Storage) Delete(key string) error {
|
|
if len(key) <= 0 {
|
|
return nil
|
|
}
|
|
s.mux.Lock()
|
|
delete(s.db, key)
|
|
s.mux.Unlock()
|
|
return nil
|
|
}
|
|
|
|
// DeleteWithContext deletes key (dummy context support)
|
|
func (s *Storage) DeleteWithContext(ctx context.Context, key string) error {
|
|
return s.Delete(key)
|
|
}
|
|
|
|
// Reset all keys
|
|
func (s *Storage) Reset() error {
|
|
ndb := make(map[string]entry)
|
|
s.mux.Lock()
|
|
s.db = ndb
|
|
s.mux.Unlock()
|
|
return nil
|
|
}
|
|
|
|
// ResetWithContext resets all keys (dummy context support)
|
|
func (s *Storage) ResetWithContext(ctx context.Context) error {
|
|
return s.Reset()
|
|
}
|
|
|
|
// Close the memory storage
|
|
func (s *Storage) Close() error {
|
|
s.done <- struct{}{}
|
|
return nil
|
|
}
|
|
|
|
func (s *Storage) gc() {
|
|
ticker := time.NewTicker(s.gcInterval)
|
|
defer ticker.Stop()
|
|
var expired []string
|
|
|
|
for {
|
|
select {
|
|
case <-s.done:
|
|
return
|
|
case <-ticker.C:
|
|
ts := atomic.LoadUint32(&internal.Timestamp)
|
|
expired = expired[:0]
|
|
s.mux.RLock()
|
|
for id, v := range s.db {
|
|
if v.expiry != 0 && v.expiry < ts {
|
|
expired = append(expired, id)
|
|
}
|
|
}
|
|
s.mux.RUnlock()
|
|
s.mux.Lock()
|
|
for i := range expired {
|
|
v := s.db[expired[i]]
|
|
if v.expiry != 0 && v.expiry <= ts {
|
|
delete(s.db, expired[i])
|
|
}
|
|
}
|
|
s.mux.Unlock()
|
|
}
|
|
}
|
|
}
|
|
|
|
// Conn returns database client
|
|
func (s *Storage) Conn() map[string]entry {
|
|
s.mux.RLock()
|
|
defer s.mux.RUnlock()
|
|
return s.db
|
|
}
|
|
|
|
// Keys returns all the keys
|
|
func (s *Storage) Keys() ([][]byte, error) {
|
|
s.mux.RLock()
|
|
defer s.mux.RUnlock()
|
|
|
|
if len(s.db) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
ts := atomic.LoadUint32(&internal.Timestamp)
|
|
keys := make([][]byte, 0, len(s.db))
|
|
for key, v := range s.db {
|
|
if v.expiry == 0 || v.expiry > ts {
|
|
keys = append(keys, []byte(key))
|
|
}
|
|
}
|
|
|
|
if len(keys) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
return keys, nil
|
|
}
|