db: add Connection

This commit is contained in:
Asdine El Hrychy
2024-02-18 12:55:31 +04:00
parent 5095097a0a
commit 6bc4992d70
32 changed files with 672 additions and 448 deletions

View File

@@ -31,12 +31,17 @@ type execer func(q string, args ...interface{}) error
// If tables is provided, only selected tables will be outputted.
func Bench(db *chai.DB, query string, opt BenchOptions) error {
var tx *chai.Tx
var p preparer = db
conn, err := db.Connect()
if err != nil {
return err
}
defer conn.Close()
var p preparer = conn
var e execer = db.Exec
var err error
if opt.SameTx {
tx, err = db.Begin(true)
tx, err = conn.Begin(true)
if err != nil {
return err
}

View File

@@ -12,7 +12,13 @@ import (
// Dump takes a database and dumps its content as SQL queries in the given writer.
// If tables is provided, only selected tables will be outputted.
func Dump(db *chai.DB, w io.Writer, tables ...string) error {
tx, err := db.Begin(false)
conn, err := db.Connect()
if err != nil {
return err
}
defer conn.Close()
tx, err := conn.Begin(false)
if err != nil {
return err
}
@@ -97,7 +103,13 @@ func dumpTable(tx *chai.Tx, w io.Writer, query, tableName string) error {
// DumpSchema takes a database and dumps its schema as SQL queries in the given writer.
// If tables are provided, only selected tables will be outputted.
func DumpSchema(db *chai.DB, w io.Writer, tables ...string) error {
tx, err := db.Begin(false)
conn, err := db.Connect()
if err != nil {
return err
}
defer conn.Close()
tx, err := conn.Begin(false)
if err != nil {
return err
}

View File

@@ -19,11 +19,18 @@ func ExecSQL(ctx context.Context, db *chai.DB, r io.Reader, w io.Writer) error {
enc.SetEscapeHTML(false)
enc.SetIndent("", " ")
conn, err := db.Connect()
if err != nil {
return err
}
defer conn.Close()
return parser.NewParser(r).Parse(func(s statement.Statement) error {
qq := query.New(s)
qctx := query.Context{
Ctx: ctx,
DB: db.DB,
Ctx: ctx,
DB: db.DB,
Conn: conn.Conn,
}
err := qq.Prepare(&qctx)
if err != nil {

View File

@@ -50,7 +50,13 @@ func ListIndexes(db *chai.DB, tableName string) ([]string, error) {
if tableName != "" {
q += " AND owner_table_name = ?"
}
res, err := db.Query(q, tableName)
conn, err := db.Connect()
if err != nil {
return nil, err
}
defer conn.Close()
res, err := conn.Query(q, tableName)
if err != nil {
return nil, err
}

View File

@@ -115,7 +115,13 @@ func runHelpCmd(out io.Writer) error {
// runTablesCmd displays all tables.
func runTablesCmd(db *chai.DB, w io.Writer) error {
res, err := db.Query("SELECT name FROM __chai_catalog WHERE type = 'table' AND name NOT LIKE '__chai_%'")
conn, err := db.Connect()
if err != nil {
return err
}
defer conn.Close()
res, err := conn.Query("SELECT name FROM __chai_catalog WHERE type = 'table' AND name NOT LIKE '__chai_%'")
if err != nil {
return err
}
@@ -195,7 +201,13 @@ func runImportCmd(db *chai.DB, fileType, path, table string) error {
}
defer f.Close()
tx, err := db.Begin(true)
conn, err := db.Connect()
if err != nil {
return err
}
defer conn.Close()
tx, err := conn.Begin(true)
if err != nil {
return err
}

152
db.go
View File

@@ -47,21 +47,64 @@ func Open(path string) (*DB, error) {
}, nil
}
func (db *DB) Connect() (*Connection, error) {
conn, err := db.DB.Connect()
if err != nil {
return nil, err
}
return &Connection{
db: db,
Conn: conn,
}, nil
}
// WithContext creates a new database handle using the given context for every operation.
func (db DB) WithContext(ctx context.Context) *DB {
db.ctx = ctx
return &db
}
func (db *DB) withConn(fn func(*Connection) error) error {
conn, err := db.Connect()
if err != nil {
return err
}
defer conn.Close()
return fn(conn)
}
// QueryRow runs the query and returns the first row.
func (db *DB) QueryRow(q string, args ...any) (r *Row, err error) {
err = db.withConn(func(c *Connection) error {
r, err = c.QueryRow(q, args...)
return err
})
return
}
// Exec a query against the database without returning the result.
func (db *DB) Exec(q string, args ...any) error {
return db.withConn(func(c *Connection) error {
return c.Exec(q, args...)
})
}
// Close the database.
func (db *DB) Close() error {
return db.DB.Close()
}
type Connection struct {
db *DB
Conn *database.Connection
}
// Begin starts a new transaction.
// The returned transaction must be closed either by calling Rollback or Commit.
func (db *DB) Begin(writable bool) (*Tx, error) {
tx, err := db.DB.BeginTx(&database.TxOptions{
func (c *Connection) Begin(writable bool) (*Tx, error) {
_, err := c.Conn.BeginTx(&database.TxOptions{
ReadOnly: !writable,
})
if err != nil {
@@ -69,14 +112,13 @@ func (db *DB) Begin(writable bool) (*Tx, error) {
}
return &Tx{
db: db,
tx: tx,
conn: c,
}, nil
}
// View starts a read only transaction, runs fn and automatically rolls it back.
func (db *DB) View(fn func(tx *Tx) error) error {
tx, err := db.Begin(false)
func (c *Connection) View(fn func(tx *Tx) error) error {
tx, err := c.Begin(false)
if err != nil {
return err
}
@@ -86,8 +128,8 @@ func (db *DB) View(fn func(tx *Tx) error) error {
}
// Update starts a read-write transaction, runs fn and automatically commits it.
func (db *DB) Update(fn func(tx *Tx) error) error {
tx, err := db.Begin(true)
func (c *Connection) Update(fn func(tx *Tx) error) error {
tx, err := c.Begin(true)
if err != nil {
return err
}
@@ -103,18 +145,25 @@ func (db *DB) Update(fn func(tx *Tx) error) error {
// Query the database and return the result.
// The returned result must always be closed after usage.
func (db *DB) Query(q string, args ...any) (*Result, error) {
stmt, err := db.Prepare(q)
func (c *Connection) Query(q string, args ...any) (*Result, error) {
stmt, err := c.Prepare(q)
if err != nil {
return nil, err
}
return stmt.Query(args...)
res, err := stmt.Query(args...)
if err != nil {
return nil, err
}
res.conn = c
return res, nil
}
// QueryRow runs the query and returns the first row.
func (db *DB) QueryRow(q string, args ...any) (*Row, error) {
stmt, err := db.Prepare(q)
func (c *Connection) QueryRow(q string, args ...any) (*Row, error) {
stmt, err := c.Prepare(q)
if err != nil {
return nil, err
}
@@ -123,8 +172,8 @@ func (db *DB) QueryRow(q string, args ...any) (*Row, error) {
}
// Exec a query against the database without returning the result.
func (db *DB) Exec(q string, args ...any) error {
stmt, err := db.Prepare(q)
func (c *Connection) Exec(q string, args ...any) error {
stmt, err := c.Prepare(q)
if err != nil {
return err
}
@@ -133,41 +182,54 @@ func (db *DB) Exec(q string, args ...any) error {
}
// Prepare parses the query and returns a prepared statement.
func (db *DB) Prepare(q string) (*Statement, error) {
func (c *Connection) Prepare(q string) (*Statement, error) {
pq, err := parser.ParseQuery(q)
if err != nil {
return nil, err
}
err = pq.Prepare(newQueryContext(db, nil, nil))
err = pq.Prepare(newQueryContext(c, nil))
if err != nil {
return nil, err
}
return &Statement{
pq: pq,
db: db,
pq: pq,
conn: c,
}, nil
}
func (c *Connection) Close() error {
return c.Conn.Close()
}
// Tx represents a database transaction. It provides methods for managing the
// collection of tables and the transaction itself.
// Tx is either read-only or read/write. Read-only can be used to read tables
// and read/write can be used to read, create, delete and modify tables.
type Tx struct {
db *DB
tx *database.Transaction
conn *Connection
}
// Rollback the transaction. Can be used safely after commit.
func (tx *Tx) Rollback() error {
return tx.tx.Rollback()
t := tx.conn.Conn.GetTx()
if t == nil {
return errors.New("transaction has already been committed or rolled back")
}
return t.Rollback()
}
// Commit the transaction. Calling this method on read-only transactions
// will return an error.
func (tx *Tx) Commit() error {
return tx.tx.Commit()
t := tx.conn.Conn.GetTx()
if t == nil {
return errors.New("transaction has already been committed or rolled back")
}
return t.Commit()
}
// Query the database withing the transaction and returns the result.
@@ -208,15 +270,15 @@ func (tx *Tx) Prepare(q string) (*Statement, error) {
return nil, err
}
err = pq.Prepare(newQueryContext(tx.db, tx, nil))
err = pq.Prepare(newQueryContext(tx.conn, nil))
if err != nil {
return nil, err
}
return &Statement{
pq: pq,
db: tx.db,
tx: tx,
pq: pq,
conn: tx.conn,
tx: tx,
}, nil
}
@@ -225,9 +287,9 @@ func (tx *Tx) Prepare(q string) (*Statement, error) {
// is valid until the DB closes.
// It's safe for concurrent use by multiple goroutines.
type Statement struct {
pq query.Query
db *DB
tx *Tx
pq query.Query
conn *Connection
tx *Tx
}
// Query the database and return the result.
@@ -236,12 +298,12 @@ func (s *Statement) Query(args ...any) (*Result, error) {
var r *statement.Result
var err error
r, err = s.pq.Run(newQueryContext(s.db, s.tx, argsToParams(args)))
r, err = s.pq.Run(newQueryContext(s.conn, argsToParams(args)))
if err != nil {
return nil, err
}
return &Result{result: r, ctx: s.db.ctx}, nil
return &Result{result: r, ctx: s.conn.db.ctx}, nil
}
func argsToParams(args []interface{}) []environment.Param {
@@ -310,6 +372,7 @@ func (s *Statement) Exec(args ...any) (err error) {
type Result struct {
result *statement.Result
ctx context.Context
conn *Connection
}
func (r *Result) Iterate(fn func(r *Row) error) error {
@@ -366,12 +429,12 @@ func (r *Result) Columns() []string {
break
}
fields := make([]string, len(po.Exprs))
columns := make([]string, len(po.Exprs))
for i := range po.Exprs {
fields[i] = po.Exprs[i].String()
columns[i] = po.Exprs[i].String()
}
return fields
return columns
}
}
@@ -385,7 +448,9 @@ func (r *Result) Close() (err error) {
return nil
}
return r.result.Close()
err = r.result.Close()
return err
}
func (r *Result) MarshalJSON() ([]byte, error) {
@@ -427,18 +492,13 @@ func (r *Result) MarshalJSONTo(w io.Writer) error {
return buf.Flush()
}
func newQueryContext(db *DB, tx *Tx, params []environment.Param) *query.Context {
ctx := query.Context{
Ctx: db.ctx,
DB: db.DB,
func newQueryContext(conn *Connection, params []environment.Param) *query.Context {
return &query.Context{
Ctx: conn.db.ctx,
DB: conn.db.DB,
Conn: conn.Conn,
Params: params,
}
if tx != nil {
ctx.Tx = tx.tx
}
return &ctx
}
type Row struct {

View File

@@ -3,7 +3,6 @@ package chai_test
import (
"context"
"fmt"
"log"
"os"
"path/filepath"
"testing"
@@ -17,11 +16,17 @@ import (
func ExampleTx() {
db, err := chai.Open(":memory:")
if err != nil {
log.Fatal(err)
panic(err)
}
defer db.Close()
tx, err := db.Begin(true)
conn, err := db.Connect()
if err != nil {
panic(err)
}
defer conn.Close()
tx, err := conn.Begin(true)
if err != nil {
panic(err)
}
@@ -98,7 +103,11 @@ func TestOpen(t *testing.T) {
require.NoError(t, err)
defer db.Close()
res1, err := db.Query("SELECT * FROM __chai_catalog")
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
res1, err := conn.Query("SELECT * FROM __chai_catalog")
require.NoError(t, err)
defer res1.Close()
@@ -144,7 +153,11 @@ func TestQueryRow(t *testing.T) {
require.NoError(t, err)
require.NoError(t, err)
tx, err := db.Begin(true)
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
tx, err := conn.Begin(true)
require.NoError(t, err)
err = tx.Exec(`
@@ -165,7 +178,11 @@ func TestQueryRow(t *testing.T) {
require.Equal(t, 1, a)
require.Equal(t, "foo", b)
tx, err := db.Begin(false)
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
tx, err := conn.Begin(false)
require.NoError(t, err)
defer tx.Rollback()
@@ -182,7 +199,11 @@ func TestQueryRow(t *testing.T) {
require.True(t, chai.IsNotFoundError(err))
require.Nil(t, r)
tx, err := db.Begin(false)
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
tx, err := conn.Begin(false)
require.NoError(t, err)
defer tx.Rollback()
r, err = tx.QueryRow("SELECT * FROM test WHERE a > 100")
@@ -196,10 +217,14 @@ func TestPrepareThreadSafe(t *testing.T) {
require.NoError(t, err)
defer db.Close()
err = db.Exec("CREATE TABLE test(a int unique, b text); INSERT INTO test(a, b) VALUES (1, 'a'), (2, 'a')")
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
err = conn.Exec("CREATE TABLE test(a int unique, b text); INSERT INTO test(a, b) VALUES (1, 'a'), (2, 'a')")
require.NoError(t, err)
stmt, err := db.Prepare("SELECT COUNT(a) FROM test WHERE a < ? GROUP BY b ORDER BY a DESC LIMIT 5")
stmt, err := conn.Prepare("SELECT COUNT(a) FROM test WHERE a < ? GROUP BY b ORDER BY a DESC LIMIT 5")
require.NoError(t, err)
g, _ := errgroup.WithContext(context.Background())
@@ -240,7 +265,11 @@ func TestIterateDeepCopy(t *testing.T) {
`)
require.NoError(t, err)
res, err := db.Query(`SELECT * FROM foo ORDER BY a DESC`)
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
res, err := conn.Query(`SELECT * FROM foo ORDER BY a DESC`)
require.NoError(t, err)
type item struct {
@@ -263,115 +292,3 @@ func TestIterateDeepCopy(t *testing.T) {
require.Equal(t, &item{A: 2, B: "sample text 2"}, items[0])
require.Equal(t, &item{A: 1, B: "sample text 1"}, items[1])
}
func BenchmarkSelect(b *testing.B) {
for size := 1; size <= 10000; size *= 10 {
b.Run(fmt.Sprintf("%.05d", size), func(b *testing.B) {
db, err := chai.Open(":memory:")
require.NoError(b, err)
err = db.Exec("CREATE TABLE foo")
require.NoError(b, err)
for i := 0; i < size; i++ {
err = db.Exec("INSERT INTO foo(a, b) VALUES (1, 2);")
require.NoError(b, err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
res, _ := db.Query("SELECT * FROM foo")
res.Iterate(func(d *chai.Row) error { return nil })
}
})
}
}
func BenchmarkSelectWhere(b *testing.B) {
for size := 1; size <= 10000; size *= 10 {
b.Run(fmt.Sprintf("%.05d", size), func(b *testing.B) {
db, err := chai.Open(":memory:")
require.NoError(b, err)
err = db.Exec("CREATE TABLE foo")
require.NoError(b, err)
for i := 0; i < size; i++ {
err = db.Exec("INSERT INTO foo(a, b) VALUES (1, 2);")
require.NoError(b, err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
res, _ := db.Query("SELECT b FROM foo WHERE a > 0")
res.Iterate(func(d *chai.Row) error { return nil })
}
})
}
}
func BenchmarkPreparedSelectWhere(b *testing.B) {
for size := 1; size <= 10000; size *= 10 {
b.Run(fmt.Sprintf("%.05d", size), func(b *testing.B) {
db, err := chai.Open(":memory:")
require.NoError(b, err)
err = db.Exec("CREATE TABLE foo")
require.NoError(b, err)
for i := 0; i < size; i++ {
err = db.Exec("INSERT INTO foo(a, b) VALUES (1, 2);")
require.NoError(b, err)
}
p, _ := db.Prepare("SELECT b FROM foo WHERE a > 0")
b.ResetTimer()
for i := 0; i < b.N; i++ {
res, _ := p.Query()
res.Iterate(func(d *chai.Row) error { return nil })
}
})
}
}
func BenchmarkSelectPk(b *testing.B) {
for size := 1; size <= 10000; size *= 10 {
b.Run(fmt.Sprintf("%.05d", size), func(b *testing.B) {
db, err := chai.Open(":memory:")
require.NoError(b, err)
err = db.Exec("CREATE TABLE foo(a INT PRIMARY KEY)")
require.NoError(b, err)
for i := 0; i < size; i++ {
err = db.Exec("INSERT INTO foo(a) VALUES (?)", i)
require.NoError(b, err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
res, _ := db.Query("SELECT * FROM foo WHERE a = ?", size-1)
res.Iterate(func(d *chai.Row) error { return nil })
}
})
}
}
func BenchmarkInsert(b *testing.B) {
db, err := chai.Open(b.TempDir())
require.NoError(b, err)
defer db.Close()
err = db.Exec("CREATE TABLE foo(a INT)")
require.NoError(b, err)
stmt, err := db.Prepare("INSERT INTO foo(a) VALUES (?)")
require.NoError(b, err)
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < 100; j++ {
stmt.Exec(j)
}
}
}

View File

@@ -5,7 +5,6 @@ import (
"database/sql"
"database/sql/driver"
"io"
"runtime"
"sync"
"time"
@@ -39,13 +38,10 @@ func (d sqlDriver) OpenConnector(name string) (driver.Connector, error) {
return nil, err
}
c := &connector{
return &connector{
db: db,
driver: d,
}
runtime.SetFinalizer(c, (*connector).Close)
return c, nil
}, nil
}
var (
@@ -54,15 +50,21 @@ var (
)
type connector struct {
driver driver.Driver
db *chai.DB
driver driver.Driver
db *chai.DB
closeOnce sync.Once
}
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
return &conn{db: c.db}, nil
cc, err := c.db.Connect()
if err != nil {
return nil, err
}
return &conn{
db: c.db,
conn: cc,
}, nil
}
func (c *connector) Driver() driver.Driver {
@@ -80,8 +82,8 @@ func (c *connector) Close() error {
// conn represents a connection to the Chai database.
// It implements the database/sql/driver.Conn interface.
type conn struct {
db *chai.DB
tx *chai.Tx
db *chai.DB
conn *chai.Connection
}
// Prepare returns a prepared statement, bound to this connection.
@@ -91,14 +93,7 @@ func (c *conn) Prepare(q string) (driver.Stmt, error) {
// PrepareContext returns a prepared statement, bound to this connection.
func (c *conn) PrepareContext(ctx context.Context, q string) (driver.Stmt, error) {
var s *chai.Statement
var err error
if c.tx != nil {
s, err = c.tx.Prepare(q)
} else {
s, err = c.db.Prepare(q)
}
s, err := c.conn.Prepare(q)
if err != nil {
return nil, err
}
@@ -110,11 +105,7 @@ func (c *conn) PrepareContext(ctx context.Context, q string) (driver.Stmt, error
// Close closes any ongoing transaction.
func (c *conn) Close() error {
if c.tx != nil {
return c.tx.Rollback()
}
return nil
return c.conn.Close()
}
// Begin starts and returns a new transaction.
@@ -122,6 +113,15 @@ func (c *conn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}
func (c *conn) ResetSession(ctx context.Context) error {
err := c.conn.Conn.Reset()
if err != nil {
return driver.ErrBadConn
}
return nil
}
// BeginTx starts and returns a new transaction.
// It uses the ReadOnly option to determine whether to start a read-only or read/write transaction.
// If the Isolation option is non zero, an error is returned.
@@ -130,26 +130,9 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
return nil, errors.New("isolation levels are not supported")
}
db := c.db.WithContext(ctx)
// if the ReadOnly flag is explicitly specified, create a read-only transaction,
// otherwise create a read/write transaction.
var err error
c.tx, err = db.Begin(!opts.ReadOnly)
return c, err
}
func (c *conn) Commit() error {
err := c.tx.Commit()
c.tx = nil
return err
}
func (c *conn) Rollback() error {
err := c.tx.Rollback()
c.tx = nil
return err
return c.conn.Begin(!opts.ReadOnly)
}
// Stmt is a prepared statement. It is bound to a Conn and not
@@ -176,18 +159,18 @@ func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver
default:
}
return result{}, s.stmt.Exec(driverNamedValueToParams(args)...)
return execResult{}, s.stmt.Exec(namedValueToParams(args)...)
}
type result struct{}
type execResult struct{}
// LastInsertId is not supported and returns an error.
func (r result) LastInsertId() (int64, error) {
func (r execResult) LastInsertId() (int64, error) {
return 0, errors.New("not supported")
}
// RowsAffected is not supported and returns an error.
func (r result) RowsAffected() (int64, error) {
func (r execResult) RowsAffected() (int64, error) {
return 0, errors.New("not supported")
}
@@ -204,17 +187,17 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
default:
}
res, err := s.stmt.Query(driverNamedValueToParams(args)...)
res, err := s.stmt.Query(namedValueToParams(args)...)
if err != nil {
return nil, err
}
rs := newRecordStream(res)
rs := newRows(res)
rs.columns = res.Columns()
return rs, nil
}
func driverNamedValueToParams(args []driver.NamedValue) []any {
func namedValueToParams(args []driver.NamedValue) []any {
params := make([]any, len(args))
for i, arg := range args {
var p environment.Param
@@ -233,35 +216,35 @@ func (s stmt) Close() error {
var errStop = errors.New("stop")
type recordStream struct {
type Rows struct {
res *chai.Result
cancelFn func()
c chan recordRow
c chan Row
wg sync.WaitGroup
columns []string
}
type recordRow struct {
type Row struct {
r *chai.Row
err error
}
func newRecordStream(res *chai.Result) *recordStream {
func newRows(res *chai.Result) *Rows {
ctx, cancel := context.WithCancel(context.Background())
ds := recordStream{
rs := Rows{
res: res,
cancelFn: cancel,
c: make(chan recordRow),
c: make(chan Row),
}
ds.wg.Add(1)
rs.wg.Add(1)
go ds.iterate(ctx)
go rs.iterate(ctx)
return &ds
return &rs
}
func (rs *recordStream) iterate(ctx context.Context) {
func (rs *Rows) iterate(ctx context.Context) {
defer rs.wg.Done()
defer close(rs.c)
@@ -275,7 +258,7 @@ func (rs *recordStream) iterate(ctx context.Context) {
select {
case <-ctx.Done():
return errStop
case rs.c <- recordRow{
case rs.c <- Row{
r: r,
}:
@@ -292,7 +275,7 @@ func (rs *recordStream) iterate(ctx context.Context) {
return
}
if err != nil {
rs.c <- recordRow{
rs.c <- Row{
err: err,
}
return
@@ -300,19 +283,19 @@ func (rs *recordStream) iterate(ctx context.Context) {
}
// Columns returns the fields selected by the SELECT statement.
func (rs *recordStream) Columns() []string {
func (rs *Rows) Columns() []string {
return rs.res.Columns()
}
// Close closes the rows iterator.
func (rs *recordStream) Close() error {
func (rs *Rows) Close() error {
rs.cancelFn()
rs.wg.Wait()
return rs.res.Close()
}
func (rs *recordStream) Next(dest []driver.Value) error {
rs.c <- recordRow{}
func (rs *Rows) Next(dest []driver.Value) error {
rs.c <- Row{}
row, ok := <-rs.c
if !ok {

View File

@@ -38,8 +38,14 @@ func Example() {
panic(err)
}
conn, err := db.Connect()
if err != nil {
panic(err)
}
defer conn.Close()
// Query some rows
stream, err := db.Query("SELECT * FROM user WHERE id > ?", 1)
stream, err := conn.Query("SELECT * FROM user WHERE id > ?", 1)
if err != nil {
panic(err)
}

View File

@@ -404,6 +404,7 @@ golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
@@ -489,6 +490,7 @@ golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/oauth2 v0.0.0-20170207211851-4464e7848382/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
@@ -500,6 +502,7 @@ golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4Iltr
golang.org/x/oauth2 v0.0.0-20210514164344-f6687ab2804c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
golang.org/x/oauth2 v0.12.0/go.mod h1:A74bZ3aGXgCY0qaIC9Ahg6Lglin4AMAco8cIv9baba4=
golang.org/x/oauth2 v0.16.0/go.mod h1:hqZ+0LWXsiVoZpeld6jVt06P3adbS2Uu911W1SsJv2o=
golang.org/x/perf v0.0.0-20230113213139-801c7ef9e5c5 h1:ObuXPmIgI4ZMyQLIz48cJYgSyWdjUXc2SZAdyJMwEAU=
golang.org/x/perf v0.0.0-20230113213139-801c7ef9e5c5/go.mod h1:UBKtEnL8aqnd+0JHqZ+2qoMDwtuy6cYhhKNoHLBiTQc=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=

View File

@@ -441,7 +441,11 @@ func TestReadOnlyTables(t *testing.T) {
require.NoError(t, err)
defer db.Close()
res, err := db.Query(`
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
res, err := conn.Query(`
CREATE TABLE foo (a int, b double unique, c text);
CREATE INDEX idx_foo_a ON foo(a, c);
SELECT * FROM __chai_catalog
@@ -565,20 +569,28 @@ func TestCatalogConcurrency(t *testing.T) {
require.NoError(t, err)
defer db.Close()
conn1, err := db.Connect()
require.NoError(t, err)
defer conn1.Close()
// create a table
err = db.Exec(`
err = conn1.Exec(`
CREATE TABLE test (a int);
CREATE INDEX idx_test_a ON test(a);
`)
require.NoError(t, err)
// start a transaction rt1
rt1, err := db.Begin(false)
rt1, err := conn1.Begin(false)
require.NoError(t, err)
defer rt1.Rollback()
conn2, err := db.Connect()
require.NoError(t, err)
defer conn2.Close()
// start a transaction wt2
wt1, err := db.Begin(true)
wt1, err := conn2.Begin(true)
require.NoError(t, err)
defer wt1.Rollback()

View File

@@ -0,0 +1,68 @@
package database
import (
"context"
"github.com/cockroachdb/errors"
)
type Connection struct {
db *Database
ctx context.Context
tx *Transaction
}
// BeginTx starts a new transaction with the given options.
// If opts is empty, it will use the default options.
// The returned transaction must be closed either by calling Rollback or Commit.
func (c *Connection) BeginTx(opts *TxOptions) (*Transaction, error) {
if c.ctx.Err() != nil {
return nil, errors.New("connection is closed")
}
if c.tx != nil {
return nil, errors.New("cannot open a transaction within a transaction")
}
tx, err := c.db.beginTx(opts)
if err != nil {
return nil, err
}
c.tx = tx
tx.conn = c
tx.OnRollbackHooks = append(tx.OnRollbackHooks, c.releaseAttachedTx)
tx.OnCommitHooks = append(tx.OnCommitHooks, c.releaseAttachedTx)
return tx, nil
}
func (c *Connection) Reset() error {
if c.tx != nil {
return errors.New("cannot reset a connection with an attached transaction")
}
return nil
}
func (c *Connection) releaseAttachedTx() {
if c.tx != nil {
c.tx = nil
}
}
// GetAttachedTx returns the transaction attached to the connection, if any.
// The returned transaction is not thread safe.
func (c *Connection) GetTx() *Transaction {
return c.tx
}
func (c *Connection) Close() error {
defer c.db.connectionWg.Done()
if c.tx != nil {
return c.tx.Rollback()
}
return nil
}

View File

@@ -2,6 +2,7 @@
package database
import (
"context"
"sync"
"sync/atomic"
"time"
@@ -19,12 +20,12 @@ type Database struct {
catalogMu sync.RWMutex
catalog *Catalog
// If this is non-nil, the user is running an explicit transaction
// using the BEGIN statement.
// Only one attached transaction can be run at a time and any calls to DB.Begin()
// will cause an error until that transaction is rolled back or commited.
attachedTransaction *Transaction
attachedTxMu sync.Mutex
// context used to notify all connections that the database is closing.
closeContext context.Context
closeCancel context.CancelFunc
// waitgroup to wait for all connections to be closed.
connectionWg sync.WaitGroup
// This is used to prevent creating a new transaction
// during certain operations (commit, close, etc.)
@@ -33,11 +34,11 @@ type Database struct {
// This limits the number of write transactions to 1.
writetxmu sync.Mutex
// TransactionIDs is used to assign transaction an ID at runtime.
// transactionIDs is used to assign transaction an ID at runtime.
// Since transaction IDs are not persisted and not used for concurrent
// access, we can use 8 bytes ids that will be reset every time
// the database restarts.
TransactionIDs uint64
transactionIDs atomic.Uint64
closeOnce sync.Once
@@ -62,10 +63,6 @@ type CatalogLoader interface {
type TxOptions struct {
// Open a read-only transaction.
ReadOnly bool
// Set the transaction as global at the database level.
// Any queries run by the database will use that transaction until it is
// rolled back or commited.
Attached bool
}
func Open(path string, opts *Options) (*Database, error) {
@@ -82,6 +79,9 @@ func Open(path string, opts *Options) (*Database, error) {
Engine: store,
}
// create a context that will be cancelled when the database is closed.
db.closeContext, db.closeCancel = context.WithCancel(context.Background())
// ensure the rollback segment doesn't contain any data that needs to be rolled back
// due to a previous crash.
err = db.Engine.Recover()
@@ -129,6 +129,9 @@ func (db *Database) Close() error {
var err error
db.closeOnce.Do(func() {
db.closeCancel()
db.connectionWg.Wait()
err = db.closeDatabase()
})
@@ -136,16 +139,8 @@ func (db *Database) Close() error {
}
func (db *Database) closeDatabase() error {
// If there is an attached transaction
// it must be rolled back before closing the engine.
if tx := db.GetAttachedTx(); tx != nil {
_ = tx.Rollback()
}
db.writetxmu.Lock()
defer db.writetxmu.Unlock()
// release all sequences
tx, err := db.beginTx(nil)
tx, err := db.beginTxUnlocked(nil)
if err != nil {
return err
}
@@ -171,20 +166,29 @@ func (db *Database) closeDatabase() error {
return db.Engine.Close()
}
// GetAttachedTx returns the transaction attached to the database. It returns nil if there is no
// such transaction.
// The returned transaction is not thread safe.
func (db *Database) GetAttachedTx() *Transaction {
db.attachedTxMu.Lock()
defer db.attachedTxMu.Unlock()
// Connect returns a new connection to the database.
// The returned connection is not thread safe.
// It is the caller's responsibility to close the connection.
func (db *Database) Connect() (*Connection, error) {
if db.closeContext.Err() != nil {
return nil, errors.New("database is closed")
}
return db.attachedTransaction
db.connectionWg.Add(1)
return &Connection{
db: db,
ctx: db.closeContext,
}, nil
}
// Begin starts a new transaction with default options.
// The returned transaction must be closed either by calling Rollback or Commit.
func (db *Database) Begin(writable bool) (*Transaction, error) {
return db.BeginTx(&TxOptions{
if db.closeContext.Err() != nil {
return nil, errors.New("database is closed")
}
return db.beginTx(&TxOptions{
ReadOnly: !writable,
})
}
@@ -192,10 +196,11 @@ func (db *Database) Begin(writable bool) (*Transaction, error) {
// BeginTx starts a new transaction with the given options.
// If opts is empty, it will use the default options.
// The returned transaction must be closed either by calling Rollback or Commit.
// If the Attached option is passed, it opens a database level transaction, which gets
// attached to the database and prevents any other transaction to be opened afterwards
// until it gets rolled back or commited.
func (db *Database) BeginTx(opts *TxOptions) (*Transaction, error) {
func (db *Database) beginTx(opts *TxOptions) (*Transaction, error) {
if db.closeContext.Err() != nil {
return nil, errors.New("database is closed")
}
db.txmu.RLock()
defer db.txmu.RUnlock()
@@ -207,18 +212,11 @@ func (db *Database) BeginTx(opts *TxOptions) (*Transaction, error) {
db.writetxmu.Lock()
}
db.attachedTxMu.Lock()
defer db.attachedTxMu.Unlock()
if db.attachedTransaction != nil {
return nil, errors.New("cannot open a transaction within a transaction")
}
return db.beginTx(opts)
return db.beginTxUnlocked(opts)
}
// beginTx creates a transaction without locks.
func (db *Database) beginTx(opts *TxOptions) (*Transaction, error) {
// beginTxUnlocked creates a transaction without locks.
func (db *Database) beginTxUnlocked(opts *TxOptions) (*Transaction, error) {
if opts == nil {
opts = &TxOptions{}
}
@@ -235,7 +233,7 @@ func (db *Database) beginTx(opts *TxOptions) (*Transaction, error) {
Engine: db.Engine,
Session: sess,
Writable: !opts.ReadOnly,
ID: atomic.AddUint64(&db.TransactionIDs, 1),
ID: db.transactionIDs.Add(1),
Catalog: db.Catalog(),
TxStart: time.Now(),
}
@@ -244,12 +242,6 @@ func (db *Database) beginTx(opts *TxOptions) (*Transaction, error) {
tx.WriteTxMu = &db.writetxmu
}
if opts.Attached {
db.attachedTransaction = &tx
tx.OnRollbackHooks = append(tx.OnRollbackHooks, db.releaseAttachedTx)
tx.OnCommitHooks = append(tx.OnCommitHooks, db.releaseAttachedTx)
}
return &tx, nil
}
@@ -265,12 +257,3 @@ func (db *Database) SetCatalog(c *Catalog) {
db.catalog = c
db.catalogMu.Unlock()
}
func (db *Database) releaseAttachedTx() {
db.attachedTxMu.Lock()
defer db.attachedTxMu.Unlock()
if db.attachedTransaction != nil {
db.attachedTransaction = nil
}
}

View File

@@ -21,7 +21,11 @@ func TestConcurrentTransactionManagement(t *testing.T) {
go func() {
// 1. Start transaction T1.
tx, err := db.Begin(true)
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
tx, err := conn.Begin(true)
require.NoError(t, err)
// Start transaction T2.
@@ -42,7 +46,11 @@ func TestConcurrentTransactionManagement(t *testing.T) {
// 2. Attempt to start transaction T2.
// Waits for T1 to finish.
tx, err := db.Begin(true)
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
tx, err := conn.Begin(true)
require.NoError(t, err)
require.NoError(t, tx.Rollback())

View File

@@ -13,7 +13,8 @@ import (
// Transaction is either read-only or read/write. Read-only can be used to read tables
// and read/write can be used to read, create, delete and modify tables.
type Transaction struct {
db *Database
db *Database
conn *Connection
// Timestamp at which the transaction was created.
// The timestamp must use the local timezone.
@@ -33,6 +34,10 @@ type Transaction struct {
catalogWriter *CatalogWriter
}
func (tx *Transaction) Connection() *Connection {
return tx.conn
}
// Rollback the transaction. Can be used safely after commit.
func (tx *Transaction) Rollback() error {
err := tx.Session.Close()

View File

@@ -249,6 +249,10 @@ func TestQueries(t *testing.T) {
db, err := chai.Open(filepath.Join(dir, "pebble"))
require.NoError(t, err)
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
r, err := db.QueryRow(`
CREATE TABLE test(a INT);
INSERT INTO test (a) VALUES (1), (2), (3), (4);
@@ -261,7 +265,7 @@ func TestQueries(t *testing.T) {
require.Equal(t, 4, count)
t.Run("ORDER BY", func(t *testing.T) {
st, err := db.Query("SELECT * FROM test ORDER BY a DESC")
st, err := conn.Query("SELECT * FROM test ORDER BY a DESC")
require.NoError(t, err)
defer st.Close()
@@ -297,7 +301,11 @@ func TestQueries(t *testing.T) {
db, err := chai.Open(filepath.Join(dir, "pebble"))
require.NoError(t, err)
st, err := db.Query(`
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
st, err := conn.Query(`
CREATE TABLE test(a INT);
INSERT INTO test (a) VALUES (1), (2), (3), (4);
UPDATE test SET a = 5;
@@ -317,16 +325,23 @@ func TestQueries(t *testing.T) {
db, err := chai.Open(filepath.Join(dir, "pebble"))
require.NoError(t, err)
err = db.Exec("CREATE TABLE test(a INT)")
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
err = conn.Exec("CREATE TABLE test(a INT)")
require.NoError(t, err)
err = db.Update(func(tx *chai.Tx) error {
for i := 1; i < 200; i++ {
err = tx.Exec("INSERT INTO test (a) VALUES (?)", i)
require.NoError(t, err)
}
return nil
})
tx, err := conn.Begin(true)
require.NoError(t, err)
defer tx.Rollback()
for i := 1; i < 200; i++ {
err = tx.Exec("INSERT INTO test (a) VALUES (?)", i)
require.NoError(t, err)
}
err = tx.Commit()
require.NoError(t, err)
r, err := db.QueryRow(`
@@ -349,19 +364,26 @@ func TestQueriesSameTransaction(t *testing.T) {
db, err := chai.Open(filepath.Join(dir, "pebble"))
require.NoError(t, err)
err = db.Update(func(tx *chai.Tx) error {
r, err := tx.QueryRow(`
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
tx, err := conn.Begin(true)
require.NoError(t, err)
defer tx.Rollback()
r, err := tx.QueryRow(`
CREATE TABLE test(a INT);
INSERT INTO test (a) VALUES (1), (2), (3), (4);
SELECT COUNT(*) FROM test;
`)
require.NoError(t, err)
var count int
err = r.Scan(&count)
require.NoError(t, err)
require.Equal(t, 4, count)
return nil
})
require.NoError(t, err)
var count int
err = r.Scan(&count)
require.NoError(t, err)
require.Equal(t, 4, count)
err = tx.Commit()
require.NoError(t, err)
})
@@ -371,14 +393,21 @@ func TestQueriesSameTransaction(t *testing.T) {
db, err := chai.Open(filepath.Join(dir, "pebble"))
require.NoError(t, err)
err = db.Update(func(tx *chai.Tx) error {
err = tx.Exec(`
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
tx, err := conn.Begin(true)
require.NoError(t, err)
defer tx.Rollback()
err = tx.Exec(`
CREATE TABLE test(a INT);
INSERT INTO test (a) VALUES (1), (2), (3), (4);
`)
require.NoError(t, err)
return nil
})
require.NoError(t, err)
err = tx.Commit()
require.NoError(t, err)
})
@@ -388,21 +417,28 @@ func TestQueriesSameTransaction(t *testing.T) {
db, err := chai.Open(filepath.Join(dir, "pebble"))
require.NoError(t, err)
err = db.Update(func(tx *chai.Tx) error {
st, err := tx.Query(`
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
tx, err := conn.Begin(true)
require.NoError(t, err)
defer tx.Rollback()
st, err := tx.Query(`
CREATE TABLE test(a INT);
INSERT INTO test (a) VALUES (1), (2), (3), (4);
UPDATE test SET a = 5;
SELECT * FROM test;
`)
require.NoError(t, err)
defer st.Close()
var buf bytes.Buffer
err = st.MarshalJSONTo(&buf)
require.NoError(t, err)
require.JSONEq(t, `[{"a": 5},{"a": 5},{"a": 5},{"a": 5}]`, buf.String())
return nil
})
require.NoError(t, err)
defer st.Close()
var buf bytes.Buffer
err = st.MarshalJSONTo(&buf)
require.NoError(t, err)
require.JSONEq(t, `[{"a": 5},{"a": 5},{"a": 5},{"a": 5}]`, buf.String())
err = tx.Commit()
require.NoError(t, err)
})
@@ -412,20 +448,27 @@ func TestQueriesSameTransaction(t *testing.T) {
db, err := chai.Open(filepath.Join(dir, "pebble"))
require.NoError(t, err)
err = db.Update(func(tx *chai.Tx) error {
r, err := tx.QueryRow(`
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
tx, err := conn.Begin(true)
require.NoError(t, err)
defer tx.Rollback()
r, err := tx.QueryRow(`
CREATE TABLE test(a INT);
INSERT INTO test (a) VALUES (1), (2), (3), (4), (5), (6), (7), (8), (9), (10);
DELETE FROM test WHERE a > 2;
SELECT COUNT(*) FROM test;
`)
require.NoError(t, err)
var count int
err = r.Scan(&count)
require.NoError(t, err)
require.Equal(t, 2, count)
return nil
})
require.NoError(t, err)
var count int
err = r.Scan(&count)
require.NoError(t, err)
require.Equal(t, 2, count)
err = tx.Commit()
require.NoError(t, err)
})
}

View File

@@ -25,16 +25,12 @@ func New(statements ...statement.Statement) Query {
type Context struct {
Ctx context.Context
DB *database.Database
Tx *database.Transaction
Conn *database.Connection
Params []environment.Param
}
func (c *Context) GetTx() *database.Transaction {
if c.Tx != nil {
return c.Tx
}
return c.DB.GetAttachedTx()
return c.Conn.GetTx()
}
// Prepare the statements by calling their Prepare methods.
@@ -62,7 +58,7 @@ func (q *Query) Prepare(context *Context) error {
if tx == nil {
tx = context.GetTx()
if tx == nil {
tx, err = context.DB.BeginTx(&database.TxOptions{
tx, err = context.Conn.BeginTx(&database.TxOptions{
ReadOnly: true,
})
if err != nil {
@@ -73,8 +69,9 @@ func (q *Query) Prepare(context *Context) error {
}
stmt, err := p.Prepare(&statement.Context{
DB: context.DB,
Tx: tx,
DB: context.DB,
Conn: context.Conn,
Tx: tx,
})
if err != nil {
return err
@@ -110,10 +107,10 @@ func (q Query) Run(context *Context) (*statement.Result, error) {
res = statement.Result{}
if qa, ok := stmt.(queryAlterer); ok {
err = qa.alterQuery(context.DB, &q)
err = qa.alterQuery(context.Conn, &q)
if err != nil {
if tx := context.GetTx(); tx != nil {
tx.Rollback()
_ = tx.Rollback()
}
return nil, err
}
@@ -122,7 +119,7 @@ func (q Query) Run(context *Context) (*statement.Result, error) {
}
if q.tx == nil {
q.tx, err = context.DB.BeginTx(&database.TxOptions{
q.tx, err = context.Conn.BeginTx(&database.TxOptions{
ReadOnly: stmt.IsReadOnly(),
})
if err != nil {
@@ -132,6 +129,7 @@ func (q Query) Run(context *Context) (*statement.Result, error) {
res, err = stmt.Run(&statement.Context{
DB: context.DB,
Conn: context.Conn,
Tx: q.tx,
Params: context.Params,
})
@@ -185,5 +183,5 @@ func (q Query) Run(context *Context) (*statement.Result, error) {
}
type queryAlterer interface {
alterQuery(db *database.Database, q *Query) error
alterQuery(conn *database.Connection, q *Query) error
}

View File

@@ -33,6 +33,10 @@ func TestDeleteStmt(t *testing.T) {
require.NoError(t, err)
defer db.Close()
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
err = db.Exec("CREATE TABLE test(id INT PRIMARY KEY, a TEXT, b TEXT, c TEXT, d TEXT, e TEXT, n INT)")
require.NoError(t, err)
err = db.Exec("INSERT INTO test (id, a, b, c, n) VALUES (1, 'foo1', 'bar1', 'baz1', 3)")
@@ -42,14 +46,14 @@ func TestDeleteStmt(t *testing.T) {
err = db.Exec("INSERT INTO test (id, d, b, e, n) VALUES (3, 'foo3', 'bar2', 'bar3', 1)")
require.NoError(t, err)
err = db.Exec(test.query, test.params...)
err = conn.Exec(test.query, test.params...)
if test.fails {
require.Error(t, err)
return
}
require.NoError(t, err)
st, err := db.Query("SELECT id FROM test")
st, err := conn.Query("SELECT id FROM test")
require.NoError(t, err)
defer st.Close()

View File

@@ -15,22 +15,26 @@ func TestDropTable(t *testing.T) {
require.NoError(t, err)
defer db.Close()
err = db.Exec("CREATE TABLE test1(a INT UNIQUE); CREATE TABLE test2(a INT); CREATE TABLE test3(a INT)")
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
err = conn.Exec("CREATE TABLE test1(a INT UNIQUE); CREATE TABLE test2(a INT); CREATE TABLE test3(a INT)")
require.NoError(t, err)
err = db.Exec("DROP TABLE test1")
err = conn.Exec("DROP TABLE test1")
require.NoError(t, err)
err = db.Exec("DROP TABLE IF EXISTS test1")
err = conn.Exec("DROP TABLE IF EXISTS test1")
require.NoError(t, err)
// Dropping a table that doesn't exist without "IF EXISTS"
// should return an error.
err = db.Exec("DROP TABLE test1")
err = conn.Exec("DROP TABLE test1")
require.Error(t, err)
// Assert that no other table has been dropped.
res, err := db.Query("SELECT name FROM __chai_catalog WHERE type = 'table'")
res, err := conn.Query("SELECT name FROM __chai_catalog WHERE type = 'table'")
require.NoError(t, err)
var tables []string
err = res.Iterate(func(r *chai.Row) error {
@@ -49,18 +53,18 @@ func TestDropTable(t *testing.T) {
// Assert the unique index test1_a_idx, created upon the creation of the table,
// has been dropped as well.
_, err = db.QueryRow("SELECT 1 FROM __chai_catalog WHERE name = 'test1_a_idx'")
_, err = conn.QueryRow("SELECT 1 FROM __chai_catalog WHERE name = 'test1_a_idx'")
require.Error(t, err)
// Assert the rowid sequence test1_seq, created upon the creation of the table,
// has been dropped as well.
_, err = db.QueryRow("SELECT 1 FROM __chai_catalog WHERE name = 'test1_seq'")
_, err = conn.QueryRow("SELECT 1 FROM __chai_catalog WHERE name = 'test1_seq'")
require.Error(t, err)
_, err = db.QueryRow("SELECT 1 FROM __chai_sequence WHERE name = 'test1_seq'")
_, err = conn.QueryRow("SELECT 1 FROM __chai_sequence WHERE name = 'test1_seq'")
require.Error(t, err)
// Dropping a read-only table should fail.
err = db.Exec("DROP TABLE __chai_catalog")
err = conn.Exec("DROP TABLE __chai_catalog")
require.Error(t, err)
}

View File

@@ -31,10 +31,14 @@ func TestInsertStmt(t *testing.T) {
require.NoError(t, err)
defer db.Close()
err = db.Exec("CREATE TABLE test(a TEXT, b TEXT, c TEXT)")
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
err = conn.Exec("CREATE TABLE test(a TEXT, b TEXT, c TEXT)")
require.NoError(t, err)
if withIndexes {
err = db.Exec(`
err = conn.Exec(`
CREATE INDEX idx_a ON test (a);
CREATE INDEX idx_b ON test (b);
CREATE INDEX idx_c ON test (c);
@@ -42,14 +46,14 @@ func TestInsertStmt(t *testing.T) {
require.NoError(t, err)
}
err = db.Exec(test.query, test.params...)
err = conn.Exec(test.query, test.params...)
if test.fails {
require.Error(t, err)
return
}
require.NoError(t, err)
st, err := db.Query("SELECT * FROM test")
st, err := conn.Query("SELECT * FROM test")
require.NoError(t, err)
defer st.Close()
@@ -82,13 +86,17 @@ func TestInsertStmt(t *testing.T) {
require.NoError(t, err)
defer db.Close()
err = db.Exec(`CREATE TABLE test(a int unique)`)
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
err = conn.Exec(`CREATE TABLE test(a int unique)`)
require.NoError(t, err)
err = db.Exec(`insert into test (a) VALUES (1), (1)`)
err = conn.Exec(`insert into test (a) VALUES (1), (1)`)
require.Error(t, err)
res, err := db.Query("SELECT * FROM test")
res, err := conn.Query("SELECT * FROM test")
require.NoError(t, err)
defer res.Close()
@@ -100,13 +108,17 @@ func TestInsertStmt(t *testing.T) {
require.NoError(t, err)
defer db.Close()
err = db.Exec(`CREATE SEQUENCE seq; CREATE TABLE test(a int, b int default NEXT VALUE FOR seq)`)
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
err = conn.Exec(`CREATE SEQUENCE seq; CREATE TABLE test(a int, b int default NEXT VALUE FOR seq)`)
require.NoError(t, err)
err = db.Exec(`insert into test (a) VALUES (1), (2), (3)`)
err = conn.Exec(`insert into test (a) VALUES (1), (2), (3)`)
require.NoError(t, err)
res, err := db.Query("SELECT * FROM test")
res, err := conn.Query("SELECT * FROM test")
require.NoError(t, err)
defer res.Close()
@@ -148,21 +160,25 @@ func TestInsertSelect(t *testing.T) {
require.NoError(t, err)
defer db.Close()
err = db.Exec(`
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
err = conn.Exec(`
CREATE TABLE foo(a INT, b INT, c INT, d INT, e INT);
CREATE TABLE bar(a INT, b INT, c INT, d INT, e INT);
INSERT INTO bar (a, b) VALUES (1, 10)
`)
require.NoError(t, err)
err = db.Exec(test.query, test.params...)
err = conn.Exec(test.query, test.params...)
if test.fails {
require.Error(t, err)
return
}
require.NoError(t, err)
st, err := db.Query("SELECT * FROM foo")
st, err := conn.Query("SELECT * FROM foo")
require.NoError(t, err)
defer st.Close()

View File

@@ -106,7 +106,11 @@ func TestSelectStmt(t *testing.T) {
require.NoError(t, err)
defer db.Close()
err = db.Exec(`--sql
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
err = conn.Exec(`--sql
CREATE TABLE test (
k INTEGER PRIMARY KEY,
color TEXT,
@@ -117,7 +121,7 @@ func TestSelectStmt(t *testing.T) {
)`)
require.NoError(t, err)
if withIndexes {
err = db.Exec(`
err = conn.Exec(`
CREATE INDEX idx_color ON test (color);
CREATE INDEX idx_size ON test (size);
CREATE INDEX idx_shape ON test (shape);
@@ -127,14 +131,14 @@ func TestSelectStmt(t *testing.T) {
require.NoError(t, err)
}
err = db.Exec("INSERT INTO test (k, color, size, shape) VALUES (1, 'red', 10, 'square')")
err = conn.Exec("INSERT INTO test (k, color, size, shape) VALUES (1, 'red', 10, 'square')")
require.NoError(t, err)
err = db.Exec("INSERT INTO test (k, color, size, weight) VALUES (2, 'blue', 10, 100)")
err = conn.Exec("INSERT INTO test (k, color, size, weight) VALUES (2, 'blue', 10, 100)")
require.NoError(t, err)
err = db.Exec("INSERT INTO test (k, height, weight) VALUES (3, 100, 200)")
err = conn.Exec("INSERT INTO test (k, height, weight) VALUES (3, 100, 200)")
require.NoError(t, err)
st, err := db.Query(test.query, test.params...)
st, err := conn.Query(test.query, test.params...)
defer st.Close()
if test.fails {
@@ -159,19 +163,23 @@ func TestSelectStmt(t *testing.T) {
require.NoError(t, err)
defer db.Close()
err = db.Exec("CREATE TABLE test (foo INTEGER PRIMARY KEY, bar TEXT)")
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
err = conn.Exec("CREATE TABLE test (foo INTEGER PRIMARY KEY, bar TEXT)")
require.NoError(t, err)
err = db.Exec(`INSERT INTO test (foo, bar) VALUES (1, 'a')`)
err = conn.Exec(`INSERT INTO test (foo, bar) VALUES (1, 'a')`)
require.NoError(t, err)
err = db.Exec(`INSERT INTO test (foo, bar) VALUES (2, 'b')`)
err = conn.Exec(`INSERT INTO test (foo, bar) VALUES (2, 'b')`)
require.NoError(t, err)
err = db.Exec(`INSERT INTO test (foo, bar) VALUES (3, 'c')`)
err = conn.Exec(`INSERT INTO test (foo, bar) VALUES (3, 'c')`)
require.NoError(t, err)
err = db.Exec(`INSERT INTO test (foo, bar) VALUES (4, 'd')`)
err = conn.Exec(`INSERT INTO test (foo, bar) VALUES (4, 'd')`)
require.NoError(t, err)
st, err := db.Query("SELECT * FROM test WHERE foo < 400 AND foo >= 2")
st, err := conn.Query("SELECT * FROM test WHERE foo < 400 AND foo >= 2")
require.NoError(t, err)
defer st.Close()
@@ -195,13 +203,17 @@ func TestSelectStmt(t *testing.T) {
require.NoError(t, err)
defer db.Close()
err = db.Exec("CREATE TABLE test(foo INT); CREATE INDEX idx_foo ON test(foo);")
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
err = conn.Exec("CREATE TABLE test(foo INT); CREATE INDEX idx_foo ON test(foo);")
require.NoError(t, err)
err = db.Exec(`INSERT INTO test (foo) VALUES (4), (2), (1), (3)`)
err = conn.Exec(`INSERT INTO test (foo) VALUES (4), (2), (1), (3)`)
require.NoError(t, err)
st, err := db.Query("SELECT * FROM test ORDER BY foo")
st, err := conn.Query("SELECT * FROM test ORDER BY foo")
require.NoError(t, err)
defer st.Close()
@@ -233,7 +245,11 @@ func TestSelectStmt(t *testing.T) {
require.NoError(t, err)
defer db.Close()
err = db.Exec(`
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
err = conn.Exec(`
CREATE TABLE test(a INT);
INSERT INTO test (a) VALUES (1);
CREATE SEQUENCE seq;
@@ -241,7 +257,7 @@ func TestSelectStmt(t *testing.T) {
require.NoError(t, err)
// normal query
r, err := db.QueryRow("SELECT a, NEXT VALUE FOR seq FROM test")
r, err := conn.QueryRow("SELECT a, NEXT VALUE FOR seq FROM test")
require.NoError(t, err)
var a, seq int
err = r.Scan(&a, &seq)
@@ -250,7 +266,7 @@ func TestSelectStmt(t *testing.T) {
require.Equal(t, 1, seq)
// query with no table
r, err = db.QueryRow("SELECT NEXT VALUE FOR seq")
r, err = conn.QueryRow("SELECT NEXT VALUE FOR seq")
require.NoError(t, err)
err = r.Scan(&seq)
require.NoError(t, err)
@@ -262,13 +278,17 @@ func TestSelectStmt(t *testing.T) {
require.NoError(t, err)
defer db.Close()
err = db.Exec(`
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
err = conn.Exec(`
CREATE TABLE test(a INT);
INSERT INTO test (a) VALUES (1), (2), (3);
`)
require.NoError(t, err)
r, err := db.QueryRow("SELECT a FROM test LIMIT ? OFFSET ?", 1, 1)
r, err := conn.QueryRow("SELECT a FROM test LIMIT ? OFFSET ?", 1, 1)
require.NoError(t, err)
var a int
err = r.Scan(&a)
@@ -302,7 +322,11 @@ func TestDistinct(t *testing.T) {
require.NoError(t, err)
defer db.Close()
tx, err := db.Begin(true)
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
tx, err := conn.Begin(true)
require.NoError(t, err)
defer tx.Rollback()
@@ -334,7 +358,7 @@ func TestDistinct(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
q, err := db.Query(test.query)
q, err := conn.Query(test.query)
require.NoError(t, err)
defer q.Close()

View File

@@ -33,6 +33,7 @@ func (stmt *basePreparedStatement) Run(ctx *Context) (Result, error) {
type Context struct {
DB *database.Database
Conn *database.Connection
Tx *database.Transaction
Params []environment.Param
}

View File

@@ -40,29 +40,33 @@ func TestUpdateStmt(t *testing.T) {
require.NoError(t, err)
defer db.Close()
err = db.Exec("CREATE TABLE test (a text not null, b text, c text, d text, e text)")
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
err = conn.Exec("CREATE TABLE test (a text not null, b text, c text, d text, e text)")
require.NoError(t, err)
if indexed {
err = db.Exec("CREATE INDEX idx_test_a ON test(a)")
err = conn.Exec("CREATE INDEX idx_test_a ON test(a)")
require.NoError(t, err)
}
err = db.Exec("INSERT INTO test (a, b, c) VALUES ('foo1', 'bar1', 'baz1')")
err = conn.Exec("INSERT INTO test (a, b, c) VALUES ('foo1', 'bar1', 'baz1')")
require.NoError(t, err)
err = db.Exec("INSERT INTO test (a, b) VALUES ('foo2', 'bar2')")
err = conn.Exec("INSERT INTO test (a, b) VALUES ('foo2', 'bar2')")
require.NoError(t, err)
err = db.Exec("INSERT INTO test (a, d, e) VALUES ('foo3', 'bar3', 'baz3')")
err = conn.Exec("INSERT INTO test (a, d, e) VALUES ('foo3', 'bar3', 'baz3')")
require.NoError(t, err)
err = db.Exec(test.query, test.params...)
err = conn.Exec(test.query, test.params...)
if test.fails {
require.Error(t, err)
return
}
require.NoError(t, err)
st, err := db.Query("SELECT * FROM test")
st, err := conn.Query("SELECT * FROM test")
require.NoError(t, err)
defer st.Close()

View File

@@ -6,6 +6,10 @@ import (
"github.com/cockroachdb/errors"
)
var _ queryAlterer = BeginStmt{}
var _ queryAlterer = RollbackStmt{}
var _ queryAlterer = CommitStmt{}
// BeginStmt is a statement that creates a new transaction.
type BeginStmt struct {
Writable bool
@@ -16,15 +20,14 @@ func (stmt BeginStmt) Prepare(*statement.Context) (statement.Statement, error) {
return stmt, nil
}
func (stmt BeginStmt) alterQuery(db *database.Database, q *Query) error {
func (stmt BeginStmt) alterQuery(conn *database.Connection, q *Query) error {
if q.tx != nil {
return errors.New("cannot begin a transaction within a transaction")
}
var err error
q.tx, err = db.BeginTx(&database.TxOptions{
q.tx, err = conn.BeginTx(&database.TxOptions{
ReadOnly: !stmt.Writable,
Attached: true,
})
q.autoCommit = false
return err
@@ -46,7 +49,7 @@ func (stmt RollbackStmt) Prepare(*statement.Context) (statement.Statement, error
return stmt, nil
}
func (stmt RollbackStmt) alterQuery(db *database.Database, q *Query) error {
func (stmt RollbackStmt) alterQuery(conn *database.Connection, q *Query) error {
if q.tx == nil || q.autoCommit {
return errors.New("cannot rollback with no active transaction")
}
@@ -76,7 +79,7 @@ func (stmt CommitStmt) Prepare(*statement.Context) (statement.Statement, error)
return stmt, nil
}
func (stmt CommitStmt) alterQuery(db *database.Database, q *Query) error {
func (stmt CommitStmt) alterQuery(conn *database.Connection, q *Query) error {
if q.tx == nil || q.autoCommit {
return errors.New("cannot commit with no active transaction")
}

View File

@@ -30,10 +30,14 @@ func TestTransactionRun(t *testing.T) {
db, err := chai.Open(":memory:")
require.NoError(t, err)
defer db.Close()
defer db.Exec("ROLLBACK")
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
defer conn.Exec("ROLLBACK")
for _, q := range test.queries {
err = db.Exec(q)
err = conn.Exec(q)
if err != nil {
break
}

View File

@@ -62,16 +62,19 @@ func TestParserDelete(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
db := testutil.NewTestDB(t)
testutil.MustExec(t, db, nil, "CREATE TABLE test(age int)")
db, tx, cleanup := testutil.NewTestTx(t)
defer cleanup()
testutil.MustExec(t, db, tx, "CREATE TABLE test(age int)")
q, err := parser.ParseQuery(test.s)
require.NoError(t, err)
err = q.Prepare(&query.Context{
Ctx: context.Background(),
DB: db,
Ctx: context.Background(),
DB: db,
Conn: tx.Connection(),
})
require.NoError(t, err)

View File

@@ -200,9 +200,10 @@ func TestParserInsert(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
db := testutil.NewTestDB(t)
db, tx, cleanup := testutil.NewTestTx(t)
defer cleanup()
testutil.MustExec(t, db, nil, "CREATE TABLE test(a TEXT, b TEXT); CREATE TABLE foo(c TEXT, d TEXT);")
testutil.MustExec(t, db, tx, "CREATE TABLE test(a TEXT, b TEXT); CREATE TABLE foo(c TEXT, d TEXT);")
q, err := parser.ParseQuery(test.s)
if test.fails {
@@ -212,8 +213,9 @@ func TestParserInsert(t *testing.T) {
require.NoError(t, err)
err = q.Prepare(&query.Context{
Ctx: context.Background(),
DB: db,
Ctx: context.Background(),
DB: db,
Conn: tx.Connection(),
})
require.NoError(t, err)

View File

@@ -353,9 +353,10 @@ func TestParserSelect(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
q, err := parser.ParseQuery(test.s)
if !test.mustFail {
db := testutil.NewTestDB(t)
db, tx, cleanup := testutil.NewTestTx(t)
defer cleanup()
testutil.MustExec(t, db, nil, `
testutil.MustExec(t, db, tx, `
CREATE TABLE test(a TEXT, b TEXT, age int);
CREATE TABLE test1(age INT, a INT);
CREATE TABLE test2(age INT, a INT);
@@ -367,8 +368,9 @@ func TestParserSelect(t *testing.T) {
)
err = q.Prepare(&query.Context{
Ctx: context.Background(),
DB: db,
Ctx: context.Background(),
DB: db,
Conn: tx.Connection(),
})
require.NoError(t, err)

View File

@@ -49,9 +49,10 @@ func TestParserUpdate(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
db := testutil.NewTestDB(t)
db, tx, cleanup := testutil.NewTestTx(t)
defer cleanup()
testutil.MustExec(t, db, nil, "CREATE TABLE test(a INT, b TEXT)")
testutil.MustExec(t, db, tx, "CREATE TABLE test(a INT, b TEXT)")
q, err := parser.ParseQuery(test.s)
if test.errored {
@@ -61,8 +62,9 @@ func TestParserUpdate(t *testing.T) {
require.NoError(t, err)
err = q.Prepare(&query.Context{
Ctx: context.Background(),
DB: db,
Ctx: context.Background(),
DB: db,
Conn: tx.Connection(),
})
require.NoError(t, err)

View File

@@ -142,6 +142,7 @@ type DiscardOperator struct {
func Discard() *DiscardOperator {
return &DiscardOperator{}
}
func (it *DiscardOperator) Clone() Operator {
return &DiscardOperator{
BaseOperator: it.BaseOperator.Clone(),

View File

@@ -60,12 +60,28 @@ func NewTestDB(t testing.TB) *database.Database {
return db
}
func NewTestConn(t testing.TB, db *database.Database) *database.Connection {
t.Helper()
conn, err := db.Connect()
require.NoError(t, err)
t.Cleanup(func() {
conn.Close()
})
return conn
}
func NewTestTx(t testing.TB) (*database.Database, *database.Transaction, func()) {
t.Helper()
db := NewTestDB(t)
conn := NewTestConn(t, db)
tx, err := db.Begin(true)
tx, err := conn.BeginTx(&database.TxOptions{
ReadOnly: false,
})
require.NoError(t, err)
return db, tx, func() {
@@ -91,7 +107,7 @@ func Query(db *database.Database, tx *database.Transaction, q string, params ...
return nil, err
}
ctx := &query.Context{Ctx: context.Background(), DB: db, Tx: tx, Params: params}
ctx := &query.Context{Ctx: context.Background(), DB: db, Conn: tx.Connection(), Params: params}
err = pq.Prepare(ctx)
if err != nil {
return nil, err

View File

@@ -111,7 +111,13 @@ func TestSQL(t *testing.T) {
if test.Fails {
exec := func() error {
res, err := db.Query(test.Expr)
conn, err := db.Connect()
if err != nil {
return err
}
defer conn.Close()
res, err := conn.Query(test.Expr)
if err != nil {
return err
}
@@ -131,7 +137,11 @@ func TestSQL(t *testing.T) {
require.Errorf(t, err, "\nSource:%s:%d expected\n%s\nto raise an error but got none", absPath, test.Line, test.Expr)
}
} else {
res, err := db.Query(test.Expr)
conn, err := db.Connect()
require.NoError(t, err, "Source: %s:%d", absPath, test.Line)
defer conn.Close()
res, err := conn.Query(test.Expr)
require.NoError(t, err, "Source: %s:%d", absPath, test.Line)
defer res.Close()