diff --git a/all_test.go b/all_test.go index fdadde5..cf0aff8 100644 --- a/all_test.go +++ b/all_test.go @@ -2204,3 +2204,86 @@ func TestBeginMode(t *testing.T) { } } } + +// https://gitlab.com/cznic/sqlite/-/issues/94 +func TestCancelRace(t *testing.T) { + tempDir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatal(err) + } + + defer func() { + os.RemoveAll(tempDir) + }() + + db, err := sql.Open("sqlite", filepath.Join(tempDir, "testcancelrace.sqlite")) + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + tests := []struct { + name string + f func(context.Context, *sql.DB) error + }{ + { + "db.ExecContext", + func(ctx context.Context, d *sql.DB) error { + _, err := db.ExecContext(ctx, "select 1") + return err + }, + }, + { + "db.QueryContext", + func(ctx context.Context, d *sql.DB) error { + _, err := db.QueryContext(ctx, "select 1") + return err + }, + }, + { + "tx.ExecContext", + func(ctx context.Context, d *sql.DB) error { + tx, err := db.BeginTx(ctx, &sql.TxOptions{}) + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.ExecContext(ctx, "select 1"); err != nil { + return err + } + return tx.Rollback() + }, + }, + { + "tx.QueryContext", + func(ctx context.Context, d *sql.DB) error { + tx, err := db.BeginTx(ctx, &sql.TxOptions{}) + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.QueryContext(ctx, "select 1"); err != nil { + return err + } + return tx.Rollback() + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // this is a race condition, so it's not guaranteed to fail on any given run, + // but with a moderate number of iterations it will eventually catch it + iterations := 100 + for i := 0; i < iterations; i++ { + // none of these iterations should ever fail, because we never cancel their + // context until after they complete + ctx, cancel := context.WithCancel(context.Background()) + if err := tt.f(ctx, db); err != nil { + t.Fatalf("Failed to run test query on iteration %d: %v", i, err) + } + cancel() + } + }) + } +} diff --git a/sqlite.go b/sqlite.go index e07d859..146583e 100644 --- a/sqlite.go +++ b/sqlite.go @@ -498,20 +498,7 @@ func (s *stmt) exec(ctx context.Context, args []driver.NamedValue) (r driver.Res var pstmt uintptr var done int32 if ctx != nil && ctx.Done() != nil { - donech := make(chan struct{}) - - go func() { - select { - case <-ctx.Done(): - atomic.AddInt32(&done, 1) - s.c.interrupt(s.c.db) - case <-donech: - } - }() - - defer func() { - close(donech) - }() + defer interruptOnDone(ctx, s.c, &done)() } for psql := s.psql; *(*byte)(unsafe.Pointer(psql)) != 0 && atomic.LoadInt32(&done) == 0; { @@ -599,24 +586,7 @@ func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Ro // context honoring if ctx != nil && ctx.Done() != nil { - donech := make(chan struct{}) - - go func() { - select { - case <-ctx.Done(): - // set done indicator - atomic.AddInt32(&done, 1) - - // interrupt in-fly queries - s.c.interrupt(s.c.db) - case <-donech: - } - }() - - // stop context monitoring at exit - defer func() { - close(donech) - }() + defer interruptOnDone(ctx, s.c, &done)() } // generally, query may contain multiple SQL statements @@ -759,19 +729,7 @@ func (t *tx) exec(ctx context.Context, sql string) (err error) { //TODO use t.conn.ExecContext() instead if ctx != nil && ctx.Done() != nil { - donech := make(chan struct{}) - - go func() { - select { - case <-ctx.Done(): - t.c.interrupt(t.c.db) - case <-donech: - } - }() - - defer func() { - close(donech) - }() + defer interruptOnDone(ctx, t.c, nil)() } if rc := sqlite3.Xsqlite3_exec(t.c.tls, t.c.db, psql, 0, 0, 0); rc != sqlite3.SQLITE_OK { @@ -781,6 +739,43 @@ func (t *tx) exec(ctx context.Context, sql string) (err error) { return nil } +// interruptOnDone sets up a goroutine to interrupt the provided db when the +// context is canceled, and returns a function the caller must defer so it +// doesn't interrupt after the caller finishes. +func interruptOnDone( + ctx context.Context, + c *conn, + done *int32, +) func() { + if done == nil { + var d int32 + done = &d + } + + donech := make(chan struct{}) + + go func() { + select { + case <-ctx.Done(): + // don't call interrupt if we were already done: it indicates that this + // call to exec is no longer running and we would be interrupting + // nothing, or even possibly an unrelated later call to exec. + if atomic.AddInt32(done, 1) == 1 { + c.interrupt(c.db) + } + case <-donech: + } + }() + + // the caller is expected to defer this function + return func() { + // set the done flag so that a context cancellation right after the caller + // returns doesn't trigger a call to interrupt for some other statement. + atomic.AddInt32(done, 1) + close(donech) + } +} + type conn struct { db uintptr // *sqlite3.Xsqlite3 tls *libc.TLS