mirror of
https://github.com/glebarez/go-sqlite.git
synced 2025-10-04 23:42:40 +08:00
driver: support ?NNN and $NNN parameters, add tests
Fix a bug around ensuring each index matches a corresponding argument. Support ?NNN parameters by checking for a match between NNN and Ordinal. Do the same for $NNN even though $NNN should technically require sql.Named. Updates https://gitlab.com/cznic/sqlite/-/issues/42.
This commit is contained in:
1
AUTHORS
1
AUTHORS
@@ -8,6 +8,7 @@
|
|||||||
#
|
#
|
||||||
# Please keep the list sorted.
|
# Please keep the list sorted.
|
||||||
|
|
||||||
|
Dan Peterson <danp@danp.net>
|
||||||
Davsk Ltd Co <skinner.david@gmail.com>
|
Davsk Ltd Co <skinner.david@gmail.com>
|
||||||
Jaap Aarts <jaap.aarts1@gmail.com>
|
Jaap Aarts <jaap.aarts1@gmail.com>
|
||||||
Jan Mercl <0xjnml@gmail.com>
|
Jan Mercl <0xjnml@gmail.com>
|
||||||
|
@@ -7,6 +7,7 @@
|
|||||||
# Please keep the list sorted.
|
# Please keep the list sorted.
|
||||||
|
|
||||||
Alexander Menzhinsky <amenzhinsky@gmail.com>
|
Alexander Menzhinsky <amenzhinsky@gmail.com>
|
||||||
|
Dan Peterson <danp@danp.net>
|
||||||
David Skinner <skinner.david@gmail.com>
|
David Skinner <skinner.david@gmail.com>
|
||||||
Jaap Aarts <jaap.aarts1@gmail.com>
|
Jaap Aarts <jaap.aarts1@gmail.com>
|
||||||
Jan Mercl <0xjnml@gmail.com>
|
Jan Mercl <0xjnml@gmail.com>
|
||||||
|
187
all_test.go
187
all_test.go
@@ -16,6 +16,7 @@ import (
|
|||||||
"os/exec"
|
"os/exec"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -1047,3 +1048,189 @@ func TestTime(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://sqlite.org/lang_expr.html#varparam
|
||||||
|
// https://gitlab.com/cznic/sqlite/-/issues/42
|
||||||
|
func TestBinding(t *testing.T) {
|
||||||
|
t.Run("DB", func(t *testing.T) {
|
||||||
|
testBinding(t, func(db *sql.DB, query string, args ...interface{}) (*sql.Row, func()) {
|
||||||
|
return db.QueryRow(query, args...), func() {}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Prepare", func(t *testing.T) {
|
||||||
|
testBinding(t, func(db *sql.DB, query string, args ...interface{}) (*sql.Row, func()) {
|
||||||
|
stmt, err := db.Prepare(query)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return stmt.QueryRow(args...), func() { stmt.Close() }
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func testBinding(t *testing.T, query func(db *sql.DB, query string, args ...interface{}) (*sql.Row, func())) {
|
||||||
|
db, err := sql.Open(driverName, "file::memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
for _, tc := range []struct {
|
||||||
|
q string
|
||||||
|
in []interface{}
|
||||||
|
w []int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
q: "?, ?, ?",
|
||||||
|
in: []interface{}{1, 2, 3},
|
||||||
|
w: []int{1, 2, 3},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
q: "?1, ?2, ?3",
|
||||||
|
in: []interface{}{1, 2, 3},
|
||||||
|
w: []int{1, 2, 3},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
q: "?1, ?, ?3",
|
||||||
|
in: []interface{}{1, 2, 3},
|
||||||
|
w: []int{1, 2, 3},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
q: "?3, ?2, ?1",
|
||||||
|
in: []interface{}{1, 2, 3},
|
||||||
|
w: []int{3, 2, 1},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
q: "?1, ?1, ?2",
|
||||||
|
in: []interface{}{1, 2},
|
||||||
|
w: []int{1, 1, 2},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
q: ":one, :two, :three",
|
||||||
|
in: []interface{}{sql.Named("one", 1), sql.Named("two", 2), sql.Named("three", 3)},
|
||||||
|
w: []int{1, 2, 3},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
q: ":one, :one, :two",
|
||||||
|
in: []interface{}{sql.Named("one", 1), sql.Named("two", 2)},
|
||||||
|
w: []int{1, 1, 2},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
q: "@one, @two, @three",
|
||||||
|
in: []interface{}{sql.Named("one", 1), sql.Named("two", 2), sql.Named("three", 3)},
|
||||||
|
w: []int{1, 2, 3},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
q: "@one, @one, @two",
|
||||||
|
in: []interface{}{sql.Named("one", 1), sql.Named("two", 2)},
|
||||||
|
w: []int{1, 1, 2},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
q: "$one, $two, $three",
|
||||||
|
in: []interface{}{sql.Named("one", 1), sql.Named("two", 2), sql.Named("three", 3)},
|
||||||
|
w: []int{1, 2, 3},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// A common usage that should technically require sql.Named but
|
||||||
|
// does not.
|
||||||
|
q: "$1, $2, $3",
|
||||||
|
in: []interface{}{1, 2, 3},
|
||||||
|
w: []int{1, 2, 3},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
q: "$one, $one, $two",
|
||||||
|
in: []interface{}{sql.Named("one", 1), sql.Named("two", 2)},
|
||||||
|
w: []int{1, 1, 2},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
q: ":one, @one, $one",
|
||||||
|
in: []interface{}{sql.Named("one", 1)},
|
||||||
|
w: []int{1, 1, 1},
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
got := make([]int, len(tc.w))
|
||||||
|
ptrs := make([]interface{}, len(got))
|
||||||
|
for i := range got {
|
||||||
|
ptrs[i] = &got[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
row, cleanup := query(db, "select "+tc.q, tc.in...)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
if err := row.Scan(ptrs...); err != nil {
|
||||||
|
t.Errorf("query(%q, %+v) = %s", tc.q, tc.in, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got, tc.w) {
|
||||||
|
t.Errorf("query(%q, %+v) = %#+v, want %#+v", tc.q, tc.in, got, tc.w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBindingError(t *testing.T) {
|
||||||
|
t.Run("DB", func(t *testing.T) {
|
||||||
|
testBindingError(t, func(db *sql.DB, query string, args ...interface{}) (*sql.Row, func()) {
|
||||||
|
return db.QueryRow(query, args...), func() {}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Prepare", func(t *testing.T) {
|
||||||
|
testBindingError(t, func(db *sql.DB, query string, args ...interface{}) (*sql.Row, func()) {
|
||||||
|
stmt, err := db.Prepare(query)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return stmt.QueryRow(args...), func() { stmt.Close() }
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func testBindingError(t *testing.T, query func(db *sql.DB, query string, args ...interface{}) (*sql.Row, func())) {
|
||||||
|
db, err := sql.Open(driverName, "file::memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
for _, tc := range []struct {
|
||||||
|
q string
|
||||||
|
in []interface{}
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
q: "?",
|
||||||
|
in: []interface{}{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
q: "?500, ?",
|
||||||
|
in: []interface{}{1, 2},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
q: ":one",
|
||||||
|
in: []interface{}{1},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
q: "@one",
|
||||||
|
in: []interface{}{1},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
q: "$one",
|
||||||
|
in: []interface{}{1},
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
got := make([]int, 2)
|
||||||
|
ptrs := make([]interface{}, len(got))
|
||||||
|
for i := range got {
|
||||||
|
ptrs[i] = &got[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
row, cleanup := query(db, "select "+tc.q, tc.in...)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
err := row.Scan(ptrs...)
|
||||||
|
if err == nil || (!strings.Contains(err.Error(), "missing argument with index") && !strings.Contains(err.Error(), "missing named argument")) {
|
||||||
|
t.Errorf("query(%q, %+v) unexpected error %+v", tc.q, tc.in, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
21
sqlite.go
21
sqlite.go
@@ -15,6 +15,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
@@ -495,7 +496,6 @@ func (s *stmt) exec(ctx context.Context, args []driver.NamedValue) (r driver.Res
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
err = func() (err error) {
|
err = func() (err error) {
|
||||||
|
|
||||||
n, err := s.c.bindParameterCount(pstmt)
|
n, err := s.c.bindParameterCount(pstmt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -891,38 +891,51 @@ func (c *conn) bind(pstmt uintptr, n int, args []driver.NamedValue) (allocs []ui
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
for i := 1; i <= n; i++ {
|
for i := 1; i <= n; i++ {
|
||||||
var p uintptr
|
|
||||||
name, err := c.bindParameterName(pstmt, i)
|
name, err := c.bindParameterName(pstmt, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return allocs, err
|
return allocs, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var found bool
|
||||||
var v driver.NamedValue
|
var v driver.NamedValue
|
||||||
for _, v = range args {
|
for _, v = range args {
|
||||||
if name != "" {
|
if name != "" {
|
||||||
|
// For ?NNN and $NNN params, match if NNN == v.Ordinal.
|
||||||
|
//
|
||||||
|
// Supporting this for $NNN is a special case that makes eg
|
||||||
|
// `select $1, $2, $3 ...` work without needing to use
|
||||||
|
// sql.Named.
|
||||||
|
if (name[0] == '?' || name[0] == '$') && name[1:] == strconv.Itoa(v.Ordinal) {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
// sqlite supports '$', '@' and ':' prefixes for string
|
// sqlite supports '$', '@' and ':' prefixes for string
|
||||||
// identifiers and '?' for numeric, so we cannot
|
// identifiers and '?' for numeric, so we cannot
|
||||||
// combine different prefixes with the same name
|
// combine different prefixes with the same name
|
||||||
// because `database/sql` requires variable names
|
// because `database/sql` requires variable names
|
||||||
// to start with a letter
|
// to start with a letter
|
||||||
if name[1:] == v.Name[:] {
|
if name[1:] == v.Name[:] {
|
||||||
|
found = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if v.Ordinal == i {
|
if v.Ordinal == i {
|
||||||
|
found = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if v.Ordinal == 0 {
|
if !found {
|
||||||
if name != "" {
|
if name != "" {
|
||||||
return allocs, fmt.Errorf("missing named argument %q", name[1:])
|
return allocs, fmt.Errorf("missing named argument %q", name[1:])
|
||||||
}
|
}
|
||||||
|
|
||||||
return allocs, fmt.Errorf("missing argument with %d index", i)
|
return allocs, fmt.Errorf("missing argument with index %d", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var p uintptr
|
||||||
switch x := v.Value.(type) {
|
switch x := v.Value.(type) {
|
||||||
case int64:
|
case int64:
|
||||||
if err := c.bindInt64(pstmt, i, x); err != nil {
|
if err := c.bindInt64(pstmt, i, x); err != nil {
|
||||||
|
Reference in New Issue
Block a user