This commit is contained in:
Muhammed Efe Çetin
2023-03-12 13:17:11 +03:00
parent 5627741945
commit 57f18bb17e
2 changed files with 36 additions and 3 deletions

View File

@@ -2,6 +2,8 @@ package postgres
import ( import (
"fmt" "fmt"
"net/url"
"strings"
"time" "time"
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
@@ -49,6 +51,11 @@ type Config struct {
// Optional. Default is "fiber_storage" // Optional. Default is "fiber_storage"
Table string Table string
// The SSL mode for the connection
//
// Optional. Default is "disable"
SSLMode string
// Reset clears any existing keys in existing Table // Reset clears any existing keys in existing Table
// //
// Optional. Default is false // Optional. Default is false
@@ -67,15 +74,41 @@ var ConfigDefault = Config{
Port: 5432, Port: 5432,
Database: "fiber", Database: "fiber",
Table: "fiber_storage", Table: "fiber_storage",
SSLMode: "disable",
Reset: false, Reset: false,
GCInterval: 10 * time.Second, GCInterval: 10 * time.Second,
} }
func (c Config) dsn() string { func (c *Config) getDSN() string {
// Just return ConnectionURI if it's already exists
if c.ConnectionURI != "" { if c.ConnectionURI != "" {
return c.ConnectionURI return c.ConnectionURI
} }
return fmt.Sprintf("postgres://%s:%s@%s:%d/%s", c.Username, c.Password, c.Host, c.Port, c.Database)
// Generate DSN
dsn := "postgresql://"
if c.Username != "" {
dsn += url.QueryEscape(c.Username)
}
if c.Password != "" {
dsn += ":" + url.QueryEscape(c.Password)
}
if c.Username != "" || c.Password != "" {
dsn += "@"
}
// unix socket host path
if strings.HasPrefix(c.Host, "/") {
dsn += fmt.Sprintf("%s:%d", c.Host, c.Port)
} else {
dsn += fmt.Sprintf("%s:%d", url.QueryEscape(c.Host), c.Port)
}
dsn += fmt.Sprintf("/%ssslmode=%s",
url.QueryEscape(c.Database),
c.SSLMode)
return dsn
} }
// Helper function to set default values // Helper function to set default values

View File

@@ -50,7 +50,7 @@ func New(config ...Config) *Storage {
var err error var err error
db := cfg.DB db := cfg.DB
if db == nil { if db == nil {
db, err = pgxpool.New(context.Background(), cfg.dsn()) db, err = pgxpool.New(context.Background(), cfg.getDSN())
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Unable to create connection pool: %v\n", err) fmt.Fprintf(os.Stderr, "Unable to create connection pool: %v\n", err)
} }