From 508747c957d2af74c867d450030ab31d036e6580 Mon Sep 17 00:00:00 2001 From: Logan Date: Thu, 19 Aug 2021 00:02:34 +0000 Subject: [PATCH] Persist pragma configurations via url parameter --- AUTHORS | 1 + CONTRIBUTORS | 1 + all_test.go | 163 +++++++++++++++++++++++++++++++++++++++++++++++++++ sqlite.go | 38 +++++++++++- 4 files changed, 201 insertions(+), 2 deletions(-) diff --git a/AUTHORS b/AUTHORS index eee21ef..22ed356 100644 --- a/AUTHORS +++ b/AUTHORS @@ -12,6 +12,7 @@ Dan Peterson Davsk Ltd Co Jaap Aarts Jan Mercl <0xjnml@gmail.com> +Logan Snow Ross Light Steffen Butzer Saed SayedAhmed diff --git a/CONTRIBUTORS b/CONTRIBUTORS index 61a74c3..af689e6 100644 --- a/CONTRIBUTORS +++ b/CONTRIBUTORS @@ -11,6 +11,7 @@ Dan Peterson David Skinner Jaap Aarts Jan Mercl <0xjnml@gmail.com> +Logan Snow Ross Light Steffen Butzer Yaacov Akiba Slama diff --git a/all_test.go b/all_test.go index 686f012..09169f0 100644 --- a/all_test.go +++ b/all_test.go @@ -8,10 +8,12 @@ import ( "bytes" "context" "database/sql" + "database/sql/driver" "flag" "fmt" "io/ioutil" "math/rand" + "net/url" "os" "os/exec" "path" @@ -1574,6 +1576,167 @@ CREATE TABLE IF NOT EXISTS loginst ( } +// https://gitlab.com/cznic/sqlite/-/issues/37 +func TestPersistPragma(t *testing.T) { + if err := emptyDir(tempDir); err != nil { + t.Fatal(err) + } + + wd, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + + defer os.Chdir(wd) + + if err := os.Chdir(tempDir); err != nil { + t.Fatal(err) + } + + pragmas := []pragmaCfg{ + {"foreign_keys", "on", int64(1)}, + {"analysis_limit", "1000", int64(1000)}, + {"application_id", "214", int64(214)}, + {"encoding", "'UTF-16le'", "UTF-16le"}} + + if err := testPragmas("x.sqlite", "x.sqlite", pragmas); err != nil { + t.Fatal(err) + } + if err := testPragmas("file::memory:", "", pragmas); err != nil { + t.Fatal(err) + } + if err := testPragmas(":memory:", "", pragmas); err != nil { + t.Fatal(err) + } +} + +type pragmaCfg struct { + name string + value string + expected interface{} +} + +func testPragmas(name, diskFile string, pragmas []pragmaCfg) error { + if diskFile != "" { + os.Remove(diskFile) + } + + q := url.Values{} + for _, pragma := range pragmas { + q.Add("_pragma", pragma.name+"="+pragma.value) + } + + dsn := name + "?" + q.Encode() + db, err := sql.Open(driverName, dsn) + if err != nil { + return err + } + + db.SetMaxOpenConns(1) + + if err := checkPragmas(db, pragmas); err != nil { + return err + } + + c, err := db.Conn(context.Background()) + if err != nil { + return err + } + + // Kill the connection to spawn a new one. Pragma configs should persist + c.Raw(func(interface{}) error { return driver.ErrBadConn }) + + if err := checkPragmas(db, pragmas); err != nil { + return err + } + + if diskFile == "" { + // Make sure in memory databases aren't being written to disk + return testInMemory(db) + } + + return nil +} + +func checkPragmas(db *sql.DB, pragmas []pragmaCfg) error { + for _, pragma := range pragmas { + row := db.QueryRow(`PRAGMA ` + pragma.name) + + var result interface{} + if err := row.Scan(&result); err != nil { + return err + } + if result != pragma.expected { + return fmt.Errorf("expected PRAGMA %s to return %v but got %v", pragma.name, pragma.expected, result) + } + } + return nil +} + +func TestInMemory(t *testing.T) { + if err := emptyDir(tempDir); err != nil { + t.Fatal(err) + } + + wd, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + + defer os.Chdir(wd) + + if err := os.Chdir(tempDir); err != nil { + t.Fatal(err) + } + + if err := testMemoryPath(":memory:"); err != nil { + t.Fatal(err) + } + if err := testMemoryPath("file::memory:"); err != nil { + t.Fatal(err) + } + + // This parameter should be ignored + q := url.Values{} + q.Add("mode", "readonly") + if err := testMemoryPath(":memory:?" + q.Encode()); err != nil { + t.Fatal(err) + } +} + +func testMemoryPath(mPath string) error { + db, err := sql.Open(driverName, mPath) + if err != nil { + return err + } + defer db.Close() + + return testInMemory(db) +} + +func testInMemory(db *sql.DB) error { + _, err := db.Exec(` + create table in_memory_test(i int, f double); + insert into in_memory_test values(12, 3.14); + `) + if err != nil { + return err + } + + files, err := ioutil.ReadDir("./") + if err != nil { + return err + } + + for _, file := range files { + if strings.Contains(file.Name(), "memory") { + return fmt.Errorf("file was created for in memory database") + } + } + + return nil +} + func emptyDir(s string) error { m, err := filepath.Glob(filepath.FromSlash(s + "/*")) if err != nil { diff --git a/sqlite.go b/sqlite.go index 30f6714..ec3873e 100644 --- a/sqlite.go +++ b/sqlite.go @@ -14,6 +14,7 @@ import ( "fmt" "io" "math" + "net/url" "reflect" "strconv" "strings" @@ -727,10 +728,23 @@ type conn struct { sync.Mutex } -func newConn(name string) (*conn, error) { +func newConn(dsn string) (*conn, error) { + var query string + + // Parse the query parameters from the dsn and them from the dsn if not prefixed by file: + // https://github.com/mattn/go-sqlite3/blob/3392062c729d77820afc1f5cae3427f0de39e954/sqlite3.go#L1046 + // https://github.com/mattn/go-sqlite3/blob/3392062c729d77820afc1f5cae3427f0de39e954/sqlite3.go#L1383 + pos := strings.IndexRune(dsn, '?') + if pos >= 1 { + query = dsn[pos+1:] + if !strings.HasPrefix(dsn, "file:") { + dsn = dsn[:pos] + } + } + c := &conn{tls: libc.NewTLS()} db, err := c.openV2( - name, + dsn, sqlite3.SQLITE_OPEN_READWRITE|sqlite3.SQLITE_OPEN_CREATE| sqlite3.SQLITE_OPEN_FULLMUTEX| sqlite3.SQLITE_OPEN_URI, @@ -745,9 +759,29 @@ func newConn(name string) (*conn, error) { return nil, err } + if err = applyPragmas(c, query); err != nil { + c.Close() + return nil, err + } + return c, nil } +func applyPragmas(c *conn, query string) error { + q, err := url.ParseQuery(query) + if err != nil { + return err + } + for _, v := range q["_pragma"] { + cmd := "pragma " + v + _, err := c.exec(context.Background(), cmd, nil) + if err != nil { + return err + } + } + return nil +} + // const void *sqlite3_column_blob(sqlite3_stmt*, int iCol); func (c *conn) columnBlob(pstmt uintptr, iCol int) (v []byte, err error) { p := sqlite3.Xsqlite3_column_blob(c.tls, pstmt, int32(iCol))