From 2097e63911c0123d8fee3059a8e76ad2389f8a27 Mon Sep 17 00:00:00 2001 From: glebarez Date: Tue, 5 Apr 2022 18:04:19 +0300 Subject: [PATCH] merge upstream --- AUTHORS | 1 + CONTRIBUTORS | 3 +- all_test.go | 265 ++++++++++++++++++++++++++++++++++++++ benchmark/README.md | 2 +- benchmark/go.sum | 12 +- functest/func_test.go | 292 ++++++++++++++++++++++++++++++++++++++++++ go.mod | 5 +- go.sum | 11 +- sqlite.go | 226 +++++++++++++++++++++++++++++++- 9 files changed, 797 insertions(+), 20 deletions(-) create mode 100644 functest/func_test.go diff --git a/AUTHORS b/AUTHORS index 22ed356..86d8640 100644 --- a/AUTHORS +++ b/AUTHORS @@ -13,6 +13,7 @@ Davsk Ltd Co Jaap Aarts Jan Mercl <0xjnml@gmail.com> Logan Snow +Michael Hoffmann Ross Light Steffen Butzer Saed SayedAhmed diff --git a/CONTRIBUTORS b/CONTRIBUTORS index 6e4d262..036da46 100644 --- a/CONTRIBUTORS +++ b/CONTRIBUTORS @@ -14,7 +14,8 @@ Jaap Aarts Jan Mercl <0xjnml@gmail.com> Logan Snow Matthew Gabeler-Lee +Michael Hoffmann Ross Light Steffen Butzer Yaacov Akiba Slama -Saed SayedAhmed \ No newline at end of file +Saed SayedAhmed diff --git a/all_test.go b/all_test.go index 0a34668..4ed8ef2 100644 --- a/all_test.go +++ b/all_test.go @@ -9,6 +9,7 @@ import ( "context" "database/sql" "database/sql/driver" + "errors" "flag" "fmt" "io/ioutil" @@ -19,6 +20,7 @@ import ( "path" "path/filepath" "reflect" + "regexp" "runtime" "runtime/debug" "strconv" @@ -26,6 +28,7 @@ import ( "sync" "testing" "time" + "unsafe" "modernc.org/libc" "modernc.org/mathutil" @@ -146,6 +149,26 @@ func tempDB(t testing.TB) (string, *sql.DB) { return dir, db } +// https://gitlab.com/cznic/sqlite/issues/98 +func TestIssue98(t *testing.T) { + dir, db := tempDB(t) + + defer func() { + db.Close() + os.RemoveAll(dir) + }() + + if _, err := db.Exec("create table t(b mediumblob not null)"); err != nil { + t.Fatal(err) + } + if _, err := db.Exec("insert into t values (?)", []byte{}); err != nil { + t.Fatal(err) + } + if _, err := db.Exec("insert into t values (?)", nil); err == nil { + t.Fatal(errors.New("expected statement to fail")) + } +} + func TestScalar(t *testing.T) { dir, db := tempDB(t) @@ -214,6 +237,243 @@ func TestScalar(t *testing.T) { } } +func TestRedefineUserDefinedFunction(t *testing.T) { + dir, db := tempDB(t) + ctx := context.Background() + + defer func() { + db.Close() + os.RemoveAll(dir) + }() + + connection, err := db.Conn(context.Background()) + if err != nil { + t.Fatal(err) + } + + var r int + funName := "test" + + if err = connection.Raw(func(driverConn interface{}) error { + c := driverConn.(*conn) + + name, err := libc.CString(funName) + if err != nil { + return err + } + + return c.createFunctionInternal(&userDefinedFunction{ + zFuncName: name, + nArg: 0, + eTextRep: sqlite3.SQLITE_UTF8 | sqlite3.SQLITE_DETERMINISTIC, + xFunc: func(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) { + sqlite3.Xsqlite3_result_int(tls, ctx, 1) + }, + }) + }); err != nil { + t.Fatal(err) + } + row := connection.QueryRowContext(ctx, "select test()") + + if err := row.Scan(&r); err != nil { + t.Fatal(err) + } + + if g, e := r, 1; g != e { + t.Fatal(g, e) + } + + if err = connection.Raw(func(driverConn interface{}) error { + c := driverConn.(*conn) + + name, err := libc.CString(funName) + if err != nil { + return err + } + + return c.createFunctionInternal(&userDefinedFunction{ + zFuncName: name, + nArg: 0, + eTextRep: sqlite3.SQLITE_UTF8 | sqlite3.SQLITE_DETERMINISTIC, + xFunc: func(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) { + sqlite3.Xsqlite3_result_int(tls, ctx, 2) + }, + }) + }); err != nil { + t.Fatal(err) + } + row = connection.QueryRowContext(ctx, "select test()") + + if err := row.Scan(&r); err != nil { + t.Fatal(err) + } + + if g, e := r, 2; g != e { + t.Fatal(g, e) + } +} + +func TestRegexpUserDefinedFunction(t *testing.T) { + dir, db := tempDB(t) + ctx := context.Background() + + defer func() { + db.Close() + os.RemoveAll(dir) + }() + + connection, err := db.Conn(context.Background()) + if err != nil { + t.Fatal(err) + } + + if err = connection.Raw(func(driverConn interface{}) error { + c := driverConn.(*conn) + + name, err := libc.CString("regexp") + if err != nil { + return err + } + + return c.createFunctionInternal(&userDefinedFunction{ + zFuncName: name, + nArg: 2, + eTextRep: sqlite3.SQLITE_UTF8 | sqlite3.SQLITE_DETERMINISTIC, + xFunc: func(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) { + const sqliteValPtrSize = unsafe.Sizeof(&sqlite3.Sqlite3_value{}) + + argvv := make([]uintptr, argc) + for i := int32(0); i < argc; i++ { + argvv[i] = *(*uintptr)(unsafe.Pointer(argv + uintptr(i)*sqliteValPtrSize)) + } + + setErrorResult := func(res error) { + errmsg, cerr := libc.CString(res.Error()) + if cerr != nil { + panic(cerr) + } + defer libc.Xfree(tls, errmsg) + sqlite3.Xsqlite3_result_error(tls, ctx, errmsg, -1) + sqlite3.Xsqlite3_result_error_code(tls, ctx, sqlite3.SQLITE_ERROR) + } + + var s1 string + switch sqlite3.Xsqlite3_value_type(tls, argvv[0]) { + case sqlite3.SQLITE_TEXT: + s1 = libc.GoString(sqlite3.Xsqlite3_value_text(tls, argvv[0])) + default: + setErrorResult(errors.New("expected argv[0] to be text")) + return + } + + var s2 string + switch sqlite3.Xsqlite3_value_type(tls, argvv[1]) { + case sqlite3.SQLITE_TEXT: + s2 = libc.GoString(sqlite3.Xsqlite3_value_text(tls, argvv[1])) + default: + setErrorResult(errors.New("expected argv[1] to be text")) + return + } + + matched, err := regexp.MatchString(s1, s2) + if err != nil { + setErrorResult(fmt.Errorf("bad regular expression: %q", err)) + return + } + sqlite3.Xsqlite3_result_int(tls, ctx, libc.Bool32(matched)) + }, + }) + }); err != nil { + t.Fatal(err) + } + + t.Run("regexp filter", func(tt *testing.T) { + t1 := "seafood" + t2 := "fruit" + + connection.ExecContext(ctx, ` +create table t(b text); +insert into t values(?), (?); +`, t1, t2) + + rows, err := connection.QueryContext(ctx, "select * from t where b regexp 'foo.*'") + if err != nil { + tt.Fatal(err) + } + + type rec struct { + b string + } + var a []rec + for rows.Next() { + var r rec + if err := rows.Scan(&r.b); err != nil { + tt.Fatal(err) + } + + a = append(a, r) + } + if err := rows.Err(); err != nil { + tt.Fatal(err) + } + + if g, e := len(a), 1; g != e { + tt.Fatal(g, e) + } + + if g, e := a[0].b, t1; g != e { + tt.Fatal(g, e) + } + }) + + t.Run("regexp matches", func(tt *testing.T) { + row := connection.QueryRowContext(ctx, "select 'seafood' regexp 'foo.*'") + + var r int + if err := row.Scan(&r); err != nil { + tt.Fatal(err) + } + + if g, e := r, 1; g != e { + tt.Fatal(g, e) + } + }) + + t.Run("regexp does not match", func(tt *testing.T) { + row := connection.QueryRowContext(ctx, "select 'fruit' regexp 'foo.*'") + + var r int + if err := row.Scan(&r); err != nil { + tt.Fatal(err) + } + + if g, e := r, 0; g != e { + tt.Fatal(g, e) + } + }) + + t.Run("errors on bad regexp", func(tt *testing.T) { + err := connection.QueryRowContext(ctx, "select 'seafood' regexp 'a(b'").Scan() + if err == nil { + tt.Fatal(errors.New("expected error, got none")) + } + }) + + t.Run("errors on bad first argument", func(tt *testing.T) { + err := connection.QueryRowContext(ctx, "SELECT 1 REGEXP 'a(b'").Scan() + if err == nil { + tt.Fatal(errors.New("expected error, got none")) + } + }) + + t.Run("errors on bad second argument", func(tt *testing.T) { + err := connection.QueryRowContext(ctx, "SELECT 'seafood' REGEXP 1").Scan() + if err == nil { + tt.Fatal(errors.New("expected error, got none")) + } + }) +} + func TestBlob(t *testing.T) { dir, db := tempDB(t) @@ -561,6 +821,11 @@ func TestConcurrentProcesses(t *testing.T) { t.Skip("skipping test in short mode") } + //TODO The current riscv64 board seems too slow for the hardcoded timeouts. + if runtime.GOARCH == "riscv64" { + t.Skip("skipping test") + } + dir, err := ioutil.TempDir("", "sqlite-test-") if err != nil { t.Fatal(err) diff --git a/benchmark/README.md b/benchmark/README.md index 4bf80a1..9ccca91 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -12,7 +12,7 @@ Additional command line arguments: | flag | type | default | description | | ---- | ---- | ------- | ----------------------------------------------------------------------------------------------- | -| -mem | bool | false | if set - benchmarks will use in-memory SQLite instance, otherwise: on-disk instance | +| -mem | bool | false | if true: benchmarks will use in-memory SQLite instance, otherwise: on-disk instance | | -rep | uint | 1 | run each benchmark multiple times and average the results. this may provide more stable results | diff --git a/benchmark/go.sum b/benchmark/go.sum index 7688541..0f8b9e1 100644 --- a/benchmark/go.sum +++ b/benchmark/go.sum @@ -56,6 +56,9 @@ modernc.org/cc/v3 v3.35.15/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g modernc.org/cc/v3 v3.35.16/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g= modernc.org/cc/v3 v3.35.17/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g= modernc.org/cc/v3 v3.35.18/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g= +modernc.org/cc/v3 v3.35.20/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g= +modernc.org/cc/v3 v3.35.22 h1:BzShpwCAP7TWzFppM4k2t03RhXhgYqaibROWkrWq7lE= +modernc.org/cc/v3 v3.35.22/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g= modernc.org/ccgo/v3 v3.9.5/go.mod h1:umuo2EP2oDSBnD3ckjaVUXMrmeAw8C8OSICVa0iFf60= modernc.org/ccgo/v3 v3.10.0/go.mod h1:c0yBmkRFi7uW4J7fwx/JiijwOjeAeR2NoSaRVFPmjMw= modernc.org/ccgo/v3 v3.11.0/go.mod h1:dGNposbDp9TOZ/1KBxghxtUp/bzErD0/0QW4hhSaBMI= @@ -91,7 +94,6 @@ modernc.org/ccgo/v3 v3.12.73/go.mod h1:hngkB+nUUqzOf3iqsM48Gf1FZhY599qzVg1iX+BT3 modernc.org/ccgo/v3 v3.12.81/go.mod h1:p2A1duHoBBg1mFtYvnhAnQyI6vL0uw5PGYLSIgF6rYY= modernc.org/ccgo/v3 v3.12.84/go.mod h1:ApbflUfa5BKadjHynCficldU1ghjen84tuM5jRynB7w= modernc.org/ccgo/v3 v3.12.86/go.mod h1:dN7S26DLTgVSni1PVA3KxxHTcykyDurf3OgUzNqTSrU= -modernc.org/ccgo/v3 v3.12.88/go.mod h1:0MFzUHIuSIthpVZyMWiFYMwjiFnhrN5MkvBrUwON+ZM= modernc.org/ccgo/v3 v3.12.90/go.mod h1:obhSc3CdivCRpYZmrvO88TXlW0NvoSVvdh/ccRjJYko= modernc.org/ccgo/v3 v3.12.92/go.mod h1:5yDdN7ti9KWPi5bRVWPl8UNhpEAtCjuEE7ayQnzzqHA= modernc.org/ccgo/v3 v3.12.95/go.mod h1:ZcLyvtocXYi8uF+9Ebm3G8EF8HNY5hGomBqthDp4eC8= @@ -132,12 +134,12 @@ modernc.org/libc v1.11.82/go.mod h1:NF+Ek1BOl2jeC7lw3a7Jj5PWyHPwWD4aq3wVKxqV1fI= modernc.org/libc v1.11.86/go.mod h1:ePuYgoQLmvxdNT06RpGnaDKJmDNEkV7ZPKI2jnsvZoE= modernc.org/libc v1.11.87/go.mod h1:Qvd5iXTeLhI5PS0XSyqMY99282y+3euapQFxM7jYnpY= modernc.org/libc v1.11.88/go.mod h1:h3oIVe8dxmTcchcFuCcJ4nAWaoiwzKCdv82MM0oiIdQ= -modernc.org/libc v1.11.90/go.mod h1:ynK5sbjsU77AP+nn61+k+wxUGRx9rOFcIqWYYMaDZ4c= modernc.org/libc v1.11.98/go.mod h1:ynK5sbjsU77AP+nn61+k+wxUGRx9rOFcIqWYYMaDZ4c= -modernc.org/libc v1.11.99/go.mod h1:wLLYgEiY2D17NbBOEp+mIJJJBGSiy7fLL4ZrGGZ+8jI= modernc.org/libc v1.11.101/go.mod h1:wLLYgEiY2D17NbBOEp+mIJJJBGSiy7fLL4ZrGGZ+8jI= -modernc.org/libc v1.11.104 h1:gxoa5b3HPo7OzD4tKZjgnwXk/w//u1oovvjSMP3Q96Q= -modernc.org/libc v1.11.104/go.mod h1:2MH3DaF/gCU8i/UBiVE1VFRos4o523M7zipmwH8SIgQ= +modernc.org/libc v1.12.0/go.mod h1:2MH3DaF/gCU8i/UBiVE1VFRos4o523M7zipmwH8SIgQ= +modernc.org/libc v1.13.1/go.mod h1:npFeGWjmZTjFeWALQLrvklVmAxv4m80jnG3+xI8FdJk= +modernc.org/libc v1.13.2 h1:GCFjY9bmwDZ/TJC4OZOUWaNgxIxwb104C/QZrqpcVEA= +modernc.org/libc v1.13.2/go.mod h1:npFeGWjmZTjFeWALQLrvklVmAxv4m80jnG3+xI8FdJk= modernc.org/mathutil v1.1.1/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= modernc.org/mathutil v1.2.2/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= modernc.org/mathutil v1.4.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= diff --git a/functest/func_test.go b/functest/func_test.go new file mode 100644 index 0000000..be878de --- /dev/null +++ b/functest/func_test.go @@ -0,0 +1,292 @@ +// Copyright 2022 The Sqlite Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package functest // modernc.org/sqlite/functest + +import ( + "bytes" + "crypto/md5" + "database/sql" + "database/sql/driver" + "encoding/hex" + "errors" + "fmt" + "strings" + "testing" + "time" + + sqlite3 "github.com/glebarez/go-sqlite" +) + +func init() { + sqlite3.MustRegisterDeterministicScalarFunction( + "test_int64", + 0, + func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) { + return int64(42), nil + }, + ) + + sqlite3.MustRegisterDeterministicScalarFunction( + "test_float64", + 0, + func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) { + return float64(1e-2), nil + }, + ) + + sqlite3.MustRegisterDeterministicScalarFunction( + "test_null", + 0, + func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) { + return nil, nil + }, + ) + + sqlite3.MustRegisterDeterministicScalarFunction( + "test_error", + 0, + func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) { + return nil, errors.New("boom") + }, + ) + + sqlite3.MustRegisterDeterministicScalarFunction( + "test_empty_byte_slice", + 0, + func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) { + return []byte{}, nil + }, + ) + + sqlite3.MustRegisterDeterministicScalarFunction( + "test_nonempty_byte_slice", + 0, + func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) { + return []byte("abcdefg"), nil + }, + ) + + sqlite3.MustRegisterDeterministicScalarFunction( + "test_empty_string", + 0, + func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) { + return "", nil + }, + ) + + sqlite3.MustRegisterDeterministicScalarFunction( + "test_nonempty_string", + 0, + func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) { + return "abcdefg", nil + }, + ) + + sqlite3.MustRegisterDeterministicScalarFunction( + "yesterday", + 1, + func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) { + var arg time.Time + switch argTyped := args[0].(type) { + case int64: + arg = time.Unix(argTyped, 0) + default: + fmt.Println(argTyped) + return nil, fmt.Errorf("expected argument to be int64, got: %T", argTyped) + } + return arg.Add(-24 * time.Hour), nil + }, + ) + + sqlite3.MustRegisterDeterministicScalarFunction( + "md5", + 1, + func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) { + var arg *bytes.Buffer + switch argTyped := args[0].(type) { + case string: + arg = bytes.NewBuffer([]byte(argTyped)) + case []byte: + arg = bytes.NewBuffer(argTyped) + default: + return nil, fmt.Errorf("expected argument to be a string, got: %T", argTyped) + } + w := md5.New() + if _, err := arg.WriteTo(w); err != nil { + return nil, fmt.Errorf("unable to compute md5 checksum: %s", err) + } + return hex.EncodeToString(w.Sum(nil)), nil + }, + ) +} + +func TestRegisteredFunctions(t *testing.T) { + withDB := func(test func(db *sql.DB)) { + db, err := sql.Open("sqlite", "file::memory:") + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + defer db.Close() + + test(db) + } + + t.Run("int64", func(tt *testing.T) { + withDB(func(db *sql.DB) { + row := db.QueryRow("select test_int64()") + + var a int + if err := row.Scan(&a); err != nil { + tt.Fatal(err) + } + if g, e := a, 42; g != e { + tt.Fatal(g, e) + } + + }) + }) + + t.Run("float64", func(tt *testing.T) { + withDB(func(db *sql.DB) { + row := db.QueryRow("select test_float64()") + + var a float64 + if err := row.Scan(&a); err != nil { + tt.Fatal(err) + } + if g, e := a, 1e-2; g != e { + tt.Fatal(g, e) + } + + }) + }) + + t.Run("error", func(tt *testing.T) { + withDB(func(db *sql.DB) { + err := db.QueryRow("select test_error()").Scan() + if err == nil { + tt.Fatal("expected error, got none") + } + if !strings.Contains(err.Error(), "boom") { + tt.Fatal(err) + } + }) + }) + + t.Run("empty_byte_slice", func(tt *testing.T) { + withDB(func(db *sql.DB) { + row := db.QueryRow("select test_empty_byte_slice()") + + var a []byte + if err := row.Scan(&a); err != nil { + tt.Fatal(err) + } + if len(a) > 0 { + tt.Fatal("expected empty byte slice") + } + }) + }) + + t.Run("nonempty_byte_slice", func(tt *testing.T) { + withDB(func(db *sql.DB) { + row := db.QueryRow("select test_nonempty_byte_slice()") + + var a []byte + if err := row.Scan(&a); err != nil { + tt.Fatal(err) + } + if g, e := a, []byte("abcdefg"); !bytes.Equal(g, e) { + tt.Fatal(string(g), string(e)) + } + }) + }) + + t.Run("empty_string", func(tt *testing.T) { + withDB(func(db *sql.DB) { + row := db.QueryRow("select test_empty_string()") + + var a string + if err := row.Scan(&a); err != nil { + tt.Fatal(err) + } + if len(a) > 0 { + tt.Fatal("expected empty string") + } + }) + }) + + t.Run("nonempty_string", func(tt *testing.T) { + withDB(func(db *sql.DB) { + row := db.QueryRow("select test_nonempty_string()") + + var a string + if err := row.Scan(&a); err != nil { + tt.Fatal(err) + } + if g, e := a, "abcdefg"; g != e { + tt.Fatal(g, e) + } + }) + }) + + t.Run("null", func(tt *testing.T) { + withDB(func(db *sql.DB) { + row := db.QueryRow("select test_null()") + + var a interface{} + if err := row.Scan(&a); err != nil { + tt.Fatal(err) + } + if a != nil { + tt.Fatal("expected nil") + } + }) + }) + + t.Run("dates", func(tt *testing.T) { + withDB(func(db *sql.DB) { + row := db.QueryRow("select yesterday(unixepoch('2018-11-01'))") + + var a int64 + if err := row.Scan(&a); err != nil { + tt.Fatal(err) + } + if g, e := time.Unix(a, 0), time.Date(2018, time.October, 31, 0, 0, 0, 0, time.UTC); !g.Equal(e) { + tt.Fatal(g, e) + } + }) + }) + + t.Run("md5", func(tt *testing.T) { + withDB(func(db *sql.DB) { + row := db.QueryRow("select md5('abcdefg')") + + var a string + if err := row.Scan(&a); err != nil { + tt.Fatal(err) + } + if g, e := a, "7ac66c0f148de9519b8bd264312c4d64"; g != e { + tt.Fatal(g, e) + } + }) + }) + + t.Run("md5 with blob input", func(tt *testing.T) { + withDB(func(db *sql.DB) { + if _, err := db.Exec("create table t(b blob); insert into t values (?)", []byte("abcdefg")); err != nil { + tt.Fatal(err) + } + row := db.QueryRow("select md5(b) from t") + + var a []byte + if err := row.Scan(&a); err != nil { + tt.Fatal(err) + } + if g, e := a, []byte("7ac66c0f148de9519b8bd264312c4d64"); !bytes.Equal(g, e) { + tt.Fatal(string(g), string(e)) + } + }) + }) +} diff --git a/go.mod b/go.mod index 5d429d4..8fbef71 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,9 @@ module github.com/glebarez/go-sqlite go 1.16 require ( - golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac + github.com/mattn/go-isatty v0.0.14 // indirect + golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64 modernc.org/libc v1.14.12 modernc.org/mathutil v1.4.1 - modernc.org/sqlite v1.15.4 + modernc.org/sqlite v1.16.0 ) diff --git a/go.sum b/go.sum index a881f6b..7fa0629 100644 --- a/go.sum +++ b/go.sum @@ -3,8 +3,9 @@ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= -github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-sqlite3 v1.14.12/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 h1:OdAsTTz6OkFY5QxjkYwrChwuRruF69c169dPK26NUlk= @@ -24,9 +25,11 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201126233918-771906719818/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210902050250-f475640dd07b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac h1:oN6lz7iLW/YC7un8pq+9bOLyXrprv2+DKfkJY+2LJJw= golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64 h1:D1v9ucDTYBtbz5vNuBbAhIMAGhQhJ6Ym5ah3maMVNX4= +golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -162,8 +165,8 @@ modernc.org/memory v1.0.6/go.mod h1:/0wo5ibyrQiaoUoH7f9D8dnglAmILJ5/cxZlRECf+Nw= modernc.org/memory v1.0.7 h1:UE3cxTRFa5tfUibAV7Jqq8P7zRY0OlJg+yWVIIaluEE= modernc.org/memory v1.0.7/go.mod h1:/0wo5ibyrQiaoUoH7f9D8dnglAmILJ5/cxZlRECf+Nw= modernc.org/opt v0.1.1/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= -modernc.org/sqlite v1.15.4 h1:pr3EA3Rety3j1c/9pCyGAe5d3vjF6wQwusHdgGCjIqc= -modernc.org/sqlite v1.15.4/go.mod h1:Jwe13ItpESZ+78K5WS6+AjXsUg+JvirsjN3iIDO4C8k= +modernc.org/sqlite v1.16.0 h1:DdvOGaWN0y+X7t2L7RUD63gcwbVjYZjcBZnA68g44EI= +modernc.org/sqlite v1.16.0/go.mod h1:Jwe13ItpESZ+78K5WS6+AjXsUg+JvirsjN3iIDO4C8k= modernc.org/strutil v1.1.1/go.mod h1:DE+MQQ/hjKBZS2zNInV5hhcipt5rLPWkmpbGeW5mmdw= modernc.org/tcl v1.11.2/go.mod h1:BRzgpajcGdS2qTxniOx9c/dcxjlbA7p12eJNmiriQYo= modernc.org/token v1.0.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/sqlite.go b/sqlite.go index 304167b..f5b5db7 100644 --- a/sqlite.go +++ b/sqlite.go @@ -1153,14 +1153,18 @@ func (c *conn) bindText(pstmt uintptr, idx1 int, value string) (uintptr, error) // int sqlite3_bind_blob(sqlite3_stmt*, int, const void*, int n, void(*)(void*)); func (c *conn) bindBlob(pstmt uintptr, idx1 int, value []byte) (uintptr, error) { + if len(value) == 0 { + if rc := sqlite3.Xsqlite3_bind_zeroblob(c.tls, pstmt, int32(idx1), 0); rc != sqlite3.SQLITE_OK { + return 0, c.errstr(rc) + } + return 0, nil + } + p, err := c.malloc(len(value)) if err != nil { return 0, err } - - if len(value) != 0 { - copy((*libc.RawMem)(unsafe.Pointer(p))[:len(value):len(value)], value) - } + copy((*libc.RawMem)(unsafe.Pointer(p))[:len(value):len(value)], value) if rc := sqlite3.Xsqlite3_bind_blob(c.tls, pstmt, int32(idx1), p, int32(len(value)), 0); rc != sqlite3.SQLITE_OK { c.free(p) return 0, c.errstr(rc) @@ -1370,6 +1374,7 @@ func (c *conn) Close() error { c.db = 0 } + if c.tls != nil { c.tls.Close() c.tls = nil @@ -1386,6 +1391,32 @@ func (c *conn) closeV2(db uintptr) error { return nil } +type userDefinedFunction struct { + zFuncName uintptr + nArg int32 + eTextRep int32 + xFunc func(*libc.TLS, uintptr, int32, uintptr) + + freeOnce sync.Once +} + +func (c *conn) createFunctionInternal(fun *userDefinedFunction) error { + if rc := sqlite3.Xsqlite3_create_function( + c.tls, + c.db, + fun.zFuncName, + fun.nArg, + fun.eTextRep, + 0, + *(*uintptr)(unsafe.Pointer(&fun.xFunc)), + 0, + 0, + ); rc != sqlite3.SQLITE_OK { + return c.errstr(rc) + } + return nil +} + // Execer is an optional interface that may be implemented by a Conn. // // If a Conn does not implement Execer, the sql package's DB.Exec will first @@ -1451,9 +1482,14 @@ func (c *conn) query(ctx context.Context, query string, args []driver.NamedValue } // Driver implements database/sql/driver.Driver. -type Driver struct{} +type Driver struct { + // user defined functions that are added to every new connection on Open + udfs map[string]*userDefinedFunction +} -func newDriver() *Driver { return &Driver{} } +var d = &Driver{udfs: make(map[string]*userDefinedFunction)} + +func newDriver() *Driver { return d } // Open returns a new connection to the database. The name is a string in a // driver-specific format. @@ -1488,5 +1524,181 @@ func (d *Driver) Open(name string) (driver.Conn, error) { if LogSqlStatements { log.Println("new connection") } - return newConn(name) + + c, err := newConn(name) + if err != nil { + return nil, err + } + + for _, udf := range d.udfs { + if err = c.createFunctionInternal(udf); err != nil { + c.Close() + return nil, err + } + } + return c, nil +} + +// FunctionContext represents the context user defined functions execute in. +// Fields and/or methods of this type may get addedd in the future. +type FunctionContext struct{} + +const sqliteValPtrSize = unsafe.Sizeof(&sqlite3.Sqlite3_value{}) + +// RegisterScalarFunction registers a scalar function named zFuncName with nArg +// arguments. Passing -1 for nArg indicates the function is variadic. +// +// The new function will be available to all new connections opened after +// executing RegisterScalarFunction. +func RegisterScalarFunction( + zFuncName string, + nArg int32, + xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error), +) error { + return registerScalarFunction(zFuncName, nArg, sqlite3.SQLITE_UTF8, xFunc) +} + +// MustRegisterScalarFunction is like RegisterScalarFunction but panics on +// error. +func MustRegisterScalarFunction( + zFuncName string, + nArg int32, + xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error), +) { + if err := RegisterScalarFunction(zFuncName, nArg, xFunc); err != nil { + panic(err) + } +} + +// MustRegisterDeterministicScalarFunction is like +// RegisterDeterministicScalarFunction but panics on error. +func MustRegisterDeterministicScalarFunction( + zFuncName string, + nArg int32, + xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error), +) { + if err := RegisterDeterministicScalarFunction(zFuncName, nArg, xFunc); err != nil { + panic(err) + } +} + +// RegisterDeterministicScalarFunction registers a deterministic scalar +// function named zFuncName with nArg arguments. Passing -1 for nArg indicates +// the function is variadic. A deterministic function means that the function +// always gives the same output when the input parameters are the same. +// +// The new function will be available to all new connections opened after +// executing RegisterDeterministicScalarFunction. +func RegisterDeterministicScalarFunction( + zFuncName string, + nArg int32, + xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error), +) error { + return registerScalarFunction(zFuncName, nArg, sqlite3.SQLITE_UTF8|sqlite3.SQLITE_DETERMINISTIC, xFunc) +} + +func registerScalarFunction( + zFuncName string, + nArg int32, + eTextRep int32, + xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error), +) error { + + if _, ok := d.udfs[zFuncName]; ok { + return fmt.Errorf("a function named %q is already registered", zFuncName) + } + + // dont free, functions registered on the driver live as long as the program + name, err := libc.CString(zFuncName) + if err != nil { + return err + } + + udf := &userDefinedFunction{ + zFuncName: name, + nArg: nArg, + eTextRep: eTextRep, + xFunc: func(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) { + setErrorResult := func(res error) { + errmsg, cerr := libc.CString(res.Error()) + if cerr != nil { + panic(cerr) + } + defer libc.Xfree(tls, errmsg) + sqlite3.Xsqlite3_result_error(tls, ctx, errmsg, -1) + sqlite3.Xsqlite3_result_error_code(tls, ctx, sqlite3.SQLITE_ERROR) + } + + args := make([]driver.Value, argc) + for i := int32(0); i < argc; i++ { + valPtr := *(*uintptr)(unsafe.Pointer(argv + uintptr(i)*sqliteValPtrSize)) + + switch valType := sqlite3.Xsqlite3_value_type(tls, valPtr); valType { + case sqlite3.SQLITE_TEXT: + args[i] = libc.GoString(sqlite3.Xsqlite3_value_text(tls, valPtr)) + case sqlite3.SQLITE_INTEGER: + args[i] = sqlite3.Xsqlite3_value_int64(tls, valPtr) + case sqlite3.SQLITE_FLOAT: + args[i] = sqlite3.Xsqlite3_value_double(tls, valPtr) + case sqlite3.SQLITE_NULL: + args[i] = nil + case sqlite3.SQLITE_BLOB: + size := sqlite3.Xsqlite3_value_bytes(tls, valPtr) + blobPtr := sqlite3.Xsqlite3_value_blob(tls, valPtr) + v := make([]byte, size) + copy(v, (*libc.RawMem)(unsafe.Pointer(blobPtr))[:size:size]) + args[i] = v + default: + panic(fmt.Sprintf("unexpected argument type %q passed by sqlite", valType)) + } + } + + res, err := xFunc(&FunctionContext{}, args) + if err != nil { + setErrorResult(err) + return + } + + switch resTyped := res.(type) { + case nil: + sqlite3.Xsqlite3_result_null(tls, ctx) + case int64: + sqlite3.Xsqlite3_result_int64(tls, ctx, resTyped) + case float64: + sqlite3.Xsqlite3_result_double(tls, ctx, resTyped) + case bool: + sqlite3.Xsqlite3_result_int(tls, ctx, libc.Bool32(resTyped)) + case time.Time: + sqlite3.Xsqlite3_result_int64(tls, ctx, resTyped.Unix()) + case string: + size := int32(len(resTyped)) + cstr, err := libc.CString(resTyped) + if err != nil { + panic(err) + } + defer libc.Xfree(tls, cstr) + sqlite3.Xsqlite3_result_text(tls, ctx, cstr, size, sqlite3.SQLITE_TRANSIENT) + case []byte: + size := int32(len(resTyped)) + if size == 0 { + sqlite3.Xsqlite3_result_zeroblob(tls, ctx, 0) + return + } + p := libc.Xmalloc(tls, types.Size_t(size)) + if p == 0 { + panic(fmt.Sprintf("unable to allocate space for blob: %d", size)) + } + defer libc.Xfree(tls, p) + copy((*libc.RawMem)(unsafe.Pointer(p))[:size:size], resTyped) + + sqlite3.Xsqlite3_result_blob(tls, ctx, p, size, sqlite3.SQLITE_TRANSIENT) + default: + setErrorResult(fmt.Errorf("function did not return a valid driver.Value: %T", resTyped)) + return + } + }, + } + d.udfs[zFuncName] = udf + + return nil }