merge upstream

This commit is contained in:
glebarez
2022-04-05 18:04:19 +03:00
parent b791ecc831
commit 2097e63911
9 changed files with 797 additions and 20 deletions

View File

@@ -13,6 +13,7 @@ Davsk Ltd Co <skinner.david@gmail.com>
Jaap Aarts <jaap.aarts1@gmail.com>
Jan Mercl <0xjnml@gmail.com>
Logan Snow <logansnow@protonmail.com>
Michael Hoffmann <mhoffm@posteo.de>
Ross Light <ross@zombiezen.com>
Steffen Butzer <steffen(dot)butzer@outlook.com>
Saed SayedAhmed <saadmtsa@gmail.com>

View File

@@ -14,7 +14,8 @@ Jaap Aarts <jaap.aarts1@gmail.com>
Jan Mercl <0xjnml@gmail.com>
Logan Snow <logansnow@protonmail.com>
Matthew Gabeler-Lee <fastcat@gmail.com>
Michael Hoffmann <mhoffm@posteo.de>
Ross Light <ross@zombiezen.com>
Steffen Butzer <steffen(dot)butzer@outlook.com>
Yaacov Akiba Slama <ya@slamail.org>
Saed SayedAhmed <saadmtsa@gmail.com>
Saed SayedAhmed <saadmtsa@gmail.com>

View File

@@ -9,6 +9,7 @@ import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"flag"
"fmt"
"io/ioutil"
@@ -19,6 +20,7 @@ import (
"path"
"path/filepath"
"reflect"
"regexp"
"runtime"
"runtime/debug"
"strconv"
@@ -26,6 +28,7 @@ import (
"sync"
"testing"
"time"
"unsafe"
"modernc.org/libc"
"modernc.org/mathutil"
@@ -146,6 +149,26 @@ func tempDB(t testing.TB) (string, *sql.DB) {
return dir, db
}
// https://gitlab.com/cznic/sqlite/issues/98
func TestIssue98(t *testing.T) {
dir, db := tempDB(t)
defer func() {
db.Close()
os.RemoveAll(dir)
}()
if _, err := db.Exec("create table t(b mediumblob not null)"); err != nil {
t.Fatal(err)
}
if _, err := db.Exec("insert into t values (?)", []byte{}); err != nil {
t.Fatal(err)
}
if _, err := db.Exec("insert into t values (?)", nil); err == nil {
t.Fatal(errors.New("expected statement to fail"))
}
}
func TestScalar(t *testing.T) {
dir, db := tempDB(t)
@@ -214,6 +237,243 @@ func TestScalar(t *testing.T) {
}
}
func TestRedefineUserDefinedFunction(t *testing.T) {
dir, db := tempDB(t)
ctx := context.Background()
defer func() {
db.Close()
os.RemoveAll(dir)
}()
connection, err := db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
var r int
funName := "test"
if err = connection.Raw(func(driverConn interface{}) error {
c := driverConn.(*conn)
name, err := libc.CString(funName)
if err != nil {
return err
}
return c.createFunctionInternal(&userDefinedFunction{
zFuncName: name,
nArg: 0,
eTextRep: sqlite3.SQLITE_UTF8 | sqlite3.SQLITE_DETERMINISTIC,
xFunc: func(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) {
sqlite3.Xsqlite3_result_int(tls, ctx, 1)
},
})
}); err != nil {
t.Fatal(err)
}
row := connection.QueryRowContext(ctx, "select test()")
if err := row.Scan(&r); err != nil {
t.Fatal(err)
}
if g, e := r, 1; g != e {
t.Fatal(g, e)
}
if err = connection.Raw(func(driverConn interface{}) error {
c := driverConn.(*conn)
name, err := libc.CString(funName)
if err != nil {
return err
}
return c.createFunctionInternal(&userDefinedFunction{
zFuncName: name,
nArg: 0,
eTextRep: sqlite3.SQLITE_UTF8 | sqlite3.SQLITE_DETERMINISTIC,
xFunc: func(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) {
sqlite3.Xsqlite3_result_int(tls, ctx, 2)
},
})
}); err != nil {
t.Fatal(err)
}
row = connection.QueryRowContext(ctx, "select test()")
if err := row.Scan(&r); err != nil {
t.Fatal(err)
}
if g, e := r, 2; g != e {
t.Fatal(g, e)
}
}
func TestRegexpUserDefinedFunction(t *testing.T) {
dir, db := tempDB(t)
ctx := context.Background()
defer func() {
db.Close()
os.RemoveAll(dir)
}()
connection, err := db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
if err = connection.Raw(func(driverConn interface{}) error {
c := driverConn.(*conn)
name, err := libc.CString("regexp")
if err != nil {
return err
}
return c.createFunctionInternal(&userDefinedFunction{
zFuncName: name,
nArg: 2,
eTextRep: sqlite3.SQLITE_UTF8 | sqlite3.SQLITE_DETERMINISTIC,
xFunc: func(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) {
const sqliteValPtrSize = unsafe.Sizeof(&sqlite3.Sqlite3_value{})
argvv := make([]uintptr, argc)
for i := int32(0); i < argc; i++ {
argvv[i] = *(*uintptr)(unsafe.Pointer(argv + uintptr(i)*sqliteValPtrSize))
}
setErrorResult := func(res error) {
errmsg, cerr := libc.CString(res.Error())
if cerr != nil {
panic(cerr)
}
defer libc.Xfree(tls, errmsg)
sqlite3.Xsqlite3_result_error(tls, ctx, errmsg, -1)
sqlite3.Xsqlite3_result_error_code(tls, ctx, sqlite3.SQLITE_ERROR)
}
var s1 string
switch sqlite3.Xsqlite3_value_type(tls, argvv[0]) {
case sqlite3.SQLITE_TEXT:
s1 = libc.GoString(sqlite3.Xsqlite3_value_text(tls, argvv[0]))
default:
setErrorResult(errors.New("expected argv[0] to be text"))
return
}
var s2 string
switch sqlite3.Xsqlite3_value_type(tls, argvv[1]) {
case sqlite3.SQLITE_TEXT:
s2 = libc.GoString(sqlite3.Xsqlite3_value_text(tls, argvv[1]))
default:
setErrorResult(errors.New("expected argv[1] to be text"))
return
}
matched, err := regexp.MatchString(s1, s2)
if err != nil {
setErrorResult(fmt.Errorf("bad regular expression: %q", err))
return
}
sqlite3.Xsqlite3_result_int(tls, ctx, libc.Bool32(matched))
},
})
}); err != nil {
t.Fatal(err)
}
t.Run("regexp filter", func(tt *testing.T) {
t1 := "seafood"
t2 := "fruit"
connection.ExecContext(ctx, `
create table t(b text);
insert into t values(?), (?);
`, t1, t2)
rows, err := connection.QueryContext(ctx, "select * from t where b regexp 'foo.*'")
if err != nil {
tt.Fatal(err)
}
type rec struct {
b string
}
var a []rec
for rows.Next() {
var r rec
if err := rows.Scan(&r.b); err != nil {
tt.Fatal(err)
}
a = append(a, r)
}
if err := rows.Err(); err != nil {
tt.Fatal(err)
}
if g, e := len(a), 1; g != e {
tt.Fatal(g, e)
}
if g, e := a[0].b, t1; g != e {
tt.Fatal(g, e)
}
})
t.Run("regexp matches", func(tt *testing.T) {
row := connection.QueryRowContext(ctx, "select 'seafood' regexp 'foo.*'")
var r int
if err := row.Scan(&r); err != nil {
tt.Fatal(err)
}
if g, e := r, 1; g != e {
tt.Fatal(g, e)
}
})
t.Run("regexp does not match", func(tt *testing.T) {
row := connection.QueryRowContext(ctx, "select 'fruit' regexp 'foo.*'")
var r int
if err := row.Scan(&r); err != nil {
tt.Fatal(err)
}
if g, e := r, 0; g != e {
tt.Fatal(g, e)
}
})
t.Run("errors on bad regexp", func(tt *testing.T) {
err := connection.QueryRowContext(ctx, "select 'seafood' regexp 'a(b'").Scan()
if err == nil {
tt.Fatal(errors.New("expected error, got none"))
}
})
t.Run("errors on bad first argument", func(tt *testing.T) {
err := connection.QueryRowContext(ctx, "SELECT 1 REGEXP 'a(b'").Scan()
if err == nil {
tt.Fatal(errors.New("expected error, got none"))
}
})
t.Run("errors on bad second argument", func(tt *testing.T) {
err := connection.QueryRowContext(ctx, "SELECT 'seafood' REGEXP 1").Scan()
if err == nil {
tt.Fatal(errors.New("expected error, got none"))
}
})
}
func TestBlob(t *testing.T) {
dir, db := tempDB(t)
@@ -561,6 +821,11 @@ func TestConcurrentProcesses(t *testing.T) {
t.Skip("skipping test in short mode")
}
//TODO The current riscv64 board seems too slow for the hardcoded timeouts.
if runtime.GOARCH == "riscv64" {
t.Skip("skipping test")
}
dir, err := ioutil.TempDir("", "sqlite-test-")
if err != nil {
t.Fatal(err)

View File

@@ -12,7 +12,7 @@ Additional command line arguments:
| flag | type | default | description |
| ---- | ---- | ------- | ----------------------------------------------------------------------------------------------- |
| -mem | bool | false | if set - benchmarks will use in-memory SQLite instance, otherwise: on-disk instance |
| -mem | bool | false | if true: benchmarks will use in-memory SQLite instance, otherwise: on-disk instance |
| -rep | uint | 1 | run each benchmark multiple times and average the results. this may provide more stable results |

View File

@@ -56,6 +56,9 @@ modernc.org/cc/v3 v3.35.15/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g
modernc.org/cc/v3 v3.35.16/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g=
modernc.org/cc/v3 v3.35.17/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g=
modernc.org/cc/v3 v3.35.18/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g=
modernc.org/cc/v3 v3.35.20/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g=
modernc.org/cc/v3 v3.35.22 h1:BzShpwCAP7TWzFppM4k2t03RhXhgYqaibROWkrWq7lE=
modernc.org/cc/v3 v3.35.22/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g=
modernc.org/ccgo/v3 v3.9.5/go.mod h1:umuo2EP2oDSBnD3ckjaVUXMrmeAw8C8OSICVa0iFf60=
modernc.org/ccgo/v3 v3.10.0/go.mod h1:c0yBmkRFi7uW4J7fwx/JiijwOjeAeR2NoSaRVFPmjMw=
modernc.org/ccgo/v3 v3.11.0/go.mod h1:dGNposbDp9TOZ/1KBxghxtUp/bzErD0/0QW4hhSaBMI=
@@ -91,7 +94,6 @@ modernc.org/ccgo/v3 v3.12.73/go.mod h1:hngkB+nUUqzOf3iqsM48Gf1FZhY599qzVg1iX+BT3
modernc.org/ccgo/v3 v3.12.81/go.mod h1:p2A1duHoBBg1mFtYvnhAnQyI6vL0uw5PGYLSIgF6rYY=
modernc.org/ccgo/v3 v3.12.84/go.mod h1:ApbflUfa5BKadjHynCficldU1ghjen84tuM5jRynB7w=
modernc.org/ccgo/v3 v3.12.86/go.mod h1:dN7S26DLTgVSni1PVA3KxxHTcykyDurf3OgUzNqTSrU=
modernc.org/ccgo/v3 v3.12.88/go.mod h1:0MFzUHIuSIthpVZyMWiFYMwjiFnhrN5MkvBrUwON+ZM=
modernc.org/ccgo/v3 v3.12.90/go.mod h1:obhSc3CdivCRpYZmrvO88TXlW0NvoSVvdh/ccRjJYko=
modernc.org/ccgo/v3 v3.12.92/go.mod h1:5yDdN7ti9KWPi5bRVWPl8UNhpEAtCjuEE7ayQnzzqHA=
modernc.org/ccgo/v3 v3.12.95/go.mod h1:ZcLyvtocXYi8uF+9Ebm3G8EF8HNY5hGomBqthDp4eC8=
@@ -132,12 +134,12 @@ modernc.org/libc v1.11.82/go.mod h1:NF+Ek1BOl2jeC7lw3a7Jj5PWyHPwWD4aq3wVKxqV1fI=
modernc.org/libc v1.11.86/go.mod h1:ePuYgoQLmvxdNT06RpGnaDKJmDNEkV7ZPKI2jnsvZoE=
modernc.org/libc v1.11.87/go.mod h1:Qvd5iXTeLhI5PS0XSyqMY99282y+3euapQFxM7jYnpY=
modernc.org/libc v1.11.88/go.mod h1:h3oIVe8dxmTcchcFuCcJ4nAWaoiwzKCdv82MM0oiIdQ=
modernc.org/libc v1.11.90/go.mod h1:ynK5sbjsU77AP+nn61+k+wxUGRx9rOFcIqWYYMaDZ4c=
modernc.org/libc v1.11.98/go.mod h1:ynK5sbjsU77AP+nn61+k+wxUGRx9rOFcIqWYYMaDZ4c=
modernc.org/libc v1.11.99/go.mod h1:wLLYgEiY2D17NbBOEp+mIJJJBGSiy7fLL4ZrGGZ+8jI=
modernc.org/libc v1.11.101/go.mod h1:wLLYgEiY2D17NbBOEp+mIJJJBGSiy7fLL4ZrGGZ+8jI=
modernc.org/libc v1.11.104 h1:gxoa5b3HPo7OzD4tKZjgnwXk/w//u1oovvjSMP3Q96Q=
modernc.org/libc v1.11.104/go.mod h1:2MH3DaF/gCU8i/UBiVE1VFRos4o523M7zipmwH8SIgQ=
modernc.org/libc v1.12.0/go.mod h1:2MH3DaF/gCU8i/UBiVE1VFRos4o523M7zipmwH8SIgQ=
modernc.org/libc v1.13.1/go.mod h1:npFeGWjmZTjFeWALQLrvklVmAxv4m80jnG3+xI8FdJk=
modernc.org/libc v1.13.2 h1:GCFjY9bmwDZ/TJC4OZOUWaNgxIxwb104C/QZrqpcVEA=
modernc.org/libc v1.13.2/go.mod h1:npFeGWjmZTjFeWALQLrvklVmAxv4m80jnG3+xI8FdJk=
modernc.org/mathutil v1.1.1/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E=
modernc.org/mathutil v1.2.2/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E=
modernc.org/mathutil v1.4.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E=

292
functest/func_test.go Normal file
View File

@@ -0,0 +1,292 @@
// Copyright 2022 The Sqlite Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package functest // modernc.org/sqlite/functest
import (
"bytes"
"crypto/md5"
"database/sql"
"database/sql/driver"
"encoding/hex"
"errors"
"fmt"
"strings"
"testing"
"time"
sqlite3 "github.com/glebarez/go-sqlite"
)
func init() {
sqlite3.MustRegisterDeterministicScalarFunction(
"test_int64",
0,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
return int64(42), nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"test_float64",
0,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
return float64(1e-2), nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"test_null",
0,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
return nil, nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"test_error",
0,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
return nil, errors.New("boom")
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"test_empty_byte_slice",
0,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
return []byte{}, nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"test_nonempty_byte_slice",
0,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
return []byte("abcdefg"), nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"test_empty_string",
0,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
return "", nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"test_nonempty_string",
0,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
return "abcdefg", nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"yesterday",
1,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
var arg time.Time
switch argTyped := args[0].(type) {
case int64:
arg = time.Unix(argTyped, 0)
default:
fmt.Println(argTyped)
return nil, fmt.Errorf("expected argument to be int64, got: %T", argTyped)
}
return arg.Add(-24 * time.Hour), nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"md5",
1,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
var arg *bytes.Buffer
switch argTyped := args[0].(type) {
case string:
arg = bytes.NewBuffer([]byte(argTyped))
case []byte:
arg = bytes.NewBuffer(argTyped)
default:
return nil, fmt.Errorf("expected argument to be a string, got: %T", argTyped)
}
w := md5.New()
if _, err := arg.WriteTo(w); err != nil {
return nil, fmt.Errorf("unable to compute md5 checksum: %s", err)
}
return hex.EncodeToString(w.Sum(nil)), nil
},
)
}
func TestRegisteredFunctions(t *testing.T) {
withDB := func(test func(db *sql.DB)) {
db, err := sql.Open("sqlite", "file::memory:")
if err != nil {
t.Fatalf("failed to open database: %v", err)
}
defer db.Close()
test(db)
}
t.Run("int64", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_int64()")
var a int
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := a, 42; g != e {
tt.Fatal(g, e)
}
})
})
t.Run("float64", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_float64()")
var a float64
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := a, 1e-2; g != e {
tt.Fatal(g, e)
}
})
})
t.Run("error", func(tt *testing.T) {
withDB(func(db *sql.DB) {
err := db.QueryRow("select test_error()").Scan()
if err == nil {
tt.Fatal("expected error, got none")
}
if !strings.Contains(err.Error(), "boom") {
tt.Fatal(err)
}
})
})
t.Run("empty_byte_slice", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_empty_byte_slice()")
var a []byte
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if len(a) > 0 {
tt.Fatal("expected empty byte slice")
}
})
})
t.Run("nonempty_byte_slice", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_nonempty_byte_slice()")
var a []byte
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := a, []byte("abcdefg"); !bytes.Equal(g, e) {
tt.Fatal(string(g), string(e))
}
})
})
t.Run("empty_string", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_empty_string()")
var a string
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if len(a) > 0 {
tt.Fatal("expected empty string")
}
})
})
t.Run("nonempty_string", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_nonempty_string()")
var a string
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := a, "abcdefg"; g != e {
tt.Fatal(g, e)
}
})
})
t.Run("null", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_null()")
var a interface{}
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if a != nil {
tt.Fatal("expected nil")
}
})
})
t.Run("dates", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select yesterday(unixepoch('2018-11-01'))")
var a int64
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := time.Unix(a, 0), time.Date(2018, time.October, 31, 0, 0, 0, 0, time.UTC); !g.Equal(e) {
tt.Fatal(g, e)
}
})
})
t.Run("md5", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select md5('abcdefg')")
var a string
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := a, "7ac66c0f148de9519b8bd264312c4d64"; g != e {
tt.Fatal(g, e)
}
})
})
t.Run("md5 with blob input", func(tt *testing.T) {
withDB(func(db *sql.DB) {
if _, err := db.Exec("create table t(b blob); insert into t values (?)", []byte("abcdefg")); err != nil {
tt.Fatal(err)
}
row := db.QueryRow("select md5(b) from t")
var a []byte
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := a, []byte("7ac66c0f148de9519b8bd264312c4d64"); !bytes.Equal(g, e) {
tt.Fatal(string(g), string(e))
}
})
})
}

5
go.mod
View File

@@ -3,8 +3,9 @@ module github.com/glebarez/go-sqlite
go 1.16
require (
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac
github.com/mattn/go-isatty v0.0.14 // indirect
golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64
modernc.org/libc v1.14.12
modernc.org/mathutil v1.4.1
modernc.org/sqlite v1.15.4
modernc.org/sqlite v1.16.0
)

11
go.sum
View File

@@ -3,8 +3,9 @@ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8=
github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-sqlite3 v1.14.12/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 h1:OdAsTTz6OkFY5QxjkYwrChwuRruF69c169dPK26NUlk=
@@ -24,9 +25,11 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201126233918-771906719818/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210902050250-f475640dd07b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac h1:oN6lz7iLW/YC7un8pq+9bOLyXrprv2+DKfkJY+2LJJw=
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64 h1:D1v9ucDTYBtbz5vNuBbAhIMAGhQhJ6Ym5ah3maMVNX4=
golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -162,8 +165,8 @@ modernc.org/memory v1.0.6/go.mod h1:/0wo5ibyrQiaoUoH7f9D8dnglAmILJ5/cxZlRECf+Nw=
modernc.org/memory v1.0.7 h1:UE3cxTRFa5tfUibAV7Jqq8P7zRY0OlJg+yWVIIaluEE=
modernc.org/memory v1.0.7/go.mod h1:/0wo5ibyrQiaoUoH7f9D8dnglAmILJ5/cxZlRECf+Nw=
modernc.org/opt v0.1.1/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0=
modernc.org/sqlite v1.15.4 h1:pr3EA3Rety3j1c/9pCyGAe5d3vjF6wQwusHdgGCjIqc=
modernc.org/sqlite v1.15.4/go.mod h1:Jwe13ItpESZ+78K5WS6+AjXsUg+JvirsjN3iIDO4C8k=
modernc.org/sqlite v1.16.0 h1:DdvOGaWN0y+X7t2L7RUD63gcwbVjYZjcBZnA68g44EI=
modernc.org/sqlite v1.16.0/go.mod h1:Jwe13ItpESZ+78K5WS6+AjXsUg+JvirsjN3iIDO4C8k=
modernc.org/strutil v1.1.1/go.mod h1:DE+MQQ/hjKBZS2zNInV5hhcipt5rLPWkmpbGeW5mmdw=
modernc.org/tcl v1.11.2/go.mod h1:BRzgpajcGdS2qTxniOx9c/dcxjlbA7p12eJNmiriQYo=
modernc.org/token v1.0.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=

226
sqlite.go
View File

@@ -1153,14 +1153,18 @@ func (c *conn) bindText(pstmt uintptr, idx1 int, value string) (uintptr, error)
// int sqlite3_bind_blob(sqlite3_stmt*, int, const void*, int n, void(*)(void*));
func (c *conn) bindBlob(pstmt uintptr, idx1 int, value []byte) (uintptr, error) {
if len(value) == 0 {
if rc := sqlite3.Xsqlite3_bind_zeroblob(c.tls, pstmt, int32(idx1), 0); rc != sqlite3.SQLITE_OK {
return 0, c.errstr(rc)
}
return 0, nil
}
p, err := c.malloc(len(value))
if err != nil {
return 0, err
}
if len(value) != 0 {
copy((*libc.RawMem)(unsafe.Pointer(p))[:len(value):len(value)], value)
}
copy((*libc.RawMem)(unsafe.Pointer(p))[:len(value):len(value)], value)
if rc := sqlite3.Xsqlite3_bind_blob(c.tls, pstmt, int32(idx1), p, int32(len(value)), 0); rc != sqlite3.SQLITE_OK {
c.free(p)
return 0, c.errstr(rc)
@@ -1370,6 +1374,7 @@ func (c *conn) Close() error {
c.db = 0
}
if c.tls != nil {
c.tls.Close()
c.tls = nil
@@ -1386,6 +1391,32 @@ func (c *conn) closeV2(db uintptr) error {
return nil
}
type userDefinedFunction struct {
zFuncName uintptr
nArg int32
eTextRep int32
xFunc func(*libc.TLS, uintptr, int32, uintptr)
freeOnce sync.Once
}
func (c *conn) createFunctionInternal(fun *userDefinedFunction) error {
if rc := sqlite3.Xsqlite3_create_function(
c.tls,
c.db,
fun.zFuncName,
fun.nArg,
fun.eTextRep,
0,
*(*uintptr)(unsafe.Pointer(&fun.xFunc)),
0,
0,
); rc != sqlite3.SQLITE_OK {
return c.errstr(rc)
}
return nil
}
// Execer is an optional interface that may be implemented by a Conn.
//
// If a Conn does not implement Execer, the sql package's DB.Exec will first
@@ -1451,9 +1482,14 @@ func (c *conn) query(ctx context.Context, query string, args []driver.NamedValue
}
// Driver implements database/sql/driver.Driver.
type Driver struct{}
type Driver struct {
// user defined functions that are added to every new connection on Open
udfs map[string]*userDefinedFunction
}
func newDriver() *Driver { return &Driver{} }
var d = &Driver{udfs: make(map[string]*userDefinedFunction)}
func newDriver() *Driver { return d }
// Open returns a new connection to the database. The name is a string in a
// driver-specific format.
@@ -1488,5 +1524,181 @@ func (d *Driver) Open(name string) (driver.Conn, error) {
if LogSqlStatements {
log.Println("new connection")
}
return newConn(name)
c, err := newConn(name)
if err != nil {
return nil, err
}
for _, udf := range d.udfs {
if err = c.createFunctionInternal(udf); err != nil {
c.Close()
return nil, err
}
}
return c, nil
}
// FunctionContext represents the context user defined functions execute in.
// Fields and/or methods of this type may get addedd in the future.
type FunctionContext struct{}
const sqliteValPtrSize = unsafe.Sizeof(&sqlite3.Sqlite3_value{})
// RegisterScalarFunction registers a scalar function named zFuncName with nArg
// arguments. Passing -1 for nArg indicates the function is variadic.
//
// The new function will be available to all new connections opened after
// executing RegisterScalarFunction.
func RegisterScalarFunction(
zFuncName string,
nArg int32,
xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error),
) error {
return registerScalarFunction(zFuncName, nArg, sqlite3.SQLITE_UTF8, xFunc)
}
// MustRegisterScalarFunction is like RegisterScalarFunction but panics on
// error.
func MustRegisterScalarFunction(
zFuncName string,
nArg int32,
xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error),
) {
if err := RegisterScalarFunction(zFuncName, nArg, xFunc); err != nil {
panic(err)
}
}
// MustRegisterDeterministicScalarFunction is like
// RegisterDeterministicScalarFunction but panics on error.
func MustRegisterDeterministicScalarFunction(
zFuncName string,
nArg int32,
xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error),
) {
if err := RegisterDeterministicScalarFunction(zFuncName, nArg, xFunc); err != nil {
panic(err)
}
}
// RegisterDeterministicScalarFunction registers a deterministic scalar
// function named zFuncName with nArg arguments. Passing -1 for nArg indicates
// the function is variadic. A deterministic function means that the function
// always gives the same output when the input parameters are the same.
//
// The new function will be available to all new connections opened after
// executing RegisterDeterministicScalarFunction.
func RegisterDeterministicScalarFunction(
zFuncName string,
nArg int32,
xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error),
) error {
return registerScalarFunction(zFuncName, nArg, sqlite3.SQLITE_UTF8|sqlite3.SQLITE_DETERMINISTIC, xFunc)
}
func registerScalarFunction(
zFuncName string,
nArg int32,
eTextRep int32,
xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error),
) error {
if _, ok := d.udfs[zFuncName]; ok {
return fmt.Errorf("a function named %q is already registered", zFuncName)
}
// dont free, functions registered on the driver live as long as the program
name, err := libc.CString(zFuncName)
if err != nil {
return err
}
udf := &userDefinedFunction{
zFuncName: name,
nArg: nArg,
eTextRep: eTextRep,
xFunc: func(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) {
setErrorResult := func(res error) {
errmsg, cerr := libc.CString(res.Error())
if cerr != nil {
panic(cerr)
}
defer libc.Xfree(tls, errmsg)
sqlite3.Xsqlite3_result_error(tls, ctx, errmsg, -1)
sqlite3.Xsqlite3_result_error_code(tls, ctx, sqlite3.SQLITE_ERROR)
}
args := make([]driver.Value, argc)
for i := int32(0); i < argc; i++ {
valPtr := *(*uintptr)(unsafe.Pointer(argv + uintptr(i)*sqliteValPtrSize))
switch valType := sqlite3.Xsqlite3_value_type(tls, valPtr); valType {
case sqlite3.SQLITE_TEXT:
args[i] = libc.GoString(sqlite3.Xsqlite3_value_text(tls, valPtr))
case sqlite3.SQLITE_INTEGER:
args[i] = sqlite3.Xsqlite3_value_int64(tls, valPtr)
case sqlite3.SQLITE_FLOAT:
args[i] = sqlite3.Xsqlite3_value_double(tls, valPtr)
case sqlite3.SQLITE_NULL:
args[i] = nil
case sqlite3.SQLITE_BLOB:
size := sqlite3.Xsqlite3_value_bytes(tls, valPtr)
blobPtr := sqlite3.Xsqlite3_value_blob(tls, valPtr)
v := make([]byte, size)
copy(v, (*libc.RawMem)(unsafe.Pointer(blobPtr))[:size:size])
args[i] = v
default:
panic(fmt.Sprintf("unexpected argument type %q passed by sqlite", valType))
}
}
res, err := xFunc(&FunctionContext{}, args)
if err != nil {
setErrorResult(err)
return
}
switch resTyped := res.(type) {
case nil:
sqlite3.Xsqlite3_result_null(tls, ctx)
case int64:
sqlite3.Xsqlite3_result_int64(tls, ctx, resTyped)
case float64:
sqlite3.Xsqlite3_result_double(tls, ctx, resTyped)
case bool:
sqlite3.Xsqlite3_result_int(tls, ctx, libc.Bool32(resTyped))
case time.Time:
sqlite3.Xsqlite3_result_int64(tls, ctx, resTyped.Unix())
case string:
size := int32(len(resTyped))
cstr, err := libc.CString(resTyped)
if err != nil {
panic(err)
}
defer libc.Xfree(tls, cstr)
sqlite3.Xsqlite3_result_text(tls, ctx, cstr, size, sqlite3.SQLITE_TRANSIENT)
case []byte:
size := int32(len(resTyped))
if size == 0 {
sqlite3.Xsqlite3_result_zeroblob(tls, ctx, 0)
return
}
p := libc.Xmalloc(tls, types.Size_t(size))
if p == 0 {
panic(fmt.Sprintf("unable to allocate space for blob: %d", size))
}
defer libc.Xfree(tls, p)
copy((*libc.RawMem)(unsafe.Pointer(p))[:size:size], resTyped)
sqlite3.Xsqlite3_result_blob(tls, ctx, p, size, sqlite3.SQLITE_TRANSIENT)
default:
setErrorResult(fmt.Errorf("function did not return a valid driver.Value: %T", resTyped))
return
}
},
}
d.udfs[zFuncName] = udf
return nil
}