diff --git a/postgres/README.md b/postgres/README.md new file mode 100644 index 00000000..feada1db --- /dev/null +++ b/postgres/README.md @@ -0,0 +1,116 @@ +# Postgres + +A Postgres storage driver using [lib/pq](https://github.com/lib/pq). + +### Table of Contents +- [Signatures](#signatures) +- [Installation](#installation) +- [Examples](#examples) +- [Config](#config) +- [Default Config](#default-config) + +### Signatures +```go +func New(config ...Config) Storage +func (s *Storage) Get(key string) ([]byte, error) +func (s *Storage) Set(key string, val []byte, exp time.Duration) error +func (s *Storage) Delete(key string) error +func (s *Storage) Reset() error +func (s *Storage) Close() error +``` +### Installation +Postgres is tested on the 2 last [Go versions](https://golang.org/dl/) with support for modules. So make sure to initialize one first if you didn't do that yet: +```bash +go mod init github.com// +``` +And then install the postgres implementation: +```bash +go get github.com/gofiber/storage/postgres +``` + +### Examples +Import the storage package. +```go +import "github.com/gofiber/storage/postgres" +``` + +You can use the following possibilities to create a storage: +```go +// Initialize default config +store := postgres.New() + +// Initialize custom config +store := postgres.New(postgres.Config{ + Host: "127.0.0.1", + Port: 5432, + Database: "fiber", + Table: "fiber_storage", + Reset: false, + GCInterval: 10 * time.Second, + SslMode: "disable", +}) +``` + +### Config +```go +// Config defines the config for storage. +type Config struct { + // Host name where the DB is hosted + // + // Optional. Default is "127.0.0.1" + Host string + + // Port where the DB is listening on + // + // Optional. Default is 5432 + Port int + + // Server username + // + // Optional. Default is "" + Username string + + // Server password + // + // Optional. Default is "" + Password string + + // Database name + // + // Optional. Default is "fiber" + Database string + + // Table name + // + // Optional. Default is "fiber_storage" + Table string + + // Reset clears any existing keys in existing Table + // + // Optional. Default is false + Reset bool + + // Time before deleting expired keys + // + // Optional. Default is 10 * time.Second + GCInterval time.Duration + + // The SSL mode for the connection + // + // Optional. Default is "disable" + SslMode string +} +``` + +### Default Config +```go +var ConfigDefault = Config{ + Host: "127.0.0.1", + Port: 5432, + Database: "fiber", + Table: "fiber_storage", + Reset: false, + GCInterval: 10 * time.Second, + SslMode: "disable", +} +``` diff --git a/postgres/config.go b/postgres/config.go new file mode 100644 index 00000000..02969125 --- /dev/null +++ b/postgres/config.go @@ -0,0 +1,135 @@ +package postgres + +import ( + "time" +) + +// Config defines the config for storage. +type Config struct { + // Host name where the DB is hosted + // + // Optional. Default is "127.0.0.1" + Host string + + // Port where the DB is listening on + // + // Optional. Default is 5432 + Port int + + // Server username + // + // Optional. Default is "" + Username string + + // Server password + // + // Optional. Default is "" + Password string + + // Database name + // + // Optional. Default is "fiber" + Database string + + // Table name + // + // Optional. Default is "fiber_storage" + Table string + + // The SSL mode for the connection + // + // Optional. Default is "disable" + SslMode string + + // Reset clears any existing keys in existing Table + // + // Optional. Default is false + Reset bool + + // Time before deleting expired keys + // + // Optional. Default is 10 * time.Second + GCInterval time.Duration + + //////////////////////////////////// + // Adaptor related config options // + //////////////////////////////////// + + // Maximum wait for connection, in seconds. Zero or + // n < 0 means wait indefinitely. + timeout time.Duration + + // The maximum number of connections in the idle connection pool. + // + // If MaxOpenConns is greater than 0 but less than the new MaxIdleConns, + // then the new MaxIdleConns will be reduced to match the MaxOpenConns limit. + // + // If n <= 0, no idle connections are retained. + // + // The default max idle connections is currently 2. This may change in + // a future release. + maxIdleConns int + + // The maximum number of open connections to the database. + // + // If MaxIdleConns is greater than 0 and the new MaxOpenConns is less than + // MaxIdleConns, then MaxIdleConns will be reduced to match the new + // MaxOpenConns limit. + // + // If n <= 0, then there is no limit on the number of open connections. + // The default is 0 (unlimited). + maxOpenConns int + + // The maximum amount of time a connection may be reused. + // + // Expired connections may be closed lazily before reuse. + // + // If d <= 0, connections are reused forever. + connMaxLifetime time.Duration +} + +// ConfigDefault is the default config +var ConfigDefault = Config{ + Host: "127.0.0.1", + Port: 5432, + Database: "fiber", + Table: "fiber_storage", + SslMode: "disable", + Reset: false, + GCInterval: 10 * time.Second, + maxOpenConns: 100, + maxIdleConns: 100, + connMaxLifetime: 1 * time.Second, +} + +// Helper function to set default values +func configDefault(config ...Config) Config { + // Return default config if nothing provided + if len(config) < 1 { + return ConfigDefault + } + + // Override default config + cfg := config[0] + + // Set default values + if cfg.Host == "" { + cfg.Host = ConfigDefault.Host + } + if cfg.Port <= 0 { + cfg.Port = ConfigDefault.Port + } + if cfg.Database == "" { + cfg.Database = ConfigDefault.Database + } + if cfg.Table == "" { + cfg.Table = ConfigDefault.Table + } + if cfg.SslMode == "" { + cfg.SslMode = ConfigDefault.SslMode + } + if int(cfg.GCInterval.Seconds()) <= 0 { + cfg.GCInterval = ConfigDefault.GCInterval + } + return cfg +} diff --git a/postgres/postgres.go b/postgres/postgres.go new file mode 100644 index 00000000..ac0dc6b9 --- /dev/null +++ b/postgres/postgres.go @@ -0,0 +1,213 @@ +package postgres + +import ( + "database/sql" + "errors" + "fmt" + "net/url" + "strings" + "time" + + _ "github.com/lib/pq" +) + +// Storage interface that is implemented by storage providers +type Storage struct { + db *sql.DB + gcInterval time.Duration + done chan struct{} + + sqlSelect string + sqlInsert string + sqlDelete string + sqlReset string + sqlGC string +} + +var ( + checkSchemaMsg = "The `v` row has an incorrect data type. " + + "It should be BYTEA but is instead %s. This will cause encoding-related panics if the DB is not migrated (see https://github.com/gofiber/storage/blob/main/MIGRATE.md)." + dropQuery = `DROP TABLE IF EXISTS %s;` + initQuery = []string{ + `CREATE TABLE IF NOT EXISTS %s ( + k VARCHAR(64) PRIMARY KEY NOT NULL DEFAULT '', + v BYTEA NOT NULL, + e BIGINT NOT NULL DEFAULT '0' + );`, + `CREATE INDEX IF NOT EXISTS e ON %s (e);`, + } + checkSchemaQuery = `SELECT DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS + WHERE table_name = '%s' AND COLUMN_NAME = 'v';` +) + +// New creates a new storage +func New(config ...Config) *Storage { + // Set default config + cfg := configDefault(config...) + + // Create data source name + var dsn string = "postgresql://" + if cfg.Username != "" { + dsn += url.QueryEscape(cfg.Username) + } + if cfg.Password != "" { + dsn += ":" + cfg.Password + } + if cfg.Username != "" || cfg.Password != "" { + dsn += "@" + } + dsn += fmt.Sprintf("%s:%d", url.QueryEscape(cfg.Host), cfg.Port) + dsn += fmt.Sprintf("/%s?connect_timeout=%d&sslmode=%s", + url.QueryEscape(cfg.Database), + int64(cfg.timeout.Seconds()), + cfg.SslMode, + ) + + // Create db + db, err := sql.Open("postgres", dsn) + if err != nil { + panic(err) + } + + // Set database options + db.SetMaxOpenConns(cfg.maxOpenConns) + db.SetMaxIdleConns(cfg.maxIdleConns) + db.SetConnMaxLifetime(cfg.connMaxLifetime) + + // Ping database + if err := db.Ping(); err != nil { + panic(err) + } + + // Drop table if set to true + if cfg.Reset { + if _, err = db.Exec(fmt.Sprintf(dropQuery, cfg.Table)); err != nil { + _ = db.Close() + panic(err) + } + } + + // Init database queries + for _, query := range initQuery { + if _, err := db.Exec(fmt.Sprintf(query, cfg.Table)); err != nil { + _ = db.Close() + + panic(err) + } + } + + // Create storage + store := &Storage{ + db: db, + gcInterval: cfg.GCInterval, + done: make(chan struct{}), + sqlSelect: fmt.Sprintf(`SELECT v, e FROM %s WHERE k=$1;`, cfg.Table), + sqlInsert: fmt.Sprintf("INSERT INTO %s (k, v, e) VALUES ($1, $2, $3) ON CONFLICT (k) DO UPDATE SET v = $4, e = $5", cfg.Table), + sqlDelete: fmt.Sprintf("DELETE FROM %s WHERE k=$1", cfg.Table), + sqlReset: fmt.Sprintf("TRUNCATE TABLE %s;", cfg.Table), + sqlGC: fmt.Sprintf("DELETE FROM %s WHERE e <= $1 AND e != 0", cfg.Table), + } + + store.checkSchema(cfg.Table) + + // Start garbage collector + go store.gcTicker() + + return store +} + +var noRows = errors.New("sql: no rows in result set") + +// Get value by key +func (s *Storage) Get(key string) ([]byte, error) { + if len(key) <= 0 { + return nil, nil + } + row := s.db.QueryRow(s.sqlSelect, key) + // Add db response to data + var ( + data = []byte{} + exp int64 = 0 + ) + if err := row.Scan(&data, &exp); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + + // If the expiration time has already passed, then return nil + if exp != 0 && exp <= time.Now().Unix() { + return nil, nil + } + + return data, nil +} + +// Set key with value +func (s *Storage) Set(key string, val []byte, exp time.Duration) error { + // Ain't Nobody Got Time For That + if len(key) <= 0 || len(val) <= 0 { + return nil + } + var expSeconds int64 + if exp != 0 { + expSeconds = time.Now().Add(exp).Unix() + } + _, err := s.db.Exec(s.sqlInsert, key, val, expSeconds, val, expSeconds) + return err +} + +// Delete entry by key +func (s *Storage) Delete(key string) error { + // Ain't Nobody Got Time For That + if len(key) <= 0 { + return nil + } + _, err := s.db.Exec(s.sqlDelete, key) + return err +} + +// Reset all entries, including unexpired +func (s *Storage) Reset() error { + _, err := s.db.Exec(s.sqlReset) + return err +} + +// Close the database +func (s *Storage) Close() error { + s.done <- struct{}{} + return s.db.Close() +} + +// gcTicker starts the gc ticker +func (s *Storage) gcTicker() { + ticker := time.NewTicker(s.gcInterval) + defer ticker.Stop() + for { + select { + case <-s.done: + return + case t := <-ticker.C: + s.gc(t) + } + } +} + +// gc deletes all expired entries +func (s *Storage) gc(t time.Time) { + _, _ = s.db.Exec(s.sqlGC, t.Unix()) +} + +func (s *Storage) checkSchema(tableName string) { + var data []byte + + row := s.db.QueryRow(fmt.Sprintf(checkSchemaQuery, tableName)) + if err := row.Scan(&data); err != nil { + panic(err) + } + + if strings.ToLower(string(data)) != "bytea" { + fmt.Printf(checkSchemaMsg, string(data)) + } +} diff --git a/postgres/postgres_test.go b/postgres/postgres_test.go new file mode 100644 index 00000000..062b235f --- /dev/null +++ b/postgres/postgres_test.go @@ -0,0 +1,179 @@ +package postgres + +import ( + "database/sql" + "os" + "testing" + "time" + + "github.com/gofiber/utils" +) + +var testStore = New(Config{ + Database: os.Getenv("POSTGRES_DATABASE"), + Username: os.Getenv("POSTGRES_USERNAME"), + Password: os.Getenv("POSTGRES_PASSWORD"), + Reset: true, +}) + +func Test_Postgres_Set(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + err := testStore.Set(key, val, 0) + utils.AssertEqual(t, nil, err) +} + +func Test_Postgres_Set_Override(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + err := testStore.Set(key, val, 0) + utils.AssertEqual(t, nil, err) + + err = testStore.Set(key, val, 0) + utils.AssertEqual(t, nil, err) +} + +func Test_Postgres_Get(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + err := testStore.Set(key, val, 0) + utils.AssertEqual(t, nil, err) + + result, err := testStore.Get(key) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, val, result) +} + +func Test_Postgres_Set_Expiration(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + exp = 1 * time.Second + ) + + err := testStore.Set(key, val, exp) + utils.AssertEqual(t, nil, err) + + time.Sleep(1100 * time.Millisecond) +} + +func Test_Postgres_Get_Expired(t *testing.T) { + var ( + key = "john" + ) + + result, err := testStore.Get(key) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, true, len(result) == 0) +} + +func Test_Postgres_Get_NotExist(t *testing.T) { + + result, err := testStore.Get("notexist") + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, true, len(result) == 0) +} + +func Test_Postgres_Delete(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + err := testStore.Set(key, val, 0) + utils.AssertEqual(t, nil, err) + + err = testStore.Delete(key) + utils.AssertEqual(t, nil, err) + + result, err := testStore.Get(key) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, true, len(result) == 0) +} + +func Test_Postgres_Reset(t *testing.T) { + var ( + val = []byte("doe") + ) + + err := testStore.Set("john1", val, 0) + utils.AssertEqual(t, nil, err) + + err = testStore.Set("john2", val, 0) + utils.AssertEqual(t, nil, err) + + err = testStore.Reset() + utils.AssertEqual(t, nil, err) + + result, err := testStore.Get("john1") + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, true, len(result) == 0) + + result, err = testStore.Get("john2") + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, true, len(result) == 0) +} + +func Test_Postgres_GC(t *testing.T) { + var ( + testVal = []byte("doe") + ) + + // This key should expire + err := testStore.Set("john", testVal, time.Nanosecond) + utils.AssertEqual(t, nil, err) + + testStore.gc(time.Now()) + row := testStore.db.QueryRow(testStore.sqlSelect, "john") + err = row.Scan(nil, nil) + utils.AssertEqual(t, sql.ErrNoRows, err) + + // This key should not expire + err = testStore.Set("john", testVal, 0) + utils.AssertEqual(t, nil, err) + + testStore.gc(time.Now()) + val, err := testStore.Get("john") + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, testVal, val) + +} + +func Test_Postgres_Non_UTF8(t *testing.T) { + val := []byte("0xF5") + + err := testStore.Set("0xF6", val, 0) + utils.AssertEqual(t, nil, err) + + result, err := testStore.Get("0xF6") + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, val, result) +} + +func Test_SslRequiredMode(t *testing.T) { + defer func() { + if recover() == nil { + utils.AssertEqual(t, true, nil, "Connection was established with a `require`") + } + }() + _ = New(Config{ + Database: "fiber", + Username: "username", + Password: "password", + Reset: true, + SslMode: "require", + }) +} + +func Test_Postgres_Close(t *testing.T) { + utils.AssertEqual(t, nil, testStore.Close()) +}