diff --git a/.gitignore b/.gitignore index 78e20aa2..e49c0051 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,7 @@ *.fasthttp.gz *.pprof *.workspace +/tmp/ # Dependencies /vendor/ diff --git a/postgres/postgres.go b/postgres/postgres.go index e995e954..3250b187 100644 --- a/postgres/postgres.go +++ b/postgres/postgres.go @@ -26,19 +26,31 @@ type Storage struct { } var ( - checkSchemaMsg = "The `v` row has an incorrect data type. " + - "It should be BYTEA but is instead %s. This will cause encoding-related panics if the DB is not migrated (see https://github.com/gofiber/storage/blob/main/MIGRATE.md)." - dropQuery = `DROP TABLE IF EXISTS %s;` + checkSchemaMsg = "The `%s` row has an incorrect data type. " + + "It should be %s but is instead %s. This will cause encoding-related panics if the DB is not migrated (see https://github.com/gofiber/storage/blob/main/MIGRATE.md)." + dropQuery = `DROP TABLE IF EXISTS %s;` + checkTableExistsQuery = `SELECT COUNT(table_name) + FROM information_schema.tables + WHERE table_schema = '%s' + AND table_name = '%s';` initQuery = []string{ - `CREATE TABLE IF NOT EXISTS %s ( + `CREATE TABLE %s ( k VARCHAR(64) PRIMARY KEY NOT NULL DEFAULT '', v BYTEA NOT NULL, e BIGINT NOT NULL DEFAULT '0' );`, `CREATE INDEX IF NOT EXISTS e ON %s (e);`, } - checkSchemaQuery = `SELECT DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS - WHERE table_name = '%s' AND COLUMN_NAME = 'v';` + checkSchemaQuery = `SELECT column_name, data_type + FROM information_schema.columns + WHERE table_schema = '%s' + AND table_name = '%s' + AND column_name IN ('k','v','e');` + checkSchemaTargetDataType = map[string]string{ + "k": "character varying", + "v": "bytea", + "e": "bigint", + } ) // New creates a new storage @@ -61,6 +73,14 @@ func New(config ...Config) *Storage { panic(err) } + // Parse out schema in config, if provided + schema := "public" + tableName := cfg.Table + if strings.Contains(cfg.Table, ".") { + schema = strings.Split(cfg.Table, ".")[0] + tableName = strings.Split(cfg.Table, ".")[1] + } + // Drop table if set to true if cfg.Reset { if _, err := db.Exec(context.Background(), fmt.Sprintf(dropQuery, cfg.Table)); err != nil { @@ -69,11 +89,23 @@ func New(config ...Config) *Storage { } } + // Determine if table exists + tableExists := false + row := db.QueryRow(context.Background(), fmt.Sprintf(checkTableExistsQuery, schema, tableName)) + var count int + if err := row.Scan(&count); err != nil { + db.Close() + panic(err) + } + tableExists = count > 0 + // Init database queries - for _, query := range initQuery { - if _, err := db.Exec(context.Background(), fmt.Sprintf(query, cfg.Table)); err != nil { - db.Close() - panic(err) + if !tableExists { + for _, query := range initQuery { + if _, err := db.Exec(context.Background(), fmt.Sprintf(query, cfg.Table)); err != nil { + db.Close() + panic(err) + } } } @@ -185,15 +217,41 @@ func (s *Storage) gc(t time.Time) { _, _ = s.db.Exec(context.Background(), s.sqlGC, t.Unix()) } -func (s *Storage) checkSchema(tableName string) { - var data []byte +func (s *Storage) checkSchema(fullTableName string) { + schema := "public" + tableName := fullTableName + if strings.Contains(fullTableName, ".") { + schema = strings.Split(fullTableName, ".")[0] + tableName = strings.Split(fullTableName, ".")[1] + } - row := s.db.QueryRow(context.Background(), fmt.Sprintf(checkSchemaQuery, tableName)) - if err := row.Scan(&data); err != nil { + rows, err := s.db.Query(context.Background(), fmt.Sprintf(checkSchemaQuery, schema, tableName)) + if err != nil { panic(err) } + defer rows.Close() - if strings.ToLower(string(data)) != "bytea" { - fmt.Printf(checkSchemaMsg, string(data)) + data := make(map[string]string) + + rowCount := 0 + for rows.Next() { + var columnName, dataType string + if err := rows.Scan(&columnName, &dataType); err != nil { + panic(err) + } + data[columnName] = dataType + rowCount++ + } + if rowCount == 0 { + panic(fmt.Errorf("table %s does not exist", tableName)) + } + for columnName, dataType := range checkSchemaTargetDataType { + dt, ok := data[columnName] + if !ok { + panic(fmt.Errorf("required column %s does not exist in table %s", columnName, tableName)) + } + if dt != dataType { + panic(fmt.Errorf(checkSchemaMsg, columnName, dataType, dt)) + } } } diff --git a/postgres/postgres_test.go b/postgres/postgres_test.go index 0a226522..ad02c0b8 100644 --- a/postgres/postgres_test.go +++ b/postgres/postgres_test.go @@ -2,7 +2,9 @@ package postgres import ( "context" + "math/rand" "os" + "strconv" "testing" "time" @@ -17,6 +19,157 @@ var testStore = New(Config{ Reset: true, }) +func TestNoCreateUser(t *testing.T) { + // Create a new user + // give the use usage permissions to the database (but not create) + ctx := context.Background() + conn := testStore.Conn() + + username := "testuser" + strconv.Itoa(rand.Intn(1_000_000)) + password := "testpassword" + + _, err := conn.Exec(ctx, "CREATE USER "+username+" WITH PASSWORD '"+password+"'") + require.NoError(t, err) + + _, err = conn.Exec(ctx, "GRANT CONNECT ON DATABASE "+os.Getenv("POSTGRES_DATABASE")+" TO "+username) + require.NoError(t, err) + + _, err = conn.Exec(ctx, "GRANT USAGE ON SCHEMA public TO "+username) + require.NoError(t, err) + + _, err = conn.Exec(ctx, "GRANT SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA public TO "+username) + require.NoError(t, err) + + _, err = conn.Exec(ctx, "ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT, INSERT, UPDATE, DELETE ON TABLES TO "+username) + require.NoError(t, err) + + _, err = conn.Exec(ctx, "REVOKE CREATE ON SCHEMA public FROM "+username) + require.NoError(t, err) + + t.Run("should panic if limited user tries to create table", func(t *testing.T) { + tableThatDoesNotExist := "public.table_does_not_exists_" + strconv.Itoa(rand.Intn(1_000_000)) + + defer func() { + r := recover() + require.NotNil(t, r, "Expected a panic when creating a table without permissions") + }() + + // This should panic since the user doesn't have CREATE permissions + New(Config{ + Database: os.Getenv("POSTGRES_DATABASE"), + Username: username, + Password: password, + Reset: true, + Table: tableThatDoesNotExist, + }) + }) + + // connect to an existing table using an unprivileged user + limitedStore := New(Config{ + Database: os.Getenv("POSTGRES_DATABASE"), + Username: username, + Password: password, + Reset: false, + }) + + defer func() { + limitedStore.Close() + conn.Exec(ctx, "DROP USER "+username) + }() + + t.Run("shoud set", func(t *testing.T) { + var ( + key = "john" + strconv.Itoa(rand.Intn(1_000_000)) + val = []byte("doe" + strconv.Itoa(rand.Intn(1_000_000))) + ) + + err := limitedStore.Set(key, val, 0) + require.NoError(t, err) + }) + t.Run("should set override", func(t *testing.T) { + var ( + key = "john" + strconv.Itoa(rand.Intn(1_000_000)) + val = []byte("doe" + strconv.Itoa(rand.Intn(1_000_000))) + ) + err := limitedStore.Set(key, val, 0) + require.NoError(t, err) + err = limitedStore.Set(key, val, 0) + require.NoError(t, err) + }) + t.Run("should get", func(t *testing.T) { + var ( + key = "john" + strconv.Itoa(rand.Intn(1_000_000)) + val = []byte("doe" + strconv.Itoa(rand.Intn(1_000_000))) + ) + err := limitedStore.Set(key, val, 0) + require.NoError(t, err) + result, err := limitedStore.Get(key) + require.NoError(t, err) + require.Equal(t, val, result) + }) + t.Run("should set expiration", func(t *testing.T) { + var ( + key = "john" + strconv.Itoa(rand.Intn(1_000_000)) + val = []byte("doe" + strconv.Itoa(rand.Intn(1_000_000))) + exp = 100 * time.Millisecond + ) + err := limitedStore.Set(key, val, exp) + require.NoError(t, err) + }) + t.Run("should get expired", func(t *testing.T) { + var ( + key = "john" + strconv.Itoa(rand.Intn(1_000_000)) + val = []byte("doe" + strconv.Itoa(rand.Intn(1_000_000))) + exp = 100 * time.Millisecond + ) + err := limitedStore.Set(key, val, exp) + require.NoError(t, err) + time.Sleep(200 * time.Millisecond) + result, err := limitedStore.Get(key) + require.NoError(t, err) + require.Zero(t, len(result)) + }) + t.Run("should get not exists", func(t *testing.T) { + result, err := limitedStore.Get("nonexistentkey") + require.NoError(t, err) + require.Zero(t, len(result)) + }) + t.Run("should delete", func(t *testing.T) { + var ( + key = "john" + strconv.Itoa(rand.Intn(1_000_000)) + val = []byte("doe" + strconv.Itoa(rand.Intn(1_000_000))) + ) + err := limitedStore.Set(key, val, 0) + require.NoError(t, err) + err = limitedStore.Delete(key) + require.NoError(t, err) + result, err := limitedStore.Get(key) + require.NoError(t, err) + require.Zero(t, len(result)) + }) + +} +func Test_Should_Panic_On_Wrong_Schema(t *testing.T) { + // Create a test table with wrong schema + _, err := testStore.Conn().Exec(context.Background(), ` + CREATE TABLE IF NOT EXISTS test_schema_table ( + k VARCHAR(64) PRIMARY KEY NOT NULL DEFAULT '', + v BYTEA NOT NULL, + e VARCHAR(64) NOT NULL DEFAULT '' -- Changed e from BIGINT to VARCHAR + ); + `) + require.NoError(t, err) + defer func() { + _, err := testStore.Conn().Exec(context.Background(), "DROP TABLE IF EXISTS test_schema_table;") + require.NoError(t, err) + }() + + // Call checkSchema with the wrong table + require.Panics(t, func() { + testStore.checkSchema("test_schema_table") + }) +} + func Test_Postgres_Set(t *testing.T) { var ( key = "john"