mirror of
https://github.com/gofiber/storage.git
synced 2025-10-05 08:37:10 +08:00
@@ -1,22 +1,96 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/utils"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
// Storage interface that is implemented by storage providers
|
||||
type Storage struct {
|
||||
db *sql.DB
|
||||
gcInterval time.Duration
|
||||
|
||||
sqlSelect string
|
||||
sqlInsert string
|
||||
sqlDelete string
|
||||
sqlClear string
|
||||
sqlGC string
|
||||
}
|
||||
|
||||
var (
|
||||
dropQuery = `DROP TABLE IF EXISTS %s;`
|
||||
initQuery = []string{
|
||||
`CREATE TABLE IF NOT EXISTS %s (
|
||||
key VARCHAR(64) PRIMARY KEY NOT NULL DEFAULT '',
|
||||
data TEXT NOT NULL,
|
||||
exp BIGINT NOT NULL DEFAULT '0'
|
||||
);`,
|
||||
`CREATE INDEX IF NOT EXISTS exp ON %s (exp);`,
|
||||
}
|
||||
)
|
||||
|
||||
// New creates a new storage
|
||||
func New(config ...Config) *Storage {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Create data source name
|
||||
dsn := fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?connect_timeout=%d&sslmode=disable",
|
||||
url.QueryEscape(cfg.Username),
|
||||
cfg.Password,
|
||||
url.QueryEscape(cfg.Host),
|
||||
cfg.Port,
|
||||
url.QueryEscape(cfg.Database),
|
||||
int64(cfg.Timeout.Seconds()))
|
||||
|
||||
// Create db
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Set database options
|
||||
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
db.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||
db.SetConnMaxLifetime(cfg.ConnMaxLifetime)
|
||||
|
||||
// Ping database
|
||||
if err := db.Ping(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Drop table if set to true
|
||||
if cfg.DropTable {
|
||||
if _, err = db.Exec(fmt.Sprintf(dropQuery, cfg.TableName)); err != nil {
|
||||
_ = db.Close()
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Init database queries
|
||||
for _, query := range initQuery {
|
||||
if _, err := db.Exec(fmt.Sprintf(query, cfg.TableName)); err != nil {
|
||||
_ = db.Close()
|
||||
fmt.Println(fmt.Sprintf(query, cfg.TableName))
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create storage
|
||||
store := &Storage{
|
||||
db: db,
|
||||
gcInterval: cfg.GCInterval,
|
||||
sqlSelect: fmt.Sprintf(`SELECT data, exp FROM %s WHERE key=$1;`, cfg.TableName),
|
||||
sqlInsert: fmt.Sprintf("INSERT INTO %s (key, data, exp) VALUES ($1, $2, $3)", cfg.TableName),
|
||||
sqlDelete: fmt.Sprintf("DELETE FROM %s WHERE key=$1", cfg.TableName),
|
||||
sqlClear: fmt.Sprintf("DELETE FROM %s;", cfg.TableName),
|
||||
sqlGC: fmt.Sprintf("DELETE FROM %s WHERE exp <= $1", cfg.TableName),
|
||||
}
|
||||
|
||||
// Start garbage collector
|
||||
@@ -25,31 +99,57 @@ func New(config ...Config) *Storage {
|
||||
return store
|
||||
}
|
||||
|
||||
var noRows = errors.New("sql: no rows in result set")
|
||||
|
||||
// Get value by key
|
||||
func (s *Storage) Get(key string) ([]byte, error) {
|
||||
return nil, nil
|
||||
row := s.db.QueryRow(s.sqlSelect, key)
|
||||
|
||||
// Add db response to data
|
||||
var (
|
||||
data = []byte{}
|
||||
exp int64 = 0
|
||||
)
|
||||
if err := row.Scan(&data, &exp); err != nil {
|
||||
if err != noRows {
|
||||
return nil, err
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// If the expiration time has already passed, then return nil
|
||||
if time.Now().After(time.Unix(exp, 0)) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// Set key with value
|
||||
func (s *Storage) Set(key string, val []byte, exp time.Duration) error {
|
||||
return nil
|
||||
_, err := s.db.Exec(s.sqlInsert, key, utils.UnsafeString(val), time.Now().Add(exp).Unix())
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete key by key
|
||||
// Delete entry by key
|
||||
func (s *Storage) Delete(key string) error {
|
||||
return nil
|
||||
_, err := s.db.Exec(s.sqlDelete, key)
|
||||
return err
|
||||
}
|
||||
|
||||
// Clear all keys
|
||||
// Clear all entries, including unexpired
|
||||
func (s *Storage) Clear() error {
|
||||
return nil
|
||||
_, err := s.db.Exec(s.sqlClear)
|
||||
return err
|
||||
}
|
||||
|
||||
// Garbage collector to delete expired keys
|
||||
// GC deletes all expired entries
|
||||
func (s *Storage) gc() {
|
||||
tick := time.NewTicker(s.gcInterval)
|
||||
for {
|
||||
<-tick.C
|
||||
// clean entries
|
||||
if _, err := s.db.Exec(s.sqlGC); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user