diff --git a/AUTHORS b/AUTHORS index cce2dd5..8e2fd5f 100644 --- a/AUTHORS +++ b/AUTHORS @@ -8,6 +8,7 @@ # # Please keep the list sorted. +Dan Peterson Davsk Ltd Co Jaap Aarts Jan Mercl <0xjnml@gmail.com> diff --git a/CONTRIBUTORS b/CONTRIBUTORS index 3165b32..5ff000b 100644 --- a/CONTRIBUTORS +++ b/CONTRIBUTORS @@ -7,6 +7,7 @@ # Please keep the list sorted. Alexander Menzhinsky +Dan Peterson David Skinner Jaap Aarts Jan Mercl <0xjnml@gmail.com> diff --git a/all_test.go b/all_test.go index b451533..688fb1b 100644 --- a/all_test.go +++ b/all_test.go @@ -16,6 +16,7 @@ import ( "os/exec" "path" "path/filepath" + "reflect" "runtime" "runtime/debug" "strconv" @@ -1047,3 +1048,189 @@ func TestTime(t *testing.T) { } } } + +// https://sqlite.org/lang_expr.html#varparam +// https://gitlab.com/cznic/sqlite/-/issues/42 +func TestBinding(t *testing.T) { + t.Run("DB", func(t *testing.T) { + testBinding(t, func(db *sql.DB, query string, args ...interface{}) (*sql.Row, func()) { + return db.QueryRow(query, args...), func() {} + }) + }) + + t.Run("Prepare", func(t *testing.T) { + testBinding(t, func(db *sql.DB, query string, args ...interface{}) (*sql.Row, func()) { + stmt, err := db.Prepare(query) + if err != nil { + t.Fatal(err) + } + return stmt.QueryRow(args...), func() { stmt.Close() } + }) + }) +} + +func testBinding(t *testing.T, query func(db *sql.DB, query string, args ...interface{}) (*sql.Row, func())) { + db, err := sql.Open(driverName, "file::memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + for _, tc := range []struct { + q string + in []interface{} + w []int + }{ + { + q: "?, ?, ?", + in: []interface{}{1, 2, 3}, + w: []int{1, 2, 3}, + }, + { + q: "?1, ?2, ?3", + in: []interface{}{1, 2, 3}, + w: []int{1, 2, 3}, + }, + { + q: "?1, ?, ?3", + in: []interface{}{1, 2, 3}, + w: []int{1, 2, 3}, + }, + { + q: "?3, ?2, ?1", + in: []interface{}{1, 2, 3}, + w: []int{3, 2, 1}, + }, + { + q: "?1, ?1, ?2", + in: []interface{}{1, 2}, + w: []int{1, 1, 2}, + }, + { + q: ":one, :two, :three", + in: []interface{}{sql.Named("one", 1), sql.Named("two", 2), sql.Named("three", 3)}, + w: []int{1, 2, 3}, + }, + { + q: ":one, :one, :two", + in: []interface{}{sql.Named("one", 1), sql.Named("two", 2)}, + w: []int{1, 1, 2}, + }, + { + q: "@one, @two, @three", + in: []interface{}{sql.Named("one", 1), sql.Named("two", 2), sql.Named("three", 3)}, + w: []int{1, 2, 3}, + }, + { + q: "@one, @one, @two", + in: []interface{}{sql.Named("one", 1), sql.Named("two", 2)}, + w: []int{1, 1, 2}, + }, + { + q: "$one, $two, $three", + in: []interface{}{sql.Named("one", 1), sql.Named("two", 2), sql.Named("three", 3)}, + w: []int{1, 2, 3}, + }, + { + // A common usage that should technically require sql.Named but + // does not. + q: "$1, $2, $3", + in: []interface{}{1, 2, 3}, + w: []int{1, 2, 3}, + }, + { + q: "$one, $one, $two", + in: []interface{}{sql.Named("one", 1), sql.Named("two", 2)}, + w: []int{1, 1, 2}, + }, + { + q: ":one, @one, $one", + in: []interface{}{sql.Named("one", 1)}, + w: []int{1, 1, 1}, + }, + } { + got := make([]int, len(tc.w)) + ptrs := make([]interface{}, len(got)) + for i := range got { + ptrs[i] = &got[i] + } + + row, cleanup := query(db, "select "+tc.q, tc.in...) + defer cleanup() + + if err := row.Scan(ptrs...); err != nil { + t.Errorf("query(%q, %+v) = %s", tc.q, tc.in, err) + continue + } + + if !reflect.DeepEqual(got, tc.w) { + t.Errorf("query(%q, %+v) = %#+v, want %#+v", tc.q, tc.in, got, tc.w) + } + } +} + +func TestBindingError(t *testing.T) { + t.Run("DB", func(t *testing.T) { + testBindingError(t, func(db *sql.DB, query string, args ...interface{}) (*sql.Row, func()) { + return db.QueryRow(query, args...), func() {} + }) + }) + + t.Run("Prepare", func(t *testing.T) { + testBindingError(t, func(db *sql.DB, query string, args ...interface{}) (*sql.Row, func()) { + stmt, err := db.Prepare(query) + if err != nil { + t.Fatal(err) + } + return stmt.QueryRow(args...), func() { stmt.Close() } + }) + }) +} + +func testBindingError(t *testing.T, query func(db *sql.DB, query string, args ...interface{}) (*sql.Row, func())) { + db, err := sql.Open(driverName, "file::memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + for _, tc := range []struct { + q string + in []interface{} + }{ + { + q: "?", + in: []interface{}{}, + }, + { + q: "?500, ?", + in: []interface{}{1, 2}, + }, + { + q: ":one", + in: []interface{}{1}, + }, + { + q: "@one", + in: []interface{}{1}, + }, + { + q: "$one", + in: []interface{}{1}, + }, + } { + got := make([]int, 2) + ptrs := make([]interface{}, len(got)) + for i := range got { + ptrs[i] = &got[i] + } + + row, cleanup := query(db, "select "+tc.q, tc.in...) + defer cleanup() + + err := row.Scan(ptrs...) + if err == nil || (!strings.Contains(err.Error(), "missing argument with index") && !strings.Contains(err.Error(), "missing named argument")) { + t.Errorf("query(%q, %+v) unexpected error %+v", tc.q, tc.in, err) + } + } +} diff --git a/sqlite.go b/sqlite.go index 9adc10b..df5340d 100644 --- a/sqlite.go +++ b/sqlite.go @@ -15,6 +15,7 @@ import ( "io" "math" "reflect" + "strconv" "strings" "time" "unsafe" @@ -495,7 +496,6 @@ func (s *stmt) exec(ctx context.Context, args []driver.NamedValue) (r driver.Res continue } err = func() (err error) { - n, err := s.c.bindParameterCount(pstmt) if err != nil { return err @@ -891,38 +891,51 @@ func (c *conn) bind(pstmt uintptr, n int, args []driver.NamedValue) (allocs []ui }() for i := 1; i <= n; i++ { - var p uintptr name, err := c.bindParameterName(pstmt, i) if err != nil { return allocs, err } + var found bool var v driver.NamedValue for _, v = range args { if name != "" { + // For ?NNN and $NNN params, match if NNN == v.Ordinal. + // + // Supporting this for $NNN is a special case that makes eg + // `select $1, $2, $3 ...` work without needing to use + // sql.Named. + if (name[0] == '?' || name[0] == '$') && name[1:] == strconv.Itoa(v.Ordinal) { + found = true + break + } + // sqlite supports '$', '@' and ':' prefixes for string // identifiers and '?' for numeric, so we cannot // combine different prefixes with the same name // because `database/sql` requires variable names // to start with a letter if name[1:] == v.Name[:] { + found = true break } } else { if v.Ordinal == i { + found = true break } } } - if v.Ordinal == 0 { + if !found { if name != "" { return allocs, fmt.Errorf("missing named argument %q", name[1:]) } - return allocs, fmt.Errorf("missing argument with %d index", i) + return allocs, fmt.Errorf("missing argument with index %d", i) } + var p uintptr switch x := v.Value.(type) { case int64: if err := c.bindInt64(pstmt, i, x); err != nil {