mirror of
https://github.com/glebarez/go-sqlite.git
synced 2025-10-04 15:32:46 +08:00
Fix race condition if exec's context is canceled just after completion
This commit is contained in:

committed by
glebarez

parent
870db7651a
commit
327c7779d4
83
all_test.go
83
all_test.go
@@ -2204,3 +2204,86 @@ func TestBeginMode(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://gitlab.com/cznic/sqlite/-/issues/94
|
||||||
|
func TestCancelRace(t *testing.T) {
|
||||||
|
tempDir, err := ioutil.TempDir("", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
os.RemoveAll(tempDir)
|
||||||
|
}()
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite", filepath.Join(tempDir, "testcancelrace.sqlite"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
f func(context.Context, *sql.DB) error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"db.ExecContext",
|
||||||
|
func(ctx context.Context, d *sql.DB) error {
|
||||||
|
_, err := db.ExecContext(ctx, "select 1")
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"db.QueryContext",
|
||||||
|
func(ctx context.Context, d *sql.DB) error {
|
||||||
|
_, err := db.QueryContext(ctx, "select 1")
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tx.ExecContext",
|
||||||
|
func(ctx context.Context, d *sql.DB) error {
|
||||||
|
tx, err := db.BeginTx(ctx, &sql.TxOptions{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
if _, err := tx.ExecContext(ctx, "select 1"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Rollback()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tx.QueryContext",
|
||||||
|
func(ctx context.Context, d *sql.DB) error {
|
||||||
|
tx, err := db.BeginTx(ctx, &sql.TxOptions{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
if _, err := tx.QueryContext(ctx, "select 1"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Rollback()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// this is a race condition, so it's not guaranteed to fail on any given run,
|
||||||
|
// but with a moderate number of iterations it will eventually catch it
|
||||||
|
iterations := 100
|
||||||
|
for i := 0; i < iterations; i++ {
|
||||||
|
// none of these iterations should ever fail, because we never cancel their
|
||||||
|
// context until after they complete
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
if err := tt.f(ctx, db); err != nil {
|
||||||
|
t.Fatalf("Failed to run test query on iteration %d: %v", i, err)
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
85
sqlite.go
85
sqlite.go
@@ -498,20 +498,7 @@ func (s *stmt) exec(ctx context.Context, args []driver.NamedValue) (r driver.Res
|
|||||||
var pstmt uintptr
|
var pstmt uintptr
|
||||||
var done int32
|
var done int32
|
||||||
if ctx != nil && ctx.Done() != nil {
|
if ctx != nil && ctx.Done() != nil {
|
||||||
donech := make(chan struct{})
|
defer interruptOnDone(ctx, s.c, &done)()
|
||||||
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
atomic.AddInt32(&done, 1)
|
|
||||||
s.c.interrupt(s.c.db)
|
|
||||||
case <-donech:
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
close(donech)
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for psql := s.psql; *(*byte)(unsafe.Pointer(psql)) != 0 && atomic.LoadInt32(&done) == 0; {
|
for psql := s.psql; *(*byte)(unsafe.Pointer(psql)) != 0 && atomic.LoadInt32(&done) == 0; {
|
||||||
@@ -599,24 +586,7 @@ func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Ro
|
|||||||
|
|
||||||
// context honoring
|
// context honoring
|
||||||
if ctx != nil && ctx.Done() != nil {
|
if ctx != nil && ctx.Done() != nil {
|
||||||
donech := make(chan struct{})
|
defer interruptOnDone(ctx, s.c, &done)()
|
||||||
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
// set done indicator
|
|
||||||
atomic.AddInt32(&done, 1)
|
|
||||||
|
|
||||||
// interrupt in-fly queries
|
|
||||||
s.c.interrupt(s.c.db)
|
|
||||||
case <-donech:
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// stop context monitoring at exit
|
|
||||||
defer func() {
|
|
||||||
close(donech)
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// generally, query may contain multiple SQL statements
|
// generally, query may contain multiple SQL statements
|
||||||
@@ -759,19 +729,7 @@ func (t *tx) exec(ctx context.Context, sql string) (err error) {
|
|||||||
//TODO use t.conn.ExecContext() instead
|
//TODO use t.conn.ExecContext() instead
|
||||||
|
|
||||||
if ctx != nil && ctx.Done() != nil {
|
if ctx != nil && ctx.Done() != nil {
|
||||||
donech := make(chan struct{})
|
defer interruptOnDone(ctx, t.c, nil)()
|
||||||
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
t.c.interrupt(t.c.db)
|
|
||||||
case <-donech:
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
close(donech)
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if rc := sqlite3.Xsqlite3_exec(t.c.tls, t.c.db, psql, 0, 0, 0); rc != sqlite3.SQLITE_OK {
|
if rc := sqlite3.Xsqlite3_exec(t.c.tls, t.c.db, psql, 0, 0, 0); rc != sqlite3.SQLITE_OK {
|
||||||
@@ -781,6 +739,43 @@ func (t *tx) exec(ctx context.Context, sql string) (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// interruptOnDone sets up a goroutine to interrupt the provided db when the
|
||||||
|
// context is canceled, and returns a function the caller must defer so it
|
||||||
|
// doesn't interrupt after the caller finishes.
|
||||||
|
func interruptOnDone(
|
||||||
|
ctx context.Context,
|
||||||
|
c *conn,
|
||||||
|
done *int32,
|
||||||
|
) func() {
|
||||||
|
if done == nil {
|
||||||
|
var d int32
|
||||||
|
done = &d
|
||||||
|
}
|
||||||
|
|
||||||
|
donech := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
// don't call interrupt if we were already done: it indicates that this
|
||||||
|
// call to exec is no longer running and we would be interrupting
|
||||||
|
// nothing, or even possibly an unrelated later call to exec.
|
||||||
|
if atomic.AddInt32(done, 1) == 1 {
|
||||||
|
c.interrupt(c.db)
|
||||||
|
}
|
||||||
|
case <-donech:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// the caller is expected to defer this function
|
||||||
|
return func() {
|
||||||
|
// set the done flag so that a context cancellation right after the caller
|
||||||
|
// returns doesn't trigger a call to interrupt for some other statement.
|
||||||
|
atomic.AddInt32(done, 1)
|
||||||
|
close(donech)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type conn struct {
|
type conn struct {
|
||||||
db uintptr // *sqlite3.Xsqlite3
|
db uintptr // *sqlite3.Xsqlite3
|
||||||
tls *libc.TLS
|
tls *libc.TLS
|
||||||
|
Reference in New Issue
Block a user