fix forgotten TODO

This commit is contained in:
Jan Mercl
2020-01-01 12:05:39 +01:00
parent ef38ac9c3b
commit 141c3f22b7
2 changed files with 63 additions and 8 deletions

View File

@@ -21,6 +21,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
@@ -52,6 +53,23 @@ func dbg(s string, va ...interface{}) {
os.Stderr.Sync() os.Stderr.Sync()
} }
var traceLevel int32
func trace() func() {
return func() {}
n := atomic.AddInt32(&traceLevel, 1)
pc, file, line, _ := runtime.Caller(1)
s := strings.Repeat("· ", int(n)-1)
fn := runtime.FuncForPC(pc)
fmt.Fprintf(os.Stderr, "%s# trace %s:%d:%s: in\n", s, path.Base(file), line, fn.Name())
os.Stderr.Sync()
return func() {
atomic.AddInt32(&traceLevel, -1)
fmt.Fprintf(os.Stderr, "%s# trace %s:%d:%s: out\n", s, path.Base(file), line, fn.Name())
os.Stderr.Sync()
}
}
func TODO(...interface{}) string { //TODOOK func TODO(...interface{}) string { //TODOOK
_, fn, fl, _ := runtime.Caller(1) _, fn, fl, _ := runtime.Caller(1)
return fmt.Sprintf("# TODO: %s:%d:\n", path.Base(fn), fl) //TODOOK return fmt.Sprintf("# TODO: %s:%d:\n", path.Base(fn), fl) //TODOOK
@@ -62,7 +80,7 @@ func stack() string { return string(debug.Stack()) }
func use(...interface{}) {} func use(...interface{}) {}
func init() { func init() {
use(caller, dbg, TODO) //TODOOK use(caller, dbg, TODO, trace) //TODOOK
} }
// ============================================================================ // ============================================================================
@@ -377,6 +395,8 @@ func TestConcurrentGoroutines(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
defer db.Close()
tx, err := db.BeginTx(context.Background(), nil) tx, err := db.BeginTx(context.Background(), nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -681,6 +701,8 @@ func TestIssue20(t *testing.T) {
t.Fatalf("foo.db open fail: %v", err) t.Fatalf("foo.db open fail: %v", err)
} }
defer db.Close()
mustExec(t, db, "CREATE TABLE "+TablePrefix+"t (count INT)") mustExec(t, db, "CREATE TABLE "+TablePrefix+"t (count INT)")
sel, err := db.PrepareContext(context.Background(), "SELECT count FROM "+TablePrefix+"t ORDER BY count DESC") sel, err := db.PrepareContext(context.Background(), "SELECT count FROM "+TablePrefix+"t ORDER BY count DESC")
if err != nil { if err != nil {
@@ -723,3 +745,28 @@ func TestIssue20(t *testing.T) {
<-ch <-ch
} }
} }
func TestNoRows(t *testing.T) {
tempDir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tempDir)
db, err := sql.Open("sqlite", filepath.Join(tempDir, "foo.db"))
if err != nil {
t.Fatalf("foo.db open fail: %v", err)
}
defer db.Close()
stmt, err := db.Prepare("create table t(i);")
if err != nil {
t.Fatal(err)
}
if _, err := stmt.Query(); err != nil {
t.Fatal(err)
}
}

View File

@@ -171,9 +171,10 @@ type rows struct {
pstmt crt.Intptr pstmt crt.Intptr
doStep bool doStep bool
empty bool
} }
func newRows(c *conn, pstmt crt.Intptr, allocs []crt.Intptr) (r *rows, err error) { func newRows(c *conn, pstmt crt.Intptr, allocs []crt.Intptr, empty bool) (r *rows, err error) {
defer func() { defer func() {
if err != nil { if err != nil {
c.finalize(pstmt) c.finalize(pstmt)
@@ -218,6 +219,10 @@ func (r *rows) Columns() (c []string) {
// //
// Next should return io.EOF when there are no more rows. // Next should return io.EOF when there are no more rows.
func (r *rows) Next(dest []driver.Value) (err error) { func (r *rows) Next(dest []driver.Value) (err error) {
if r.empty {
return io.EOF
}
rc := bin.DSQLITE_ROW rc := bin.DSQLITE_ROW
if r.doStep { if r.doStep {
if rc, err = r.c.step(r.pstmt); err != nil { if rc, err = r.c.step(r.pstmt); err != nil {
@@ -480,7 +485,7 @@ func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Ro
switch rc & 0xff { switch rc & 0xff {
case bin.DSQLITE_ROW: case bin.DSQLITE_ROW:
if r, err = newRows(s.c, pstmt, allocs); err != nil { if r, err = newRows(s.c, pstmt, allocs, false); err != nil {
return err return err
} }
@@ -492,16 +497,19 @@ func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Ro
return s.c.errstr(int32(rc)) return s.c.errstr(int32(rc))
} }
if *(*byte)(unsafe.Pointer(uintptr(psql))) == 0 {
if r, err = newRows(s.c, pstmt, allocs, true); err != nil {
return err
}
pstmt = 0
}
return nil return nil
}(); err != nil { }(); err != nil {
return nil, err return nil, err
} }
} }
if r != nil { return r, err
return r, nil
}
panic("TODO")
} }
type tx struct { type tx struct {