mirror of
https://github.com/gofiber/storage.git
synced 2025-10-05 16:48:25 +08:00
Removed Postgre Close test since the pgx implementation of Close() doesn't return an error,
Removed Storage.Close()'s return method since the pgx implementation of Close() doesn't return an error, Removed Config.maxIdleConns since pgx implementation doesn't support setting a maximum number of idle connections
This commit is contained in:
@@ -1,19 +1,19 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
)
|
||||
|
||||
// Storage interface that is implemented by storage providers
|
||||
type Storage struct {
|
||||
db *sql.DB
|
||||
db *pgxpool.Pool
|
||||
gcInterval time.Duration
|
||||
done chan struct{}
|
||||
|
||||
@@ -46,7 +46,7 @@ func New(config ...Config) *Storage {
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Create data source name
|
||||
var dsn string = "postgresql://"
|
||||
var dsn = "postgresql://"
|
||||
if cfg.Username != "" {
|
||||
dsn += url.QueryEscape(cfg.Username)
|
||||
}
|
||||
@@ -63,35 +63,34 @@ func New(config ...Config) *Storage {
|
||||
cfg.SslMode,
|
||||
)
|
||||
|
||||
// Create db
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
cnf, _ := pgxpool.ParseConfig(dsn)
|
||||
cnf.MaxConns = cfg.maxOpenConns
|
||||
cnf.MaxConnLifetime = cfg.connMaxLifetime
|
||||
cnf.MaxConnIdleTime = cfg.connMaxLifetime
|
||||
|
||||
db, err := pgxpool.ConnectConfig(context.Background(), cnf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Set database options
|
||||
db.SetMaxOpenConns(cfg.maxOpenConns)
|
||||
db.SetMaxIdleConns(cfg.maxIdleConns)
|
||||
db.SetConnMaxLifetime(cfg.connMaxLifetime)
|
||||
defer db.Close()
|
||||
|
||||
// Ping database
|
||||
if err := db.Ping(); err != nil {
|
||||
if err := db.Ping(context.Background()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Drop table if set to true
|
||||
if cfg.Reset {
|
||||
if _, err = db.Exec(fmt.Sprintf(dropQuery, cfg.Table)); err != nil {
|
||||
_ = db.Close()
|
||||
if _, err = db.Exec(context.Background(), fmt.Sprintf(dropQuery, cfg.Table)); err != nil {
|
||||
db.Close()
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Init database queries
|
||||
for _, query := range initQuery {
|
||||
if _, err := db.Exec(fmt.Sprintf(query, cfg.Table)); err != nil {
|
||||
_ = db.Close()
|
||||
|
||||
if _, err := db.Exec(context.Background(), fmt.Sprintf(query, cfg.Table)); err != nil {
|
||||
db.Close()
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
@@ -116,17 +115,15 @@ func New(config ...Config) *Storage {
|
||||
return store
|
||||
}
|
||||
|
||||
var noRows = errors.New("sql: no rows in result set")
|
||||
|
||||
// Get value by key
|
||||
func (s *Storage) Get(key string) ([]byte, error) {
|
||||
if len(key) <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
row := s.db.QueryRow(s.sqlSelect, key)
|
||||
row := s.db.QueryRow(context.Background(), s.sqlSelect, key)
|
||||
// Add db response to data
|
||||
var (
|
||||
data = []byte{}
|
||||
data []byte
|
||||
exp int64 = 0
|
||||
)
|
||||
if err := row.Scan(&data, &exp); err != nil {
|
||||
@@ -154,7 +151,7 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error {
|
||||
if exp != 0 {
|
||||
expSeconds = time.Now().Add(exp).Unix()
|
||||
}
|
||||
_, err := s.db.Exec(s.sqlInsert, key, val, expSeconds, val, expSeconds)
|
||||
_, err := s.db.Exec(context.Background(), s.sqlInsert, key, val, expSeconds, val, expSeconds)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -164,20 +161,21 @@ func (s *Storage) Delete(key string) error {
|
||||
if len(key) <= 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := s.db.Exec(s.sqlDelete, key)
|
||||
_, err := s.db.Exec(context.Background(), s.sqlDelete, key)
|
||||
return err
|
||||
}
|
||||
|
||||
// Reset all entries, including unexpired
|
||||
func (s *Storage) Reset() error {
|
||||
_, err := s.db.Exec(s.sqlReset)
|
||||
_, err := s.db.Exec(context.Background(), s.sqlReset)
|
||||
return err
|
||||
}
|
||||
|
||||
// Close the database
|
||||
func (s *Storage) Close() error {
|
||||
func (s *Storage) Close() {
|
||||
s.done <- struct{}{}
|
||||
return s.db.Close()
|
||||
s.db.Stat()
|
||||
s.db.Close()
|
||||
}
|
||||
|
||||
// gcTicker starts the gc ticker
|
||||
@@ -196,13 +194,13 @@ func (s *Storage) gcTicker() {
|
||||
|
||||
// gc deletes all expired entries
|
||||
func (s *Storage) gc(t time.Time) {
|
||||
_, _ = s.db.Exec(s.sqlGC, t.Unix())
|
||||
_, _ = s.db.Exec(context.Background(), s.sqlGC, t.Unix())
|
||||
}
|
||||
|
||||
func (s *Storage) checkSchema(tableName string) {
|
||||
var data []byte
|
||||
|
||||
row := s.db.QueryRow(fmt.Sprintf(checkSchemaQuery, tableName))
|
||||
row := s.db.QueryRow(context.Background(), fmt.Sprintf(checkSchemaQuery, tableName))
|
||||
if err := row.Scan(&data); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -211,3 +209,7 @@ func (s *Storage) checkSchema(tableName string) {
|
||||
fmt.Printf(checkSchemaMsg, string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Storage) DB() *pgxpool.Pool {
|
||||
return s.db
|
||||
}
|
||||
|
Reference in New Issue
Block a user