mirror of
https://github.com/glebarez/go-sqlite.git
synced 2025-10-04 07:26:28 +08:00
merge upstream
This commit is contained in:
265
all_test.go
265
all_test.go
@@ -9,6 +9,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
@@ -19,6 +20,7 @@ import (
|
||||
"path"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
@@ -26,6 +28,7 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"modernc.org/libc"
|
||||
"modernc.org/mathutil"
|
||||
@@ -146,6 +149,26 @@ func tempDB(t testing.TB) (string, *sql.DB) {
|
||||
return dir, db
|
||||
}
|
||||
|
||||
// https://gitlab.com/cznic/sqlite/issues/98
|
||||
func TestIssue98(t *testing.T) {
|
||||
dir, db := tempDB(t)
|
||||
|
||||
defer func() {
|
||||
db.Close()
|
||||
os.RemoveAll(dir)
|
||||
}()
|
||||
|
||||
if _, err := db.Exec("create table t(b mediumblob not null)"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := db.Exec("insert into t values (?)", []byte{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := db.Exec("insert into t values (?)", nil); err == nil {
|
||||
t.Fatal(errors.New("expected statement to fail"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestScalar(t *testing.T) {
|
||||
dir, db := tempDB(t)
|
||||
|
||||
@@ -214,6 +237,243 @@ func TestScalar(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedefineUserDefinedFunction(t *testing.T) {
|
||||
dir, db := tempDB(t)
|
||||
ctx := context.Background()
|
||||
|
||||
defer func() {
|
||||
db.Close()
|
||||
os.RemoveAll(dir)
|
||||
}()
|
||||
|
||||
connection, err := db.Conn(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var r int
|
||||
funName := "test"
|
||||
|
||||
if err = connection.Raw(func(driverConn interface{}) error {
|
||||
c := driverConn.(*conn)
|
||||
|
||||
name, err := libc.CString(funName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.createFunctionInternal(&userDefinedFunction{
|
||||
zFuncName: name,
|
||||
nArg: 0,
|
||||
eTextRep: sqlite3.SQLITE_UTF8 | sqlite3.SQLITE_DETERMINISTIC,
|
||||
xFunc: func(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) {
|
||||
sqlite3.Xsqlite3_result_int(tls, ctx, 1)
|
||||
},
|
||||
})
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
row := connection.QueryRowContext(ctx, "select test()")
|
||||
|
||||
if err := row.Scan(&r); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if g, e := r, 1; g != e {
|
||||
t.Fatal(g, e)
|
||||
}
|
||||
|
||||
if err = connection.Raw(func(driverConn interface{}) error {
|
||||
c := driverConn.(*conn)
|
||||
|
||||
name, err := libc.CString(funName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.createFunctionInternal(&userDefinedFunction{
|
||||
zFuncName: name,
|
||||
nArg: 0,
|
||||
eTextRep: sqlite3.SQLITE_UTF8 | sqlite3.SQLITE_DETERMINISTIC,
|
||||
xFunc: func(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) {
|
||||
sqlite3.Xsqlite3_result_int(tls, ctx, 2)
|
||||
},
|
||||
})
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
row = connection.QueryRowContext(ctx, "select test()")
|
||||
|
||||
if err := row.Scan(&r); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if g, e := r, 2; g != e {
|
||||
t.Fatal(g, e)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegexpUserDefinedFunction(t *testing.T) {
|
||||
dir, db := tempDB(t)
|
||||
ctx := context.Background()
|
||||
|
||||
defer func() {
|
||||
db.Close()
|
||||
os.RemoveAll(dir)
|
||||
}()
|
||||
|
||||
connection, err := db.Conn(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err = connection.Raw(func(driverConn interface{}) error {
|
||||
c := driverConn.(*conn)
|
||||
|
||||
name, err := libc.CString("regexp")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.createFunctionInternal(&userDefinedFunction{
|
||||
zFuncName: name,
|
||||
nArg: 2,
|
||||
eTextRep: sqlite3.SQLITE_UTF8 | sqlite3.SQLITE_DETERMINISTIC,
|
||||
xFunc: func(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) {
|
||||
const sqliteValPtrSize = unsafe.Sizeof(&sqlite3.Sqlite3_value{})
|
||||
|
||||
argvv := make([]uintptr, argc)
|
||||
for i := int32(0); i < argc; i++ {
|
||||
argvv[i] = *(*uintptr)(unsafe.Pointer(argv + uintptr(i)*sqliteValPtrSize))
|
||||
}
|
||||
|
||||
setErrorResult := func(res error) {
|
||||
errmsg, cerr := libc.CString(res.Error())
|
||||
if cerr != nil {
|
||||
panic(cerr)
|
||||
}
|
||||
defer libc.Xfree(tls, errmsg)
|
||||
sqlite3.Xsqlite3_result_error(tls, ctx, errmsg, -1)
|
||||
sqlite3.Xsqlite3_result_error_code(tls, ctx, sqlite3.SQLITE_ERROR)
|
||||
}
|
||||
|
||||
var s1 string
|
||||
switch sqlite3.Xsqlite3_value_type(tls, argvv[0]) {
|
||||
case sqlite3.SQLITE_TEXT:
|
||||
s1 = libc.GoString(sqlite3.Xsqlite3_value_text(tls, argvv[0]))
|
||||
default:
|
||||
setErrorResult(errors.New("expected argv[0] to be text"))
|
||||
return
|
||||
}
|
||||
|
||||
var s2 string
|
||||
switch sqlite3.Xsqlite3_value_type(tls, argvv[1]) {
|
||||
case sqlite3.SQLITE_TEXT:
|
||||
s2 = libc.GoString(sqlite3.Xsqlite3_value_text(tls, argvv[1]))
|
||||
default:
|
||||
setErrorResult(errors.New("expected argv[1] to be text"))
|
||||
return
|
||||
}
|
||||
|
||||
matched, err := regexp.MatchString(s1, s2)
|
||||
if err != nil {
|
||||
setErrorResult(fmt.Errorf("bad regular expression: %q", err))
|
||||
return
|
||||
}
|
||||
sqlite3.Xsqlite3_result_int(tls, ctx, libc.Bool32(matched))
|
||||
},
|
||||
})
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("regexp filter", func(tt *testing.T) {
|
||||
t1 := "seafood"
|
||||
t2 := "fruit"
|
||||
|
||||
connection.ExecContext(ctx, `
|
||||
create table t(b text);
|
||||
insert into t values(?), (?);
|
||||
`, t1, t2)
|
||||
|
||||
rows, err := connection.QueryContext(ctx, "select * from t where b regexp 'foo.*'")
|
||||
if err != nil {
|
||||
tt.Fatal(err)
|
||||
}
|
||||
|
||||
type rec struct {
|
||||
b string
|
||||
}
|
||||
var a []rec
|
||||
for rows.Next() {
|
||||
var r rec
|
||||
if err := rows.Scan(&r.b); err != nil {
|
||||
tt.Fatal(err)
|
||||
}
|
||||
|
||||
a = append(a, r)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
tt.Fatal(err)
|
||||
}
|
||||
|
||||
if g, e := len(a), 1; g != e {
|
||||
tt.Fatal(g, e)
|
||||
}
|
||||
|
||||
if g, e := a[0].b, t1; g != e {
|
||||
tt.Fatal(g, e)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("regexp matches", func(tt *testing.T) {
|
||||
row := connection.QueryRowContext(ctx, "select 'seafood' regexp 'foo.*'")
|
||||
|
||||
var r int
|
||||
if err := row.Scan(&r); err != nil {
|
||||
tt.Fatal(err)
|
||||
}
|
||||
|
||||
if g, e := r, 1; g != e {
|
||||
tt.Fatal(g, e)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("regexp does not match", func(tt *testing.T) {
|
||||
row := connection.QueryRowContext(ctx, "select 'fruit' regexp 'foo.*'")
|
||||
|
||||
var r int
|
||||
if err := row.Scan(&r); err != nil {
|
||||
tt.Fatal(err)
|
||||
}
|
||||
|
||||
if g, e := r, 0; g != e {
|
||||
tt.Fatal(g, e)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("errors on bad regexp", func(tt *testing.T) {
|
||||
err := connection.QueryRowContext(ctx, "select 'seafood' regexp 'a(b'").Scan()
|
||||
if err == nil {
|
||||
tt.Fatal(errors.New("expected error, got none"))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("errors on bad first argument", func(tt *testing.T) {
|
||||
err := connection.QueryRowContext(ctx, "SELECT 1 REGEXP 'a(b'").Scan()
|
||||
if err == nil {
|
||||
tt.Fatal(errors.New("expected error, got none"))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("errors on bad second argument", func(tt *testing.T) {
|
||||
err := connection.QueryRowContext(ctx, "SELECT 'seafood' REGEXP 1").Scan()
|
||||
if err == nil {
|
||||
tt.Fatal(errors.New("expected error, got none"))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBlob(t *testing.T) {
|
||||
dir, db := tempDB(t)
|
||||
|
||||
@@ -561,6 +821,11 @@ func TestConcurrentProcesses(t *testing.T) {
|
||||
t.Skip("skipping test in short mode")
|
||||
}
|
||||
|
||||
//TODO The current riscv64 board seems too slow for the hardcoded timeouts.
|
||||
if runtime.GOARCH == "riscv64" {
|
||||
t.Skip("skipping test")
|
||||
}
|
||||
|
||||
dir, err := ioutil.TempDir("", "sqlite-test-")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
Reference in New Issue
Block a user