mirror of
https://github.com/glebarez/go-sqlite.git
synced 2025-10-04 15:32:46 +08:00
driver: support ?NNN and $NNN parameters, add tests
Fix a bug around ensuring each index matches a corresponding argument. Support ?NNN parameters by checking for a match between NNN and Ordinal. Do the same for $NNN even though $NNN should technically require sql.Named. Updates https://gitlab.com/cznic/sqlite/-/issues/42.
This commit is contained in:
1
AUTHORS
1
AUTHORS
@@ -8,6 +8,7 @@
|
||||
#
|
||||
# Please keep the list sorted.
|
||||
|
||||
Dan Peterson <danp@danp.net>
|
||||
Davsk Ltd Co <skinner.david@gmail.com>
|
||||
Jaap Aarts <jaap.aarts1@gmail.com>
|
||||
Jan Mercl <0xjnml@gmail.com>
|
||||
|
@@ -7,6 +7,7 @@
|
||||
# Please keep the list sorted.
|
||||
|
||||
Alexander Menzhinsky <amenzhinsky@gmail.com>
|
||||
Dan Peterson <danp@danp.net>
|
||||
David Skinner <skinner.david@gmail.com>
|
||||
Jaap Aarts <jaap.aarts1@gmail.com>
|
||||
Jan Mercl <0xjnml@gmail.com>
|
||||
|
187
all_test.go
187
all_test.go
@@ -16,6 +16,7 @@ import (
|
||||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
@@ -1047,3 +1048,189 @@ func TestTime(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// https://sqlite.org/lang_expr.html#varparam
|
||||
// https://gitlab.com/cznic/sqlite/-/issues/42
|
||||
func TestBinding(t *testing.T) {
|
||||
t.Run("DB", func(t *testing.T) {
|
||||
testBinding(t, func(db *sql.DB, query string, args ...interface{}) (*sql.Row, func()) {
|
||||
return db.QueryRow(query, args...), func() {}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Prepare", func(t *testing.T) {
|
||||
testBinding(t, func(db *sql.DB, query string, args ...interface{}) (*sql.Row, func()) {
|
||||
stmt, err := db.Prepare(query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return stmt.QueryRow(args...), func() { stmt.Close() }
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func testBinding(t *testing.T, query func(db *sql.DB, query string, args ...interface{}) (*sql.Row, func())) {
|
||||
db, err := sql.Open(driverName, "file::memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
for _, tc := range []struct {
|
||||
q string
|
||||
in []interface{}
|
||||
w []int
|
||||
}{
|
||||
{
|
||||
q: "?, ?, ?",
|
||||
in: []interface{}{1, 2, 3},
|
||||
w: []int{1, 2, 3},
|
||||
},
|
||||
{
|
||||
q: "?1, ?2, ?3",
|
||||
in: []interface{}{1, 2, 3},
|
||||
w: []int{1, 2, 3},
|
||||
},
|
||||
{
|
||||
q: "?1, ?, ?3",
|
||||
in: []interface{}{1, 2, 3},
|
||||
w: []int{1, 2, 3},
|
||||
},
|
||||
{
|
||||
q: "?3, ?2, ?1",
|
||||
in: []interface{}{1, 2, 3},
|
||||
w: []int{3, 2, 1},
|
||||
},
|
||||
{
|
||||
q: "?1, ?1, ?2",
|
||||
in: []interface{}{1, 2},
|
||||
w: []int{1, 1, 2},
|
||||
},
|
||||
{
|
||||
q: ":one, :two, :three",
|
||||
in: []interface{}{sql.Named("one", 1), sql.Named("two", 2), sql.Named("three", 3)},
|
||||
w: []int{1, 2, 3},
|
||||
},
|
||||
{
|
||||
q: ":one, :one, :two",
|
||||
in: []interface{}{sql.Named("one", 1), sql.Named("two", 2)},
|
||||
w: []int{1, 1, 2},
|
||||
},
|
||||
{
|
||||
q: "@one, @two, @three",
|
||||
in: []interface{}{sql.Named("one", 1), sql.Named("two", 2), sql.Named("three", 3)},
|
||||
w: []int{1, 2, 3},
|
||||
},
|
||||
{
|
||||
q: "@one, @one, @two",
|
||||
in: []interface{}{sql.Named("one", 1), sql.Named("two", 2)},
|
||||
w: []int{1, 1, 2},
|
||||
},
|
||||
{
|
||||
q: "$one, $two, $three",
|
||||
in: []interface{}{sql.Named("one", 1), sql.Named("two", 2), sql.Named("three", 3)},
|
||||
w: []int{1, 2, 3},
|
||||
},
|
||||
{
|
||||
// A common usage that should technically require sql.Named but
|
||||
// does not.
|
||||
q: "$1, $2, $3",
|
||||
in: []interface{}{1, 2, 3},
|
||||
w: []int{1, 2, 3},
|
||||
},
|
||||
{
|
||||
q: "$one, $one, $two",
|
||||
in: []interface{}{sql.Named("one", 1), sql.Named("two", 2)},
|
||||
w: []int{1, 1, 2},
|
||||
},
|
||||
{
|
||||
q: ":one, @one, $one",
|
||||
in: []interface{}{sql.Named("one", 1)},
|
||||
w: []int{1, 1, 1},
|
||||
},
|
||||
} {
|
||||
got := make([]int, len(tc.w))
|
||||
ptrs := make([]interface{}, len(got))
|
||||
for i := range got {
|
||||
ptrs[i] = &got[i]
|
||||
}
|
||||
|
||||
row, cleanup := query(db, "select "+tc.q, tc.in...)
|
||||
defer cleanup()
|
||||
|
||||
if err := row.Scan(ptrs...); err != nil {
|
||||
t.Errorf("query(%q, %+v) = %s", tc.q, tc.in, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, tc.w) {
|
||||
t.Errorf("query(%q, %+v) = %#+v, want %#+v", tc.q, tc.in, got, tc.w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindingError(t *testing.T) {
|
||||
t.Run("DB", func(t *testing.T) {
|
||||
testBindingError(t, func(db *sql.DB, query string, args ...interface{}) (*sql.Row, func()) {
|
||||
return db.QueryRow(query, args...), func() {}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Prepare", func(t *testing.T) {
|
||||
testBindingError(t, func(db *sql.DB, query string, args ...interface{}) (*sql.Row, func()) {
|
||||
stmt, err := db.Prepare(query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return stmt.QueryRow(args...), func() { stmt.Close() }
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func testBindingError(t *testing.T, query func(db *sql.DB, query string, args ...interface{}) (*sql.Row, func())) {
|
||||
db, err := sql.Open(driverName, "file::memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
for _, tc := range []struct {
|
||||
q string
|
||||
in []interface{}
|
||||
}{
|
||||
{
|
||||
q: "?",
|
||||
in: []interface{}{},
|
||||
},
|
||||
{
|
||||
q: "?500, ?",
|
||||
in: []interface{}{1, 2},
|
||||
},
|
||||
{
|
||||
q: ":one",
|
||||
in: []interface{}{1},
|
||||
},
|
||||
{
|
||||
q: "@one",
|
||||
in: []interface{}{1},
|
||||
},
|
||||
{
|
||||
q: "$one",
|
||||
in: []interface{}{1},
|
||||
},
|
||||
} {
|
||||
got := make([]int, 2)
|
||||
ptrs := make([]interface{}, len(got))
|
||||
for i := range got {
|
||||
ptrs[i] = &got[i]
|
||||
}
|
||||
|
||||
row, cleanup := query(db, "select "+tc.q, tc.in...)
|
||||
defer cleanup()
|
||||
|
||||
err := row.Scan(ptrs...)
|
||||
if err == nil || (!strings.Contains(err.Error(), "missing argument with index") && !strings.Contains(err.Error(), "missing named argument")) {
|
||||
t.Errorf("query(%q, %+v) unexpected error %+v", tc.q, tc.in, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
21
sqlite.go
21
sqlite.go
@@ -15,6 +15,7 @@ import (
|
||||
"io"
|
||||
"math"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
@@ -495,7 +496,6 @@ func (s *stmt) exec(ctx context.Context, args []driver.NamedValue) (r driver.Res
|
||||
continue
|
||||
}
|
||||
err = func() (err error) {
|
||||
|
||||
n, err := s.c.bindParameterCount(pstmt)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -891,38 +891,51 @@ func (c *conn) bind(pstmt uintptr, n int, args []driver.NamedValue) (allocs []ui
|
||||
}()
|
||||
|
||||
for i := 1; i <= n; i++ {
|
||||
var p uintptr
|
||||
name, err := c.bindParameterName(pstmt, i)
|
||||
if err != nil {
|
||||
return allocs, err
|
||||
}
|
||||
|
||||
var found bool
|
||||
var v driver.NamedValue
|
||||
for _, v = range args {
|
||||
if name != "" {
|
||||
// For ?NNN and $NNN params, match if NNN == v.Ordinal.
|
||||
//
|
||||
// Supporting this for $NNN is a special case that makes eg
|
||||
// `select $1, $2, $3 ...` work without needing to use
|
||||
// sql.Named.
|
||||
if (name[0] == '?' || name[0] == '$') && name[1:] == strconv.Itoa(v.Ordinal) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
|
||||
// sqlite supports '$', '@' and ':' prefixes for string
|
||||
// identifiers and '?' for numeric, so we cannot
|
||||
// combine different prefixes with the same name
|
||||
// because `database/sql` requires variable names
|
||||
// to start with a letter
|
||||
if name[1:] == v.Name[:] {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
} else {
|
||||
if v.Ordinal == i {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if v.Ordinal == 0 {
|
||||
if !found {
|
||||
if name != "" {
|
||||
return allocs, fmt.Errorf("missing named argument %q", name[1:])
|
||||
}
|
||||
|
||||
return allocs, fmt.Errorf("missing argument with %d index", i)
|
||||
return allocs, fmt.Errorf("missing argument with index %d", i)
|
||||
}
|
||||
|
||||
var p uintptr
|
||||
switch x := v.Value.(type) {
|
||||
case int64:
|
||||
if err := c.bindInt64(pstmt, i, x); err != nil {
|
||||
|
Reference in New Issue
Block a user