diff --git a/postgres/config.go b/postgres/config.go index 64ed4269..ebbcfa24 100644 --- a/postgres/config.go +++ b/postgres/config.go @@ -2,6 +2,8 @@ package postgres import ( "fmt" + "net/url" + "strings" "time" "github.com/jackc/pgx/v5/pgxpool" @@ -49,6 +51,11 @@ type Config struct { // Optional. Default is "fiber_storage" Table string + // The SSL mode for the connection + // + // Optional. Default is "disable" + SSLMode string + // Reset clears any existing keys in existing Table // // Optional. Default is false @@ -67,15 +74,41 @@ var ConfigDefault = Config{ Port: 5432, Database: "fiber", Table: "fiber_storage", + SSLMode: "disable", Reset: false, GCInterval: 10 * time.Second, } -func (c Config) dsn() string { +func (c *Config) getDSN() string { + // Just return ConnectionURI if it's already exists if c.ConnectionURI != "" { return c.ConnectionURI } - return fmt.Sprintf("postgres://%s:%s@%s:%d/%s", c.Username, c.Password, c.Host, c.Port, c.Database) + + // Generate DSN + dsn := "postgresql://" + if c.Username != "" { + dsn += url.QueryEscape(c.Username) + } + if c.Password != "" { + dsn += ":" + url.QueryEscape(c.Password) + } + if c.Username != "" || c.Password != "" { + dsn += "@" + } + + // unix socket host path + if strings.HasPrefix(c.Host, "/") { + dsn += fmt.Sprintf("%s:%d", c.Host, c.Port) + } else { + dsn += fmt.Sprintf("%s:%d", url.QueryEscape(c.Host), c.Port) + } + + dsn += fmt.Sprintf("/%ssslmode=%s", + url.QueryEscape(c.Database), + c.SSLMode) + + return dsn } // Helper function to set default values diff --git a/postgres/postgres.go b/postgres/postgres.go index 348dfefd..e995e954 100644 --- a/postgres/postgres.go +++ b/postgres/postgres.go @@ -50,7 +50,7 @@ func New(config ...Config) *Storage { var err error db := cfg.DB if db == nil { - db, err = pgxpool.New(context.Background(), cfg.dsn()) + db, err = pgxpool.New(context.Background(), cfg.getDSN()) if err != nil { fmt.Fprintf(os.Stderr, "Unable to create connection pool: %v\n", err) }