diff --git a/postgres/README.md b/postgres/README.md index b5c2d433..feada1db 100644 --- a/postgres/README.md +++ b/postgres/README.md @@ -47,6 +47,7 @@ store := postgres.New(postgres.Config{ Table: "fiber_storage", Reset: false, GCInterval: 10 * time.Second, + SslMode: "disable", }) ``` @@ -93,6 +94,11 @@ type Config struct { // // Optional. Default is 10 * time.Second GCInterval time.Duration + + // The SSL mode for the connection + // + // Optional. Default is "disable" + SslMode string } ``` @@ -105,5 +111,6 @@ var ConfigDefault = Config{ Table: "fiber_storage", Reset: false, GCInterval: 10 * time.Second, + SslMode: "disable", } ``` diff --git a/postgres/config.go b/postgres/config.go index 0b6a9c23..02969125 100644 --- a/postgres/config.go +++ b/postgres/config.go @@ -36,6 +36,11 @@ type Config struct { // Optional. Default is "fiber_storage" Table string + // The SSL mode for the connection + // + // Optional. Default is "disable" + SslMode string + // Reset clears any existing keys in existing Table // // Optional. Default is false @@ -89,6 +94,7 @@ var ConfigDefault = Config{ Port: 5432, Database: "fiber", Table: "fiber_storage", + SslMode: "disable", Reset: false, GCInterval: 10 * time.Second, maxOpenConns: 100, @@ -119,6 +125,9 @@ func configDefault(config ...Config) Config { if cfg.Table == "" { cfg.Table = ConfigDefault.Table } + if cfg.SslMode == "" { + cfg.SslMode = ConfigDefault.SslMode + } if int(cfg.GCInterval.Seconds()) <= 0 { cfg.GCInterval = ConfigDefault.GCInterval } diff --git a/postgres/postgres.go b/postgres/postgres.go index 3e5797da..ac0dc6b9 100644 --- a/postgres/postgres.go +++ b/postgres/postgres.go @@ -57,9 +57,10 @@ func New(config ...Config) *Storage { dsn += "@" } dsn += fmt.Sprintf("%s:%d", url.QueryEscape(cfg.Host), cfg.Port) - dsn += fmt.Sprintf("/%s?connect_timeout=%d&sslmode=disable", + dsn += fmt.Sprintf("/%s?connect_timeout=%d&sslmode=%s", url.QueryEscape(cfg.Database), int64(cfg.timeout.Seconds()), + cfg.SslMode, ) // Create db diff --git a/postgres/postgres_test.go b/postgres/postgres_test.go index c8880f52..062b235f 100644 --- a/postgres/postgres_test.go +++ b/postgres/postgres_test.go @@ -159,6 +159,21 @@ func Test_Postgres_Non_UTF8(t *testing.T) { utils.AssertEqual(t, val, result) } +func Test_SslRequiredMode(t *testing.T) { + defer func() { + if recover() == nil { + utils.AssertEqual(t, true, nil, "Connection was established with a `require`") + } + }() + _ = New(Config{ + Database: "fiber", + Username: "username", + Password: "password", + Reset: true, + SslMode: "require", + }) +} + func Test_Postgres_Close(t *testing.T) { utils.AssertEqual(t, nil, testStore.Close()) }