Files
storage/memory/memory.go
Jason McNeil 4e90d76c39 Also copy value in Get() to prevent caller mutations
Returning the stored slice directly allows callers to mutate the stored
data, which defeats the purpose of the defensive copying in Set().

Since this package doesn't have access to gofiber/utils, we manually
copy the slice using make() and copy().

This completes the fix by ensuring stored data cannot be corrupted
either on input (Set) or accessed mutably on output (Get).
2025-10-31 07:27:45 -03:00

193 lines
3.9 KiB
Go

package memory
import (
"context"
"sync"
"sync/atomic"
"time"
"github.com/gofiber/storage/memory/v2/internal"
)
// Storage interface that is implemented by storage providers
type Storage struct {
mux sync.RWMutex
db map[string]entry
gcInterval time.Duration
done chan struct{}
}
type entry struct {
data []byte
expiry uint32 // max value is 4294967295 -> Sun Feb 07 2106 06:28:15 GMT+0000
}
// New creates a new memory storage
func New(config ...Config) *Storage {
// Set default config
cfg := configDefault(config...)
// Create storage
store := &Storage{
db: make(map[string]entry),
gcInterval: cfg.GCInterval,
done: make(chan struct{}),
}
// Start garbage collector
internal.StartTimeStampUpdater()
go store.gc()
return store
}
// Get value by key
func (s *Storage) Get(key string) ([]byte, error) {
if len(key) <= 0 {
return nil, nil
}
s.mux.RLock()
v, ok := s.db[key]
s.mux.RUnlock()
if !ok || (v.expiry != 0 && v.expiry <= atomic.LoadUint32(&internal.Timestamp)) {
return nil, nil
}
// Return a copy to prevent callers from mutating stored data
valCopy := make([]byte, len(v.data))
copy(valCopy, v.data)
return valCopy, nil
}
// GetWithContext gets value by key (dummy context support)
func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) {
return s.Get(key)
}
// Set key with value
func (s *Storage) Set(key string, val []byte, exp time.Duration) error {
if len(key) <= 0 || len(val) <= 0 {
return nil
}
var expire uint32
// Copy both key and value to avoid unsafe reuse from sync.Pool
// When Fiber uses pooled buffers, the underlying memory can be reused
keyCopy := string([]byte(key))
valCopy := make([]byte, len(val))
copy(valCopy, val)
if exp != 0 {
expire = uint32(exp.Seconds()) + atomic.LoadUint32(&internal.Timestamp)
}
e := entry{valCopy, expire}
s.mux.Lock()
s.db[keyCopy] = e
s.mux.Unlock()
return nil
}
// SetWithContext sets value by key (dummy context support)
func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error {
return s.Set(key, val, exp)
}
// Delete key by key
func (s *Storage) Delete(key string) error {
if len(key) <= 0 {
return nil
}
s.mux.Lock()
delete(s.db, key)
s.mux.Unlock()
return nil
}
// DeleteWithContext deletes key (dummy context support)
func (s *Storage) DeleteWithContext(ctx context.Context, key string) error {
return s.Delete(key)
}
// Reset all keys
func (s *Storage) Reset() error {
ndb := make(map[string]entry)
s.mux.Lock()
s.db = ndb
s.mux.Unlock()
return nil
}
// ResetWithContext resets all keys (dummy context support)
func (s *Storage) ResetWithContext(ctx context.Context) error {
return s.Reset()
}
// Close the memory storage
func (s *Storage) Close() error {
s.done <- struct{}{}
return nil
}
func (s *Storage) gc() {
ticker := time.NewTicker(s.gcInterval)
defer ticker.Stop()
var expired []string
for {
select {
case <-s.done:
return
case <-ticker.C:
ts := atomic.LoadUint32(&internal.Timestamp)
expired = expired[:0]
s.mux.RLock()
for id, v := range s.db {
if v.expiry != 0 && v.expiry < ts {
expired = append(expired, id)
}
}
s.mux.RUnlock()
s.mux.Lock()
for i := range expired {
v := s.db[expired[i]]
if v.expiry != 0 && v.expiry <= ts {
delete(s.db, expired[i])
}
}
s.mux.Unlock()
}
}
}
// Conn returns database client
func (s *Storage) Conn() map[string]entry {
s.mux.RLock()
defer s.mux.RUnlock()
return s.db
}
// Keys returns all the keys
func (s *Storage) Keys() ([][]byte, error) {
s.mux.RLock()
defer s.mux.RUnlock()
if len(s.db) == 0 {
return nil, nil
}
ts := atomic.LoadUint32(&internal.Timestamp)
keys := make([][]byte, 0, len(s.db))
for key, v := range s.db {
if v.expiry == 0 || v.expiry > ts {
keys = append(keys, []byte(key))
}
}
if len(keys) == 0 {
return nil, nil
}
return keys, nil
}