Fix race condition if exec's context is canceled just after completion

This commit is contained in:
Matthew Gabeler-Lee
2022-03-15 11:22:26 +00:00
committed by glebarez
parent 870db7651a
commit 327c7779d4
2 changed files with 123 additions and 45 deletions

View File

@@ -2204,3 +2204,86 @@ func TestBeginMode(t *testing.T) {
}
}
}
// https://gitlab.com/cznic/sqlite/-/issues/94
func TestCancelRace(t *testing.T) {
tempDir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
defer func() {
os.RemoveAll(tempDir)
}()
db, err := sql.Open("sqlite", filepath.Join(tempDir, "testcancelrace.sqlite"))
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
tests := []struct {
name string
f func(context.Context, *sql.DB) error
}{
{
"db.ExecContext",
func(ctx context.Context, d *sql.DB) error {
_, err := db.ExecContext(ctx, "select 1")
return err
},
},
{
"db.QueryContext",
func(ctx context.Context, d *sql.DB) error {
_, err := db.QueryContext(ctx, "select 1")
return err
},
},
{
"tx.ExecContext",
func(ctx context.Context, d *sql.DB) error {
tx, err := db.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.ExecContext(ctx, "select 1"); err != nil {
return err
}
return tx.Rollback()
},
},
{
"tx.QueryContext",
func(ctx context.Context, d *sql.DB) error {
tx, err := db.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.QueryContext(ctx, "select 1"); err != nil {
return err
}
return tx.Rollback()
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// this is a race condition, so it's not guaranteed to fail on any given run,
// but with a moderate number of iterations it will eventually catch it
iterations := 100
for i := 0; i < iterations; i++ {
// none of these iterations should ever fail, because we never cancel their
// context until after they complete
ctx, cancel := context.WithCancel(context.Background())
if err := tt.f(ctx, db); err != nil {
t.Fatalf("Failed to run test query on iteration %d: %v", i, err)
}
cancel()
}
})
}
}

View File

@@ -498,20 +498,7 @@ func (s *stmt) exec(ctx context.Context, args []driver.NamedValue) (r driver.Res
var pstmt uintptr
var done int32
if ctx != nil && ctx.Done() != nil {
donech := make(chan struct{})
go func() {
select {
case <-ctx.Done():
atomic.AddInt32(&done, 1)
s.c.interrupt(s.c.db)
case <-donech:
}
}()
defer func() {
close(donech)
}()
defer interruptOnDone(ctx, s.c, &done)()
}
for psql := s.psql; *(*byte)(unsafe.Pointer(psql)) != 0 && atomic.LoadInt32(&done) == 0; {
@@ -599,24 +586,7 @@ func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Ro
// context honoring
if ctx != nil && ctx.Done() != nil {
donech := make(chan struct{})
go func() {
select {
case <-ctx.Done():
// set done indicator
atomic.AddInt32(&done, 1)
// interrupt in-fly queries
s.c.interrupt(s.c.db)
case <-donech:
}
}()
// stop context monitoring at exit
defer func() {
close(donech)
}()
defer interruptOnDone(ctx, s.c, &done)()
}
// generally, query may contain multiple SQL statements
@@ -759,19 +729,7 @@ func (t *tx) exec(ctx context.Context, sql string) (err error) {
//TODO use t.conn.ExecContext() instead
if ctx != nil && ctx.Done() != nil {
donech := make(chan struct{})
go func() {
select {
case <-ctx.Done():
t.c.interrupt(t.c.db)
case <-donech:
}
}()
defer func() {
close(donech)
}()
defer interruptOnDone(ctx, t.c, nil)()
}
if rc := sqlite3.Xsqlite3_exec(t.c.tls, t.c.db, psql, 0, 0, 0); rc != sqlite3.SQLITE_OK {
@@ -781,6 +739,43 @@ func (t *tx) exec(ctx context.Context, sql string) (err error) {
return nil
}
// interruptOnDone sets up a goroutine to interrupt the provided db when the
// context is canceled, and returns a function the caller must defer so it
// doesn't interrupt after the caller finishes.
func interruptOnDone(
ctx context.Context,
c *conn,
done *int32,
) func() {
if done == nil {
var d int32
done = &d
}
donech := make(chan struct{})
go func() {
select {
case <-ctx.Done():
// don't call interrupt if we were already done: it indicates that this
// call to exec is no longer running and we would be interrupting
// nothing, or even possibly an unrelated later call to exec.
if atomic.AddInt32(done, 1) == 1 {
c.interrupt(c.db)
}
case <-donech:
}
}()
// the caller is expected to defer this function
return func() {
// set the done flag so that a context cancellation right after the caller
// returns doesn't trigger a call to interrupt for some other statement.
atomic.AddInt32(done, 1)
close(donech)
}
}
type conn struct {
db uintptr // *sqlite3.Xsqlite3
tls *libc.TLS