Removed Postgre Close test since the pgx implementation of Close() doesn't return an error,

Removed Storage.Close()'s return method since the pgx implementation of Close() doesn't return an error,
Removed Config.maxIdleConns since pgx implementation doesn't support setting a maximum number of idle connections
This commit is contained in:
Technerder
2021-07-02 18:04:33 -04:00
parent f0233feaea
commit 85e612ec79
4 changed files with 35 additions and 48 deletions

View File

@@ -1,6 +1,6 @@
# Postgres # Postgres
A Postgres storage driver using [lib/pq](https://github.com/lib/pq). A Postgres storage driver using [jackc/pgx](https://github.com/jackc/pgx).
### Table of Contents ### Table of Contents
- [Signatures](#signatures) - [Signatures](#signatures)
@@ -16,7 +16,7 @@ func (s *Storage) Get(key string) ([]byte, error)
func (s *Storage) Set(key string, val []byte, exp time.Duration) error func (s *Storage) Set(key string, val []byte, exp time.Duration) error
func (s *Storage) Delete(key string) error func (s *Storage) Delete(key string) error
func (s *Storage) Reset() error func (s *Storage) Reset() error
func (s *Storage) Close() error func (s *Storage) Close()
``` ```
### Installation ### Installation
Postgres is tested on the 2 last [Go versions](https://golang.org/dl/) with support for modules. So make sure to initialize one first if you didn't do that yet: Postgres is tested on the 2 last [Go versions](https://golang.org/dl/) with support for modules. So make sure to initialize one first if you didn't do that yet:

View File

@@ -59,17 +59,6 @@ type Config struct {
// n < 0 means wait indefinitely. // n < 0 means wait indefinitely.
timeout time.Duration timeout time.Duration
// The maximum number of connections in the idle connection pool.
//
// If MaxOpenConns is greater than 0 but less than the new MaxIdleConns,
// then the new MaxIdleConns will be reduced to match the MaxOpenConns limit.
//
// If n <= 0, no idle connections are retained.
//
// The default max idle connections is currently 2. This may change in
// a future release.
maxIdleConns int
// The maximum number of open connections to the database. // The maximum number of open connections to the database.
// //
// If MaxIdleConns is greater than 0 and the new MaxOpenConns is less than // If MaxIdleConns is greater than 0 and the new MaxOpenConns is less than
@@ -78,7 +67,7 @@ type Config struct {
// //
// If n <= 0, then there is no limit on the number of open connections. // If n <= 0, then there is no limit on the number of open connections.
// The default is 0 (unlimited). // The default is 0 (unlimited).
maxOpenConns int maxOpenConns int32
// The maximum amount of time a connection may be reused. // The maximum amount of time a connection may be reused.
// //
@@ -98,7 +87,6 @@ var ConfigDefault = Config{
Reset: false, Reset: false,
GCInterval: 10 * time.Second, GCInterval: 10 * time.Second,
maxOpenConns: 100, maxOpenConns: 100,
maxIdleConns: 100,
connMaxLifetime: 1 * time.Second, connMaxLifetime: 1 * time.Second,
} }

View File

@@ -1,19 +1,19 @@
package postgres package postgres
import ( import (
"context"
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"net/url" "net/url"
"strings" "strings"
"time" "time"
_ "github.com/lib/pq" "github.com/jackc/pgx/v4/pgxpool"
) )
// Storage interface that is implemented by storage providers // Storage interface that is implemented by storage providers
type Storage struct { type Storage struct {
db *sql.DB db *pgxpool.Pool
gcInterval time.Duration gcInterval time.Duration
done chan struct{} done chan struct{}
@@ -46,7 +46,7 @@ func New(config ...Config) *Storage {
cfg := configDefault(config...) cfg := configDefault(config...)
// Create data source name // Create data source name
var dsn string = "postgresql://" var dsn = "postgresql://"
if cfg.Username != "" { if cfg.Username != "" {
dsn += url.QueryEscape(cfg.Username) dsn += url.QueryEscape(cfg.Username)
} }
@@ -63,35 +63,34 @@ func New(config ...Config) *Storage {
cfg.SslMode, cfg.SslMode,
) )
// Create db cnf, _ := pgxpool.ParseConfig(dsn)
db, err := sql.Open("postgres", dsn) cnf.MaxConns = cfg.maxOpenConns
cnf.MaxConnLifetime = cfg.connMaxLifetime
cnf.MaxConnIdleTime = cfg.connMaxLifetime
db, err := pgxpool.ConnectConfig(context.Background(), cnf)
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer db.Close()
// Set database options
db.SetMaxOpenConns(cfg.maxOpenConns)
db.SetMaxIdleConns(cfg.maxIdleConns)
db.SetConnMaxLifetime(cfg.connMaxLifetime)
// Ping database // Ping database
if err := db.Ping(); err != nil { if err := db.Ping(context.Background()); err != nil {
panic(err) panic(err)
} }
// Drop table if set to true // Drop table if set to true
if cfg.Reset { if cfg.Reset {
if _, err = db.Exec(fmt.Sprintf(dropQuery, cfg.Table)); err != nil { if _, err = db.Exec(context.Background(), fmt.Sprintf(dropQuery, cfg.Table)); err != nil {
_ = db.Close() db.Close()
panic(err) panic(err)
} }
} }
// Init database queries // Init database queries
for _, query := range initQuery { for _, query := range initQuery {
if _, err := db.Exec(fmt.Sprintf(query, cfg.Table)); err != nil { if _, err := db.Exec(context.Background(), fmt.Sprintf(query, cfg.Table)); err != nil {
_ = db.Close() db.Close()
panic(err) panic(err)
} }
} }
@@ -116,17 +115,15 @@ func New(config ...Config) *Storage {
return store return store
} }
var noRows = errors.New("sql: no rows in result set")
// Get value by key // Get value by key
func (s *Storage) Get(key string) ([]byte, error) { func (s *Storage) Get(key string) ([]byte, error) {
if len(key) <= 0 { if len(key) <= 0 {
return nil, nil return nil, nil
} }
row := s.db.QueryRow(s.sqlSelect, key) row := s.db.QueryRow(context.Background(), s.sqlSelect, key)
// Add db response to data // Add db response to data
var ( var (
data = []byte{} data []byte
exp int64 = 0 exp int64 = 0
) )
if err := row.Scan(&data, &exp); err != nil { if err := row.Scan(&data, &exp); err != nil {
@@ -154,7 +151,7 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error {
if exp != 0 { if exp != 0 {
expSeconds = time.Now().Add(exp).Unix() expSeconds = time.Now().Add(exp).Unix()
} }
_, err := s.db.Exec(s.sqlInsert, key, val, expSeconds, val, expSeconds) _, err := s.db.Exec(context.Background(), s.sqlInsert, key, val, expSeconds, val, expSeconds)
return err return err
} }
@@ -164,20 +161,21 @@ func (s *Storage) Delete(key string) error {
if len(key) <= 0 { if len(key) <= 0 {
return nil return nil
} }
_, err := s.db.Exec(s.sqlDelete, key) _, err := s.db.Exec(context.Background(), s.sqlDelete, key)
return err return err
} }
// Reset all entries, including unexpired // Reset all entries, including unexpired
func (s *Storage) Reset() error { func (s *Storage) Reset() error {
_, err := s.db.Exec(s.sqlReset) _, err := s.db.Exec(context.Background(), s.sqlReset)
return err return err
} }
// Close the database // Close the database
func (s *Storage) Close() error { func (s *Storage) Close() {
s.done <- struct{}{} s.done <- struct{}{}
return s.db.Close() s.db.Stat()
s.db.Close()
} }
// gcTicker starts the gc ticker // gcTicker starts the gc ticker
@@ -196,13 +194,13 @@ func (s *Storage) gcTicker() {
// gc deletes all expired entries // gc deletes all expired entries
func (s *Storage) gc(t time.Time) { func (s *Storage) gc(t time.Time) {
_, _ = s.db.Exec(s.sqlGC, t.Unix()) _, _ = s.db.Exec(context.Background(), s.sqlGC, t.Unix())
} }
func (s *Storage) checkSchema(tableName string) { func (s *Storage) checkSchema(tableName string) {
var data []byte var data []byte
row := s.db.QueryRow(fmt.Sprintf(checkSchemaQuery, tableName)) row := s.db.QueryRow(context.Background(), fmt.Sprintf(checkSchemaQuery, tableName))
if err := row.Scan(&data); err != nil { if err := row.Scan(&data); err != nil {
panic(err) panic(err)
} }
@@ -211,3 +209,7 @@ func (s *Storage) checkSchema(tableName string) {
fmt.Printf(checkSchemaMsg, string(data)) fmt.Printf(checkSchemaMsg, string(data))
} }
} }
func (s *Storage) DB() *pgxpool.Pool {
return s.db
}

View File

@@ -1,6 +1,7 @@
package postgres package postgres
import ( import (
"context"
"database/sql" "database/sql"
"os" "os"
"testing" "testing"
@@ -133,7 +134,7 @@ func Test_Postgres_GC(t *testing.T) {
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
testStore.gc(time.Now()) testStore.gc(time.Now())
row := testStore.db.QueryRow(testStore.sqlSelect, "john") row := testStore.db.QueryRow(context.Background(), testStore.sqlSelect, "john")
err = row.Scan(nil, nil) err = row.Scan(nil, nil)
utils.AssertEqual(t, sql.ErrNoRows, err) utils.AssertEqual(t, sql.ErrNoRows, err)
@@ -173,7 +174,3 @@ func Test_SslRequiredMode(t *testing.T) {
SslMode: "require", SslMode: "require",
}) })
} }
func Test_Postgres_Close(t *testing.T) {
utils.AssertEqual(t, nil, testStore.Close())
}