driver: support ?NNN and $NNN parameters, add tests

Fix a bug around ensuring each index matches a corresponding
argument.

Support ?NNN parameters by checking for a match between NNN and
Ordinal. Do the same for $NNN even though $NNN should technically
require sql.Named.

Updates https://gitlab.com/cznic/sqlite/-/issues/42.
This commit is contained in:
Dan Peterson
2021-01-17 12:57:51 -04:00
parent c1357b87cd
commit 2d062fa148
4 changed files with 206 additions and 4 deletions

View File

@@ -8,6 +8,7 @@
# #
# Please keep the list sorted. # Please keep the list sorted.
Dan Peterson <danp@danp.net>
Davsk Ltd Co <skinner.david@gmail.com> Davsk Ltd Co <skinner.david@gmail.com>
Jaap Aarts <jaap.aarts1@gmail.com> Jaap Aarts <jaap.aarts1@gmail.com>
Jan Mercl <0xjnml@gmail.com> Jan Mercl <0xjnml@gmail.com>

View File

@@ -7,6 +7,7 @@
# Please keep the list sorted. # Please keep the list sorted.
Alexander Menzhinsky <amenzhinsky@gmail.com> Alexander Menzhinsky <amenzhinsky@gmail.com>
Dan Peterson <danp@danp.net>
David Skinner <skinner.david@gmail.com> David Skinner <skinner.david@gmail.com>
Jaap Aarts <jaap.aarts1@gmail.com> Jaap Aarts <jaap.aarts1@gmail.com>
Jan Mercl <0xjnml@gmail.com> Jan Mercl <0xjnml@gmail.com>

View File

@@ -16,6 +16,7 @@ import (
"os/exec" "os/exec"
"path" "path"
"path/filepath" "path/filepath"
"reflect"
"runtime" "runtime"
"runtime/debug" "runtime/debug"
"strconv" "strconv"
@@ -1047,3 +1048,189 @@ func TestTime(t *testing.T) {
} }
} }
} }
// https://sqlite.org/lang_expr.html#varparam
// https://gitlab.com/cznic/sqlite/-/issues/42
func TestBinding(t *testing.T) {
t.Run("DB", func(t *testing.T) {
testBinding(t, func(db *sql.DB, query string, args ...interface{}) (*sql.Row, func()) {
return db.QueryRow(query, args...), func() {}
})
})
t.Run("Prepare", func(t *testing.T) {
testBinding(t, func(db *sql.DB, query string, args ...interface{}) (*sql.Row, func()) {
stmt, err := db.Prepare(query)
if err != nil {
t.Fatal(err)
}
return stmt.QueryRow(args...), func() { stmt.Close() }
})
})
}
func testBinding(t *testing.T, query func(db *sql.DB, query string, args ...interface{}) (*sql.Row, func())) {
db, err := sql.Open(driverName, "file::memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
for _, tc := range []struct {
q string
in []interface{}
w []int
}{
{
q: "?, ?, ?",
in: []interface{}{1, 2, 3},
w: []int{1, 2, 3},
},
{
q: "?1, ?2, ?3",
in: []interface{}{1, 2, 3},
w: []int{1, 2, 3},
},
{
q: "?1, ?, ?3",
in: []interface{}{1, 2, 3},
w: []int{1, 2, 3},
},
{
q: "?3, ?2, ?1",
in: []interface{}{1, 2, 3},
w: []int{3, 2, 1},
},
{
q: "?1, ?1, ?2",
in: []interface{}{1, 2},
w: []int{1, 1, 2},
},
{
q: ":one, :two, :three",
in: []interface{}{sql.Named("one", 1), sql.Named("two", 2), sql.Named("three", 3)},
w: []int{1, 2, 3},
},
{
q: ":one, :one, :two",
in: []interface{}{sql.Named("one", 1), sql.Named("two", 2)},
w: []int{1, 1, 2},
},
{
q: "@one, @two, @three",
in: []interface{}{sql.Named("one", 1), sql.Named("two", 2), sql.Named("three", 3)},
w: []int{1, 2, 3},
},
{
q: "@one, @one, @two",
in: []interface{}{sql.Named("one", 1), sql.Named("two", 2)},
w: []int{1, 1, 2},
},
{
q: "$one, $two, $three",
in: []interface{}{sql.Named("one", 1), sql.Named("two", 2), sql.Named("three", 3)},
w: []int{1, 2, 3},
},
{
// A common usage that should technically require sql.Named but
// does not.
q: "$1, $2, $3",
in: []interface{}{1, 2, 3},
w: []int{1, 2, 3},
},
{
q: "$one, $one, $two",
in: []interface{}{sql.Named("one", 1), sql.Named("two", 2)},
w: []int{1, 1, 2},
},
{
q: ":one, @one, $one",
in: []interface{}{sql.Named("one", 1)},
w: []int{1, 1, 1},
},
} {
got := make([]int, len(tc.w))
ptrs := make([]interface{}, len(got))
for i := range got {
ptrs[i] = &got[i]
}
row, cleanup := query(db, "select "+tc.q, tc.in...)
defer cleanup()
if err := row.Scan(ptrs...); err != nil {
t.Errorf("query(%q, %+v) = %s", tc.q, tc.in, err)
continue
}
if !reflect.DeepEqual(got, tc.w) {
t.Errorf("query(%q, %+v) = %#+v, want %#+v", tc.q, tc.in, got, tc.w)
}
}
}
func TestBindingError(t *testing.T) {
t.Run("DB", func(t *testing.T) {
testBindingError(t, func(db *sql.DB, query string, args ...interface{}) (*sql.Row, func()) {
return db.QueryRow(query, args...), func() {}
})
})
t.Run("Prepare", func(t *testing.T) {
testBindingError(t, func(db *sql.DB, query string, args ...interface{}) (*sql.Row, func()) {
stmt, err := db.Prepare(query)
if err != nil {
t.Fatal(err)
}
return stmt.QueryRow(args...), func() { stmt.Close() }
})
})
}
func testBindingError(t *testing.T, query func(db *sql.DB, query string, args ...interface{}) (*sql.Row, func())) {
db, err := sql.Open(driverName, "file::memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
for _, tc := range []struct {
q string
in []interface{}
}{
{
q: "?",
in: []interface{}{},
},
{
q: "?500, ?",
in: []interface{}{1, 2},
},
{
q: ":one",
in: []interface{}{1},
},
{
q: "@one",
in: []interface{}{1},
},
{
q: "$one",
in: []interface{}{1},
},
} {
got := make([]int, 2)
ptrs := make([]interface{}, len(got))
for i := range got {
ptrs[i] = &got[i]
}
row, cleanup := query(db, "select "+tc.q, tc.in...)
defer cleanup()
err := row.Scan(ptrs...)
if err == nil || (!strings.Contains(err.Error(), "missing argument with index") && !strings.Contains(err.Error(), "missing named argument")) {
t.Errorf("query(%q, %+v) unexpected error %+v", tc.q, tc.in, err)
}
}
}

View File

@@ -15,6 +15,7 @@ import (
"io" "io"
"math" "math"
"reflect" "reflect"
"strconv"
"strings" "strings"
"time" "time"
"unsafe" "unsafe"
@@ -495,7 +496,6 @@ func (s *stmt) exec(ctx context.Context, args []driver.NamedValue) (r driver.Res
continue continue
} }
err = func() (err error) { err = func() (err error) {
n, err := s.c.bindParameterCount(pstmt) n, err := s.c.bindParameterCount(pstmt)
if err != nil { if err != nil {
return err return err
@@ -891,38 +891,51 @@ func (c *conn) bind(pstmt uintptr, n int, args []driver.NamedValue) (allocs []ui
}() }()
for i := 1; i <= n; i++ { for i := 1; i <= n; i++ {
var p uintptr
name, err := c.bindParameterName(pstmt, i) name, err := c.bindParameterName(pstmt, i)
if err != nil { if err != nil {
return allocs, err return allocs, err
} }
var found bool
var v driver.NamedValue var v driver.NamedValue
for _, v = range args { for _, v = range args {
if name != "" { if name != "" {
// For ?NNN and $NNN params, match if NNN == v.Ordinal.
//
// Supporting this for $NNN is a special case that makes eg
// `select $1, $2, $3 ...` work without needing to use
// sql.Named.
if (name[0] == '?' || name[0] == '$') && name[1:] == strconv.Itoa(v.Ordinal) {
found = true
break
}
// sqlite supports '$', '@' and ':' prefixes for string // sqlite supports '$', '@' and ':' prefixes for string
// identifiers and '?' for numeric, so we cannot // identifiers and '?' for numeric, so we cannot
// combine different prefixes with the same name // combine different prefixes with the same name
// because `database/sql` requires variable names // because `database/sql` requires variable names
// to start with a letter // to start with a letter
if name[1:] == v.Name[:] { if name[1:] == v.Name[:] {
found = true
break break
} }
} else { } else {
if v.Ordinal == i { if v.Ordinal == i {
found = true
break break
} }
} }
} }
if v.Ordinal == 0 { if !found {
if name != "" { if name != "" {
return allocs, fmt.Errorf("missing named argument %q", name[1:]) return allocs, fmt.Errorf("missing named argument %q", name[1:])
} }
return allocs, fmt.Errorf("missing argument with %d index", i) return allocs, fmt.Errorf("missing argument with index %d", i)
} }
var p uintptr
switch x := v.Value.(type) { switch x := v.Value.(type) {
case int64: case int64:
if err := c.bindInt64(pstmt, i, x); err != nil { if err := c.bindInt64(pstmt, i, x); err != nil {