diff --git a/postgres/README.md b/postgres/README.md index feada1db..ee2be004 100644 --- a/postgres/README.md +++ b/postgres/README.md @@ -1,6 +1,6 @@ # Postgres -A Postgres storage driver using [lib/pq](https://github.com/lib/pq). +A Postgres storage driver using [jackc/pgx](https://github.com/jackc/pgx). ### Table of Contents - [Signatures](#signatures) @@ -16,7 +16,7 @@ func (s *Storage) Get(key string) ([]byte, error) func (s *Storage) Set(key string, val []byte, exp time.Duration) error func (s *Storage) Delete(key string) error func (s *Storage) Reset() error -func (s *Storage) Close() error +func (s *Storage) Close() ``` ### Installation Postgres is tested on the 2 last [Go versions](https://golang.org/dl/) with support for modules. So make sure to initialize one first if you didn't do that yet: diff --git a/postgres/config.go b/postgres/config.go index 02969125..f1695ac2 100644 --- a/postgres/config.go +++ b/postgres/config.go @@ -59,17 +59,6 @@ type Config struct { // n < 0 means wait indefinitely. timeout time.Duration - // The maximum number of connections in the idle connection pool. - // - // If MaxOpenConns is greater than 0 but less than the new MaxIdleConns, - // then the new MaxIdleConns will be reduced to match the MaxOpenConns limit. - // - // If n <= 0, no idle connections are retained. - // - // The default max idle connections is currently 2. This may change in - // a future release. - maxIdleConns int - // The maximum number of open connections to the database. // // If MaxIdleConns is greater than 0 and the new MaxOpenConns is less than @@ -78,7 +67,7 @@ type Config struct { // // If n <= 0, then there is no limit on the number of open connections. // The default is 0 (unlimited). - maxOpenConns int + maxOpenConns int32 // The maximum amount of time a connection may be reused. // @@ -98,7 +87,6 @@ var ConfigDefault = Config{ Reset: false, GCInterval: 10 * time.Second, maxOpenConns: 100, - maxIdleConns: 100, connMaxLifetime: 1 * time.Second, } diff --git a/postgres/postgres.go b/postgres/postgres.go index ac0dc6b9..054d0d92 100644 --- a/postgres/postgres.go +++ b/postgres/postgres.go @@ -1,19 +1,19 @@ package postgres import ( + "context" "database/sql" - "errors" "fmt" "net/url" "strings" "time" - _ "github.com/lib/pq" + "github.com/jackc/pgx/v4/pgxpool" ) // Storage interface that is implemented by storage providers type Storage struct { - db *sql.DB + db *pgxpool.Pool gcInterval time.Duration done chan struct{} @@ -46,7 +46,7 @@ func New(config ...Config) *Storage { cfg := configDefault(config...) // Create data source name - var dsn string = "postgresql://" + var dsn = "postgresql://" if cfg.Username != "" { dsn += url.QueryEscape(cfg.Username) } @@ -63,35 +63,34 @@ func New(config ...Config) *Storage { cfg.SslMode, ) - // Create db - db, err := sql.Open("postgres", dsn) + cnf, _ := pgxpool.ParseConfig(dsn) + cnf.MaxConns = cfg.maxOpenConns + cnf.MaxConnLifetime = cfg.connMaxLifetime + cnf.MaxConnIdleTime = cfg.connMaxLifetime + + db, err := pgxpool.ConnectConfig(context.Background(), cnf) if err != nil { panic(err) } - - // Set database options - db.SetMaxOpenConns(cfg.maxOpenConns) - db.SetMaxIdleConns(cfg.maxIdleConns) - db.SetConnMaxLifetime(cfg.connMaxLifetime) + defer db.Close() // Ping database - if err := db.Ping(); err != nil { + if err := db.Ping(context.Background()); err != nil { panic(err) } // Drop table if set to true if cfg.Reset { - if _, err = db.Exec(fmt.Sprintf(dropQuery, cfg.Table)); err != nil { - _ = db.Close() + if _, err = db.Exec(context.Background(), fmt.Sprintf(dropQuery, cfg.Table)); err != nil { + db.Close() panic(err) } } // Init database queries for _, query := range initQuery { - if _, err := db.Exec(fmt.Sprintf(query, cfg.Table)); err != nil { - _ = db.Close() - + if _, err := db.Exec(context.Background(), fmt.Sprintf(query, cfg.Table)); err != nil { + db.Close() panic(err) } } @@ -116,17 +115,15 @@ func New(config ...Config) *Storage { return store } -var noRows = errors.New("sql: no rows in result set") - // Get value by key func (s *Storage) Get(key string) ([]byte, error) { if len(key) <= 0 { return nil, nil } - row := s.db.QueryRow(s.sqlSelect, key) + row := s.db.QueryRow(context.Background(), s.sqlSelect, key) // Add db response to data var ( - data = []byte{} + data []byte exp int64 = 0 ) if err := row.Scan(&data, &exp); err != nil { @@ -154,7 +151,7 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { if exp != 0 { expSeconds = time.Now().Add(exp).Unix() } - _, err := s.db.Exec(s.sqlInsert, key, val, expSeconds, val, expSeconds) + _, err := s.db.Exec(context.Background(), s.sqlInsert, key, val, expSeconds, val, expSeconds) return err } @@ -164,20 +161,21 @@ func (s *Storage) Delete(key string) error { if len(key) <= 0 { return nil } - _, err := s.db.Exec(s.sqlDelete, key) + _, err := s.db.Exec(context.Background(), s.sqlDelete, key) return err } // Reset all entries, including unexpired func (s *Storage) Reset() error { - _, err := s.db.Exec(s.sqlReset) + _, err := s.db.Exec(context.Background(), s.sqlReset) return err } // Close the database -func (s *Storage) Close() error { +func (s *Storage) Close() { s.done <- struct{}{} - return s.db.Close() + s.db.Stat() + s.db.Close() } // gcTicker starts the gc ticker @@ -196,13 +194,13 @@ func (s *Storage) gcTicker() { // gc deletes all expired entries func (s *Storage) gc(t time.Time) { - _, _ = s.db.Exec(s.sqlGC, t.Unix()) + _, _ = s.db.Exec(context.Background(), s.sqlGC, t.Unix()) } func (s *Storage) checkSchema(tableName string) { var data []byte - row := s.db.QueryRow(fmt.Sprintf(checkSchemaQuery, tableName)) + row := s.db.QueryRow(context.Background(), fmt.Sprintf(checkSchemaQuery, tableName)) if err := row.Scan(&data); err != nil { panic(err) } @@ -211,3 +209,7 @@ func (s *Storage) checkSchema(tableName string) { fmt.Printf(checkSchemaMsg, string(data)) } } + +func (s *Storage) DB() *pgxpool.Pool { + return s.db +} diff --git a/postgres/postgres_test.go b/postgres/postgres_test.go index 062b235f..830cca76 100644 --- a/postgres/postgres_test.go +++ b/postgres/postgres_test.go @@ -1,6 +1,7 @@ package postgres import ( + "context" "database/sql" "os" "testing" @@ -133,7 +134,7 @@ func Test_Postgres_GC(t *testing.T) { utils.AssertEqual(t, nil, err) testStore.gc(time.Now()) - row := testStore.db.QueryRow(testStore.sqlSelect, "john") + row := testStore.db.QueryRow(context.Background(), testStore.sqlSelect, "john") err = row.Scan(nil, nil) utils.AssertEqual(t, sql.ErrNoRows, err) @@ -173,7 +174,3 @@ func Test_SslRequiredMode(t *testing.T) { SslMode: "require", }) } - -func Test_Postgres_Close(t *testing.T) { - utils.AssertEqual(t, nil, testStore.Close()) -}