diff --git a/all_test.go b/all_test.go index 3c15858..e631e11 100644 --- a/all_test.go +++ b/all_test.go @@ -857,6 +857,67 @@ func TestNoRows(t *testing.T) { } } +func TestColumns(t *testing.T) { + db, err := sql.Open("sqlite", "file::memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + if _, err := db.Exec("create table t1(a integer, b text, c blob)"); err != nil { + t.Fatal(err) + } + + if _, err := db.Exec("insert into t1 (a) values (1)"); err != nil { + t.Fatal(err) + } + + rows, err := db.Query("select * from t1") + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + got, err := rows.Columns() + if err != nil { + t.Fatal(err) + } + + want := []string{"a", "b", "c"} + if !reflect.DeepEqual(got, want) { + t.Errorf("got columns %v, want %v", got, want) + } +} + +// https://gitlab.com/cznic/sqlite/-/issues/32 +func TestColumnsNoRows(t *testing.T) { + db, err := sql.Open("sqlite", "file::memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + if _, err := db.Exec("create table t1(a integer, b text, c blob)"); err != nil { + t.Fatal(err) + } + + rows, err := db.Query("select * from t1") + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + got, err := rows.Columns() + if err != nil { + t.Fatal(err) + } + + want := []string{"a", "b", "c"} + if !reflect.DeepEqual(got, want) { + t.Errorf("got columns %v, want %v", got, want) + } +} + // https://gitlab.com/cznic/sqlite/-/issues/28 func TestIssue28(t *testing.T) { tempDir, err := ioutil.TempDir("", "") @@ -887,7 +948,7 @@ func TestIssue28(t *testing.T) { } // https://gitlab.com/cznic/sqlite/-/issues/30 -func TestIssue30(t *testing.T) { +func TestColumnTypes(t *testing.T) { tempDir, err := ioutil.TempDir("", "") if err != nil { t.Fatal(err) @@ -904,7 +965,7 @@ func TestIssue30(t *testing.T) { defer db.Close() - _, err = db.Query("CREATE TABLE IF NOT EXISTS `userinfo` (`uid` INTEGER PRIMARY KEY AUTOINCREMENT,`username` VARCHAR(64) NULL, `departname` VARCHAR(64) NULL, `created` DATE NULL);") + _, err = db.Exec("CREATE TABLE IF NOT EXISTS `userinfo` (`uid` INTEGER PRIMARY KEY AUTOINCREMENT,`username` VARCHAR(64) NULL, `departname` VARCHAR(64) NULL, `created` DATE NULL);") if err != nil { t.Fatal(err) } @@ -919,24 +980,27 @@ func TestIssue30(t *testing.T) { if err != nil { t.Fatal(err) } + defer rows2.Close() + + columnTypes, err := rows2.ColumnTypes() + if err != nil { + t.Fatal(err) + } - columnTypes, _ := rows2.ColumnTypes() var b strings.Builder - for rows2.Next() { - for index, value := range columnTypes { - precision, scale, precisionOk := value.DecimalSize() - length, lengthOk := value.Length() - nullable, nullableOk := value.Nullable() - fmt.Fprintf(&b, "Col %d: DatabaseTypeName %q, DecimalSize %v %v %v, Length %v %v, Name %q, Nullable %v %v, ScanType %q\n", - index, - value.DatabaseTypeName(), - precision, scale, precisionOk, - length, lengthOk, - value.Name(), - nullable, nullableOk, - value.ScanType(), - ) - } + for index, value := range columnTypes { + precision, scale, precisionOk := value.DecimalSize() + length, lengthOk := value.Length() + nullable, nullableOk := value.Nullable() + fmt.Fprintf(&b, "Col %d: DatabaseTypeName %q, DecimalSize %v %v %v, Length %v %v, Name %q, Nullable %v %v, ScanType %q\n", + index, + value.DatabaseTypeName(), + precision, scale, precisionOk, + length, lengthOk, + value.Name(), + nullable, nullableOk, + value.ScanType(), + ) } if err := rows2.Err(); err != nil { t.Fatal(err) @@ -952,6 +1016,69 @@ Col 3: DatabaseTypeName "DATE", DecimalSize 0 0 false, Length 922337203685477580 t.Log(b.String()) } +// https://gitlab.com/cznic/sqlite/-/issues/32 +func TestColumnTypesNoRows(t *testing.T) { + tempDir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatal(err) + } + + defer func() { + os.RemoveAll(tempDir) + }() + + db, err := sql.Open("sqlite", filepath.Join(tempDir, "test.db")) + if err != nil { + t.Fatalf("test.db open fail: %v", err) + } + + defer db.Close() + + _, err = db.Exec("CREATE TABLE IF NOT EXISTS `userinfo` (`uid` INTEGER PRIMARY KEY AUTOINCREMENT,`username` VARCHAR(64) NULL, `departname` VARCHAR(64) NULL, `created` DATE NULL);") + if err != nil { + t.Fatal(err) + } + + rows2, err := db.Query("SELECT * FROM userinfo") + if err != nil { + t.Fatal(err) + } + defer rows2.Close() + + columnTypes, err := rows2.ColumnTypes() + if err != nil { + t.Fatal(err) + } + + var b strings.Builder + for index, value := range columnTypes { + precision, scale, precisionOk := value.DecimalSize() + length, lengthOk := value.Length() + nullable, nullableOk := value.Nullable() + fmt.Fprintf(&b, "Col %d: DatabaseTypeName %q, DecimalSize %v %v %v, Length %v %v, Name %q, Nullable %v %v, ScanType %q\n", + index, + value.DatabaseTypeName(), + precision, scale, precisionOk, + length, lengthOk, + value.Name(), + nullable, nullableOk, + value.ScanType(), + ) + } + if err := rows2.Err(); err != nil { + t.Fatal(err) + } + + if g, e := b.String(), `Col 0: DatabaseTypeName "INTEGER", DecimalSize 0 0 false, Length 0 false, Name "uid", Nullable true true, ScanType %!q() +Col 1: DatabaseTypeName "VARCHAR(64)", DecimalSize 0 0 false, Length 0 false, Name "username", Nullable true true, ScanType %!q() +Col 2: DatabaseTypeName "VARCHAR(64)", DecimalSize 0 0 false, Length 0 false, Name "departname", Nullable true true, ScanType %!q() +Col 3: DatabaseTypeName "DATE", DecimalSize 0 0 false, Length 0 false, Name "created", Nullable true true, ScanType %!q() +`; g != e { + t.Fatalf("---- got\n%s\n----expected\n%s", g, e) + } + t.Log(b.String()) +} + // https://gitlab.com/cznic/sqlite/-/issues/35 func TestTime(t *testing.T) { types := []string{ diff --git a/sqlite.go b/sqlite.go index c8f3111..0b6e054 100644 --- a/sqlite.go +++ b/sqlite.go @@ -34,7 +34,6 @@ var ( _ driver.Queryer = (*conn)(nil) _ driver.Result = (*result)(nil) _ driver.Rows = (*rows)(nil) - _ driver.Rows = (*noRows)(nil) _ driver.RowsColumnTypeDatabaseTypeName = (*rows)(nil) _ driver.RowsColumnTypeLength = (*rows)(nil) _ driver.RowsColumnTypeNullable = (*rows)(nil) @@ -188,14 +187,15 @@ type rows struct { } func newRows(c *conn, pstmt uintptr, allocs []uintptr, empty bool) (r *rows, err error) { + r = &rows{c: c, pstmt: pstmt, allocs: allocs, empty: empty} + defer func() { if err != nil { - c.finalize(pstmt) + r.Close() r = nil } }() - r = &rows{c: c, pstmt: pstmt, allocs: allocs} n, err := c.columnCount(pstmt) if err != nil { return nil, err @@ -630,7 +630,9 @@ func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Ro return nil case sqlite3.SQLITE_DONE: if r == nil { - r = &noRows{c: s.c, pstmt: pstmt, allocs: allocs} + if r, err = newRows(s.c, pstmt, allocs, true); err != nil { + return err + } pstmt = 0 return nil } @@ -663,23 +665,6 @@ func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Ro return r, err } -type noRows struct { - allocs []uintptr - c *conn - pstmt uintptr -} - -func (r *noRows) Columns() []string { return nil } -func (r *noRows) Next([]driver.Value) error { return io.EOF } - -func (r *noRows) Close() error { - for _, v := range r.allocs { - r.c.free(v) - } - r.allocs = nil - return r.c.finalize(r.pstmt) -} - type tx struct { c *conn }