mirror of
https://github.com/chaisql/chai.git
synced 2025-09-26 19:51:21 +08:00
db: add Connection
This commit is contained in:
@@ -31,12 +31,17 @@ type execer func(q string, args ...interface{}) error
|
||||
// If tables is provided, only selected tables will be outputted.
|
||||
func Bench(db *chai.DB, query string, opt BenchOptions) error {
|
||||
var tx *chai.Tx
|
||||
var p preparer = db
|
||||
conn, err := db.Connect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
var p preparer = conn
|
||||
var e execer = db.Exec
|
||||
var err error
|
||||
|
||||
if opt.SameTx {
|
||||
tx, err = db.Begin(true)
|
||||
tx, err = conn.Begin(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@@ -12,7 +12,13 @@ import (
|
||||
// Dump takes a database and dumps its content as SQL queries in the given writer.
|
||||
// If tables is provided, only selected tables will be outputted.
|
||||
func Dump(db *chai.DB, w io.Writer, tables ...string) error {
|
||||
tx, err := db.Begin(false)
|
||||
conn, err := db.Connect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
tx, err := conn.Begin(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -97,7 +103,13 @@ func dumpTable(tx *chai.Tx, w io.Writer, query, tableName string) error {
|
||||
// DumpSchema takes a database and dumps its schema as SQL queries in the given writer.
|
||||
// If tables are provided, only selected tables will be outputted.
|
||||
func DumpSchema(db *chai.DB, w io.Writer, tables ...string) error {
|
||||
tx, err := db.Begin(false)
|
||||
conn, err := db.Connect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
tx, err := conn.Begin(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@@ -19,11 +19,18 @@ func ExecSQL(ctx context.Context, db *chai.DB, r io.Reader, w io.Writer) error {
|
||||
enc.SetEscapeHTML(false)
|
||||
enc.SetIndent("", " ")
|
||||
|
||||
conn, err := db.Connect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
return parser.NewParser(r).Parse(func(s statement.Statement) error {
|
||||
qq := query.New(s)
|
||||
qctx := query.Context{
|
||||
Ctx: ctx,
|
||||
DB: db.DB,
|
||||
Ctx: ctx,
|
||||
DB: db.DB,
|
||||
Conn: conn.Conn,
|
||||
}
|
||||
err := qq.Prepare(&qctx)
|
||||
if err != nil {
|
||||
|
@@ -50,7 +50,13 @@ func ListIndexes(db *chai.DB, tableName string) ([]string, error) {
|
||||
if tableName != "" {
|
||||
q += " AND owner_table_name = ?"
|
||||
}
|
||||
res, err := db.Query(q, tableName)
|
||||
conn, err := db.Connect()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
res, err := conn.Query(q, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@@ -115,7 +115,13 @@ func runHelpCmd(out io.Writer) error {
|
||||
|
||||
// runTablesCmd displays all tables.
|
||||
func runTablesCmd(db *chai.DB, w io.Writer) error {
|
||||
res, err := db.Query("SELECT name FROM __chai_catalog WHERE type = 'table' AND name NOT LIKE '__chai_%'")
|
||||
conn, err := db.Connect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
res, err := conn.Query("SELECT name FROM __chai_catalog WHERE type = 'table' AND name NOT LIKE '__chai_%'")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -195,7 +201,13 @@ func runImportCmd(db *chai.DB, fileType, path, table string) error {
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
tx, err := db.Begin(true)
|
||||
conn, err := db.Connect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
tx, err := conn.Begin(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
152
db.go
152
db.go
@@ -47,21 +47,64 @@ func Open(path string) (*DB, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (db *DB) Connect() (*Connection, error) {
|
||||
conn, err := db.DB.Connect()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Connection{
|
||||
db: db,
|
||||
Conn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// WithContext creates a new database handle using the given context for every operation.
|
||||
func (db DB) WithContext(ctx context.Context) *DB {
|
||||
db.ctx = ctx
|
||||
return &db
|
||||
}
|
||||
|
||||
func (db *DB) withConn(fn func(*Connection) error) error {
|
||||
conn, err := db.Connect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
return fn(conn)
|
||||
}
|
||||
|
||||
// QueryRow runs the query and returns the first row.
|
||||
func (db *DB) QueryRow(q string, args ...any) (r *Row, err error) {
|
||||
err = db.withConn(func(c *Connection) error {
|
||||
r, err = c.QueryRow(q, args...)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Exec a query against the database without returning the result.
|
||||
func (db *DB) Exec(q string, args ...any) error {
|
||||
return db.withConn(func(c *Connection) error {
|
||||
return c.Exec(q, args...)
|
||||
})
|
||||
}
|
||||
|
||||
// Close the database.
|
||||
func (db *DB) Close() error {
|
||||
return db.DB.Close()
|
||||
}
|
||||
|
||||
type Connection struct {
|
||||
db *DB
|
||||
Conn *database.Connection
|
||||
}
|
||||
|
||||
// Begin starts a new transaction.
|
||||
// The returned transaction must be closed either by calling Rollback or Commit.
|
||||
func (db *DB) Begin(writable bool) (*Tx, error) {
|
||||
tx, err := db.DB.BeginTx(&database.TxOptions{
|
||||
func (c *Connection) Begin(writable bool) (*Tx, error) {
|
||||
_, err := c.Conn.BeginTx(&database.TxOptions{
|
||||
ReadOnly: !writable,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -69,14 +112,13 @@ func (db *DB) Begin(writable bool) (*Tx, error) {
|
||||
}
|
||||
|
||||
return &Tx{
|
||||
db: db,
|
||||
tx: tx,
|
||||
conn: c,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// View starts a read only transaction, runs fn and automatically rolls it back.
|
||||
func (db *DB) View(fn func(tx *Tx) error) error {
|
||||
tx, err := db.Begin(false)
|
||||
func (c *Connection) View(fn func(tx *Tx) error) error {
|
||||
tx, err := c.Begin(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -86,8 +128,8 @@ func (db *DB) View(fn func(tx *Tx) error) error {
|
||||
}
|
||||
|
||||
// Update starts a read-write transaction, runs fn and automatically commits it.
|
||||
func (db *DB) Update(fn func(tx *Tx) error) error {
|
||||
tx, err := db.Begin(true)
|
||||
func (c *Connection) Update(fn func(tx *Tx) error) error {
|
||||
tx, err := c.Begin(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -103,18 +145,25 @@ func (db *DB) Update(fn func(tx *Tx) error) error {
|
||||
|
||||
// Query the database and return the result.
|
||||
// The returned result must always be closed after usage.
|
||||
func (db *DB) Query(q string, args ...any) (*Result, error) {
|
||||
stmt, err := db.Prepare(q)
|
||||
func (c *Connection) Query(q string, args ...any) (*Result, error) {
|
||||
stmt, err := c.Prepare(q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return stmt.Query(args...)
|
||||
res, err := stmt.Query(args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res.conn = c
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// QueryRow runs the query and returns the first row.
|
||||
func (db *DB) QueryRow(q string, args ...any) (*Row, error) {
|
||||
stmt, err := db.Prepare(q)
|
||||
func (c *Connection) QueryRow(q string, args ...any) (*Row, error) {
|
||||
stmt, err := c.Prepare(q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -123,8 +172,8 @@ func (db *DB) QueryRow(q string, args ...any) (*Row, error) {
|
||||
}
|
||||
|
||||
// Exec a query against the database without returning the result.
|
||||
func (db *DB) Exec(q string, args ...any) error {
|
||||
stmt, err := db.Prepare(q)
|
||||
func (c *Connection) Exec(q string, args ...any) error {
|
||||
stmt, err := c.Prepare(q)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -133,41 +182,54 @@ func (db *DB) Exec(q string, args ...any) error {
|
||||
}
|
||||
|
||||
// Prepare parses the query and returns a prepared statement.
|
||||
func (db *DB) Prepare(q string) (*Statement, error) {
|
||||
func (c *Connection) Prepare(q string) (*Statement, error) {
|
||||
pq, err := parser.ParseQuery(q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = pq.Prepare(newQueryContext(db, nil, nil))
|
||||
err = pq.Prepare(newQueryContext(c, nil))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Statement{
|
||||
pq: pq,
|
||||
db: db,
|
||||
pq: pq,
|
||||
conn: c,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Connection) Close() error {
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// Tx represents a database transaction. It provides methods for managing the
|
||||
// collection of tables and the transaction itself.
|
||||
// Tx is either read-only or read/write. Read-only can be used to read tables
|
||||
// and read/write can be used to read, create, delete and modify tables.
|
||||
type Tx struct {
|
||||
db *DB
|
||||
tx *database.Transaction
|
||||
conn *Connection
|
||||
}
|
||||
|
||||
// Rollback the transaction. Can be used safely after commit.
|
||||
func (tx *Tx) Rollback() error {
|
||||
return tx.tx.Rollback()
|
||||
t := tx.conn.Conn.GetTx()
|
||||
if t == nil {
|
||||
return errors.New("transaction has already been committed or rolled back")
|
||||
}
|
||||
|
||||
return t.Rollback()
|
||||
}
|
||||
|
||||
// Commit the transaction. Calling this method on read-only transactions
|
||||
// will return an error.
|
||||
func (tx *Tx) Commit() error {
|
||||
return tx.tx.Commit()
|
||||
t := tx.conn.Conn.GetTx()
|
||||
if t == nil {
|
||||
return errors.New("transaction has already been committed or rolled back")
|
||||
}
|
||||
|
||||
return t.Commit()
|
||||
}
|
||||
|
||||
// Query the database withing the transaction and returns the result.
|
||||
@@ -208,15 +270,15 @@ func (tx *Tx) Prepare(q string) (*Statement, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = pq.Prepare(newQueryContext(tx.db, tx, nil))
|
||||
err = pq.Prepare(newQueryContext(tx.conn, nil))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Statement{
|
||||
pq: pq,
|
||||
db: tx.db,
|
||||
tx: tx,
|
||||
pq: pq,
|
||||
conn: tx.conn,
|
||||
tx: tx,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -225,9 +287,9 @@ func (tx *Tx) Prepare(q string) (*Statement, error) {
|
||||
// is valid until the DB closes.
|
||||
// It's safe for concurrent use by multiple goroutines.
|
||||
type Statement struct {
|
||||
pq query.Query
|
||||
db *DB
|
||||
tx *Tx
|
||||
pq query.Query
|
||||
conn *Connection
|
||||
tx *Tx
|
||||
}
|
||||
|
||||
// Query the database and return the result.
|
||||
@@ -236,12 +298,12 @@ func (s *Statement) Query(args ...any) (*Result, error) {
|
||||
var r *statement.Result
|
||||
var err error
|
||||
|
||||
r, err = s.pq.Run(newQueryContext(s.db, s.tx, argsToParams(args)))
|
||||
r, err = s.pq.Run(newQueryContext(s.conn, argsToParams(args)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Result{result: r, ctx: s.db.ctx}, nil
|
||||
return &Result{result: r, ctx: s.conn.db.ctx}, nil
|
||||
}
|
||||
|
||||
func argsToParams(args []interface{}) []environment.Param {
|
||||
@@ -310,6 +372,7 @@ func (s *Statement) Exec(args ...any) (err error) {
|
||||
type Result struct {
|
||||
result *statement.Result
|
||||
ctx context.Context
|
||||
conn *Connection
|
||||
}
|
||||
|
||||
func (r *Result) Iterate(fn func(r *Row) error) error {
|
||||
@@ -366,12 +429,12 @@ func (r *Result) Columns() []string {
|
||||
break
|
||||
}
|
||||
|
||||
fields := make([]string, len(po.Exprs))
|
||||
columns := make([]string, len(po.Exprs))
|
||||
for i := range po.Exprs {
|
||||
fields[i] = po.Exprs[i].String()
|
||||
columns[i] = po.Exprs[i].String()
|
||||
}
|
||||
|
||||
return fields
|
||||
return columns
|
||||
}
|
||||
}
|
||||
|
||||
@@ -385,7 +448,9 @@ func (r *Result) Close() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.result.Close()
|
||||
err = r.result.Close()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Result) MarshalJSON() ([]byte, error) {
|
||||
@@ -427,18 +492,13 @@ func (r *Result) MarshalJSONTo(w io.Writer) error {
|
||||
return buf.Flush()
|
||||
}
|
||||
|
||||
func newQueryContext(db *DB, tx *Tx, params []environment.Param) *query.Context {
|
||||
ctx := query.Context{
|
||||
Ctx: db.ctx,
|
||||
DB: db.DB,
|
||||
func newQueryContext(conn *Connection, params []environment.Param) *query.Context {
|
||||
return &query.Context{
|
||||
Ctx: conn.db.ctx,
|
||||
DB: conn.db.DB,
|
||||
Conn: conn.Conn,
|
||||
Params: params,
|
||||
}
|
||||
|
||||
if tx != nil {
|
||||
ctx.Tx = tx.tx
|
||||
}
|
||||
|
||||
return &ctx
|
||||
}
|
||||
|
||||
type Row struct {
|
||||
|
161
db_test.go
161
db_test.go
@@ -3,7 +3,6 @@ package chai_test
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -17,11 +16,17 @@ import (
|
||||
func ExampleTx() {
|
||||
db, err := chai.Open(":memory:")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
panic(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
tx, err := db.Begin(true)
|
||||
conn, err := db.Connect()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
tx, err := conn.Begin(true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -98,7 +103,11 @@ func TestOpen(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
res1, err := db.Query("SELECT * FROM __chai_catalog")
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
res1, err := conn.Query("SELECT * FROM __chai_catalog")
|
||||
require.NoError(t, err)
|
||||
defer res1.Close()
|
||||
|
||||
@@ -144,7 +153,11 @@ func TestQueryRow(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
tx, err := db.Begin(true)
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
tx, err := conn.Begin(true)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tx.Exec(`
|
||||
@@ -165,7 +178,11 @@ func TestQueryRow(t *testing.T) {
|
||||
require.Equal(t, 1, a)
|
||||
require.Equal(t, "foo", b)
|
||||
|
||||
tx, err := db.Begin(false)
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
tx, err := conn.Begin(false)
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
@@ -182,7 +199,11 @@ func TestQueryRow(t *testing.T) {
|
||||
require.True(t, chai.IsNotFoundError(err))
|
||||
require.Nil(t, r)
|
||||
|
||||
tx, err := db.Begin(false)
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
tx, err := conn.Begin(false)
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
r, err = tx.QueryRow("SELECT * FROM test WHERE a > 100")
|
||||
@@ -196,10 +217,14 @@ func TestPrepareThreadSafe(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec("CREATE TABLE test(a int unique, b text); INSERT INTO test(a, b) VALUES (1, 'a'), (2, 'a')")
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.Exec("CREATE TABLE test(a int unique, b text); INSERT INTO test(a, b) VALUES (1, 'a'), (2, 'a')")
|
||||
require.NoError(t, err)
|
||||
|
||||
stmt, err := db.Prepare("SELECT COUNT(a) FROM test WHERE a < ? GROUP BY b ORDER BY a DESC LIMIT 5")
|
||||
stmt, err := conn.Prepare("SELECT COUNT(a) FROM test WHERE a < ? GROUP BY b ORDER BY a DESC LIMIT 5")
|
||||
require.NoError(t, err)
|
||||
|
||||
g, _ := errgroup.WithContext(context.Background())
|
||||
@@ -240,7 +265,11 @@ func TestIterateDeepCopy(t *testing.T) {
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
res, err := db.Query(`SELECT * FROM foo ORDER BY a DESC`)
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
res, err := conn.Query(`SELECT * FROM foo ORDER BY a DESC`)
|
||||
require.NoError(t, err)
|
||||
|
||||
type item struct {
|
||||
@@ -263,115 +292,3 @@ func TestIterateDeepCopy(t *testing.T) {
|
||||
require.Equal(t, &item{A: 2, B: "sample text 2"}, items[0])
|
||||
require.Equal(t, &item{A: 1, B: "sample text 1"}, items[1])
|
||||
}
|
||||
|
||||
func BenchmarkSelect(b *testing.B) {
|
||||
for size := 1; size <= 10000; size *= 10 {
|
||||
b.Run(fmt.Sprintf("%.05d", size), func(b *testing.B) {
|
||||
db, err := chai.Open(":memory:")
|
||||
require.NoError(b, err)
|
||||
|
||||
err = db.Exec("CREATE TABLE foo")
|
||||
require.NoError(b, err)
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
err = db.Exec("INSERT INTO foo(a, b) VALUES (1, 2);")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
res, _ := db.Query("SELECT * FROM foo")
|
||||
res.Iterate(func(d *chai.Row) error { return nil })
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSelectWhere(b *testing.B) {
|
||||
for size := 1; size <= 10000; size *= 10 {
|
||||
b.Run(fmt.Sprintf("%.05d", size), func(b *testing.B) {
|
||||
db, err := chai.Open(":memory:")
|
||||
require.NoError(b, err)
|
||||
|
||||
err = db.Exec("CREATE TABLE foo")
|
||||
require.NoError(b, err)
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
err = db.Exec("INSERT INTO foo(a, b) VALUES (1, 2);")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
res, _ := db.Query("SELECT b FROM foo WHERE a > 0")
|
||||
res.Iterate(func(d *chai.Row) error { return nil })
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPreparedSelectWhere(b *testing.B) {
|
||||
for size := 1; size <= 10000; size *= 10 {
|
||||
b.Run(fmt.Sprintf("%.05d", size), func(b *testing.B) {
|
||||
db, err := chai.Open(":memory:")
|
||||
require.NoError(b, err)
|
||||
|
||||
err = db.Exec("CREATE TABLE foo")
|
||||
require.NoError(b, err)
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
err = db.Exec("INSERT INTO foo(a, b) VALUES (1, 2);")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
p, _ := db.Prepare("SELECT b FROM foo WHERE a > 0")
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
res, _ := p.Query()
|
||||
res.Iterate(func(d *chai.Row) error { return nil })
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSelectPk(b *testing.B) {
|
||||
for size := 1; size <= 10000; size *= 10 {
|
||||
b.Run(fmt.Sprintf("%.05d", size), func(b *testing.B) {
|
||||
db, err := chai.Open(":memory:")
|
||||
require.NoError(b, err)
|
||||
|
||||
err = db.Exec("CREATE TABLE foo(a INT PRIMARY KEY)")
|
||||
require.NoError(b, err)
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
err = db.Exec("INSERT INTO foo(a) VALUES (?)", i)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
res, _ := db.Query("SELECT * FROM foo WHERE a = ?", size-1)
|
||||
res.Iterate(func(d *chai.Row) error { return nil })
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkInsert(b *testing.B) {
|
||||
db, err := chai.Open(b.TempDir())
|
||||
require.NoError(b, err)
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec("CREATE TABLE foo(a INT)")
|
||||
require.NoError(b, err)
|
||||
|
||||
stmt, err := db.Prepare("INSERT INTO foo(a) VALUES (?)")
|
||||
require.NoError(b, err)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for j := 0; j < 100; j++ {
|
||||
stmt.Exec(j)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
117
driver/driver.go
117
driver/driver.go
@@ -5,7 +5,6 @@ import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"io"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -39,13 +38,10 @@ func (d sqlDriver) OpenConnector(name string) (driver.Connector, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := &connector{
|
||||
return &connector{
|
||||
db: db,
|
||||
driver: d,
|
||||
}
|
||||
runtime.SetFinalizer(c, (*connector).Close)
|
||||
|
||||
return c, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -54,15 +50,21 @@ var (
|
||||
)
|
||||
|
||||
type connector struct {
|
||||
driver driver.Driver
|
||||
|
||||
db *chai.DB
|
||||
|
||||
driver driver.Driver
|
||||
db *chai.DB
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
|
||||
return &conn{db: c.db}, nil
|
||||
cc, err := c.db.Connect()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &conn{
|
||||
db: c.db,
|
||||
conn: cc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *connector) Driver() driver.Driver {
|
||||
@@ -80,8 +82,8 @@ func (c *connector) Close() error {
|
||||
// conn represents a connection to the Chai database.
|
||||
// It implements the database/sql/driver.Conn interface.
|
||||
type conn struct {
|
||||
db *chai.DB
|
||||
tx *chai.Tx
|
||||
db *chai.DB
|
||||
conn *chai.Connection
|
||||
}
|
||||
|
||||
// Prepare returns a prepared statement, bound to this connection.
|
||||
@@ -91,14 +93,7 @@ func (c *conn) Prepare(q string) (driver.Stmt, error) {
|
||||
|
||||
// PrepareContext returns a prepared statement, bound to this connection.
|
||||
func (c *conn) PrepareContext(ctx context.Context, q string) (driver.Stmt, error) {
|
||||
var s *chai.Statement
|
||||
var err error
|
||||
|
||||
if c.tx != nil {
|
||||
s, err = c.tx.Prepare(q)
|
||||
} else {
|
||||
s, err = c.db.Prepare(q)
|
||||
}
|
||||
s, err := c.conn.Prepare(q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -110,11 +105,7 @@ func (c *conn) PrepareContext(ctx context.Context, q string) (driver.Stmt, error
|
||||
|
||||
// Close closes any ongoing transaction.
|
||||
func (c *conn) Close() error {
|
||||
if c.tx != nil {
|
||||
return c.tx.Rollback()
|
||||
}
|
||||
|
||||
return nil
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
// Begin starts and returns a new transaction.
|
||||
@@ -122,6 +113,15 @@ func (c *conn) Begin() (driver.Tx, error) {
|
||||
return c.BeginTx(context.Background(), driver.TxOptions{})
|
||||
}
|
||||
|
||||
func (c *conn) ResetSession(ctx context.Context) error {
|
||||
err := c.conn.Conn.Reset()
|
||||
if err != nil {
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BeginTx starts and returns a new transaction.
|
||||
// It uses the ReadOnly option to determine whether to start a read-only or read/write transaction.
|
||||
// If the Isolation option is non zero, an error is returned.
|
||||
@@ -130,26 +130,9 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
|
||||
return nil, errors.New("isolation levels are not supported")
|
||||
}
|
||||
|
||||
db := c.db.WithContext(ctx)
|
||||
|
||||
// if the ReadOnly flag is explicitly specified, create a read-only transaction,
|
||||
// otherwise create a read/write transaction.
|
||||
var err error
|
||||
c.tx, err = db.Begin(!opts.ReadOnly)
|
||||
|
||||
return c, err
|
||||
}
|
||||
|
||||
func (c *conn) Commit() error {
|
||||
err := c.tx.Commit()
|
||||
c.tx = nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) Rollback() error {
|
||||
err := c.tx.Rollback()
|
||||
c.tx = nil
|
||||
return err
|
||||
return c.conn.Begin(!opts.ReadOnly)
|
||||
}
|
||||
|
||||
// Stmt is a prepared statement. It is bound to a Conn and not
|
||||
@@ -176,18 +159,18 @@ func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver
|
||||
default:
|
||||
}
|
||||
|
||||
return result{}, s.stmt.Exec(driverNamedValueToParams(args)...)
|
||||
return execResult{}, s.stmt.Exec(namedValueToParams(args)...)
|
||||
}
|
||||
|
||||
type result struct{}
|
||||
type execResult struct{}
|
||||
|
||||
// LastInsertId is not supported and returns an error.
|
||||
func (r result) LastInsertId() (int64, error) {
|
||||
func (r execResult) LastInsertId() (int64, error) {
|
||||
return 0, errors.New("not supported")
|
||||
}
|
||||
|
||||
// RowsAffected is not supported and returns an error.
|
||||
func (r result) RowsAffected() (int64, error) {
|
||||
func (r execResult) RowsAffected() (int64, error) {
|
||||
return 0, errors.New("not supported")
|
||||
}
|
||||
|
||||
@@ -204,17 +187,17 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
|
||||
default:
|
||||
}
|
||||
|
||||
res, err := s.stmt.Query(driverNamedValueToParams(args)...)
|
||||
res, err := s.stmt.Query(namedValueToParams(args)...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rs := newRecordStream(res)
|
||||
rs := newRows(res)
|
||||
rs.columns = res.Columns()
|
||||
return rs, nil
|
||||
}
|
||||
|
||||
func driverNamedValueToParams(args []driver.NamedValue) []any {
|
||||
func namedValueToParams(args []driver.NamedValue) []any {
|
||||
params := make([]any, len(args))
|
||||
for i, arg := range args {
|
||||
var p environment.Param
|
||||
@@ -233,35 +216,35 @@ func (s stmt) Close() error {
|
||||
|
||||
var errStop = errors.New("stop")
|
||||
|
||||
type recordStream struct {
|
||||
type Rows struct {
|
||||
res *chai.Result
|
||||
cancelFn func()
|
||||
c chan recordRow
|
||||
c chan Row
|
||||
wg sync.WaitGroup
|
||||
columns []string
|
||||
}
|
||||
|
||||
type recordRow struct {
|
||||
type Row struct {
|
||||
r *chai.Row
|
||||
err error
|
||||
}
|
||||
|
||||
func newRecordStream(res *chai.Result) *recordStream {
|
||||
func newRows(res *chai.Result) *Rows {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
ds := recordStream{
|
||||
rs := Rows{
|
||||
res: res,
|
||||
cancelFn: cancel,
|
||||
c: make(chan recordRow),
|
||||
c: make(chan Row),
|
||||
}
|
||||
ds.wg.Add(1)
|
||||
rs.wg.Add(1)
|
||||
|
||||
go ds.iterate(ctx)
|
||||
go rs.iterate(ctx)
|
||||
|
||||
return &ds
|
||||
return &rs
|
||||
}
|
||||
|
||||
func (rs *recordStream) iterate(ctx context.Context) {
|
||||
func (rs *Rows) iterate(ctx context.Context) {
|
||||
defer rs.wg.Done()
|
||||
defer close(rs.c)
|
||||
|
||||
@@ -275,7 +258,7 @@ func (rs *recordStream) iterate(ctx context.Context) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return errStop
|
||||
case rs.c <- recordRow{
|
||||
case rs.c <- Row{
|
||||
r: r,
|
||||
}:
|
||||
|
||||
@@ -292,7 +275,7 @@ func (rs *recordStream) iterate(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
rs.c <- recordRow{
|
||||
rs.c <- Row{
|
||||
err: err,
|
||||
}
|
||||
return
|
||||
@@ -300,19 +283,19 @@ func (rs *recordStream) iterate(ctx context.Context) {
|
||||
}
|
||||
|
||||
// Columns returns the fields selected by the SELECT statement.
|
||||
func (rs *recordStream) Columns() []string {
|
||||
func (rs *Rows) Columns() []string {
|
||||
return rs.res.Columns()
|
||||
}
|
||||
|
||||
// Close closes the rows iterator.
|
||||
func (rs *recordStream) Close() error {
|
||||
func (rs *Rows) Close() error {
|
||||
rs.cancelFn()
|
||||
rs.wg.Wait()
|
||||
return rs.res.Close()
|
||||
}
|
||||
|
||||
func (rs *recordStream) Next(dest []driver.Value) error {
|
||||
rs.c <- recordRow{}
|
||||
func (rs *Rows) Next(dest []driver.Value) error {
|
||||
rs.c <- Row{}
|
||||
|
||||
row, ok := <-rs.c
|
||||
if !ok {
|
||||
|
@@ -38,8 +38,14 @@ func Example() {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
conn, err := db.Connect()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Query some rows
|
||||
stream, err := db.Query("SELECT * FROM user WHERE id > ?", 1)
|
||||
stream, err := conn.Query("SELECT * FROM user WHERE id > ?", 1)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@@ -404,6 +404,7 @@ golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8U
|
||||
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
|
||||
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
|
||||
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
@@ -489,6 +490,7 @@ golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su
|
||||
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo=
|
||||
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/oauth2 v0.0.0-20170207211851-4464e7848382/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
@@ -500,6 +502,7 @@ golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4Iltr
|
||||
golang.org/x/oauth2 v0.0.0-20210514164344-f6687ab2804c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
|
||||
golang.org/x/oauth2 v0.12.0/go.mod h1:A74bZ3aGXgCY0qaIC9Ahg6Lglin4AMAco8cIv9baba4=
|
||||
golang.org/x/oauth2 v0.16.0/go.mod h1:hqZ+0LWXsiVoZpeld6jVt06P3adbS2Uu911W1SsJv2o=
|
||||
golang.org/x/perf v0.0.0-20230113213139-801c7ef9e5c5 h1:ObuXPmIgI4ZMyQLIz48cJYgSyWdjUXc2SZAdyJMwEAU=
|
||||
golang.org/x/perf v0.0.0-20230113213139-801c7ef9e5c5/go.mod h1:UBKtEnL8aqnd+0JHqZ+2qoMDwtuy6cYhhKNoHLBiTQc=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
|
@@ -441,7 +441,11 @@ func TestReadOnlyTables(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
res, err := db.Query(`
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
res, err := conn.Query(`
|
||||
CREATE TABLE foo (a int, b double unique, c text);
|
||||
CREATE INDEX idx_foo_a ON foo(a, c);
|
||||
SELECT * FROM __chai_catalog
|
||||
@@ -565,20 +569,28 @@ func TestCatalogConcurrency(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
conn1, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn1.Close()
|
||||
|
||||
// create a table
|
||||
err = db.Exec(`
|
||||
err = conn1.Exec(`
|
||||
CREATE TABLE test (a int);
|
||||
CREATE INDEX idx_test_a ON test(a);
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// start a transaction rt1
|
||||
rt1, err := db.Begin(false)
|
||||
rt1, err := conn1.Begin(false)
|
||||
require.NoError(t, err)
|
||||
defer rt1.Rollback()
|
||||
|
||||
conn2, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn2.Close()
|
||||
|
||||
// start a transaction wt2
|
||||
wt1, err := db.Begin(true)
|
||||
wt1, err := conn2.Begin(true)
|
||||
require.NoError(t, err)
|
||||
defer wt1.Rollback()
|
||||
|
||||
|
68
internal/database/connection.go
Normal file
68
internal/database/connection.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
)
|
||||
|
||||
type Connection struct {
|
||||
db *Database
|
||||
ctx context.Context
|
||||
tx *Transaction
|
||||
}
|
||||
|
||||
// BeginTx starts a new transaction with the given options.
|
||||
// If opts is empty, it will use the default options.
|
||||
// The returned transaction must be closed either by calling Rollback or Commit.
|
||||
func (c *Connection) BeginTx(opts *TxOptions) (*Transaction, error) {
|
||||
if c.ctx.Err() != nil {
|
||||
return nil, errors.New("connection is closed")
|
||||
}
|
||||
|
||||
if c.tx != nil {
|
||||
return nil, errors.New("cannot open a transaction within a transaction")
|
||||
}
|
||||
|
||||
tx, err := c.db.beginTx(opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.tx = tx
|
||||
tx.conn = c
|
||||
tx.OnRollbackHooks = append(tx.OnRollbackHooks, c.releaseAttachedTx)
|
||||
tx.OnCommitHooks = append(tx.OnCommitHooks, c.releaseAttachedTx)
|
||||
|
||||
return tx, nil
|
||||
}
|
||||
|
||||
func (c *Connection) Reset() error {
|
||||
if c.tx != nil {
|
||||
return errors.New("cannot reset a connection with an attached transaction")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connection) releaseAttachedTx() {
|
||||
if c.tx != nil {
|
||||
c.tx = nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetAttachedTx returns the transaction attached to the connection, if any.
|
||||
// The returned transaction is not thread safe.
|
||||
func (c *Connection) GetTx() *Transaction {
|
||||
return c.tx
|
||||
}
|
||||
|
||||
func (c *Connection) Close() error {
|
||||
defer c.db.connectionWg.Done()
|
||||
|
||||
if c.tx != nil {
|
||||
return c.tx.Rollback()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@@ -2,6 +2,7 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -19,12 +20,12 @@ type Database struct {
|
||||
catalogMu sync.RWMutex
|
||||
catalog *Catalog
|
||||
|
||||
// If this is non-nil, the user is running an explicit transaction
|
||||
// using the BEGIN statement.
|
||||
// Only one attached transaction can be run at a time and any calls to DB.Begin()
|
||||
// will cause an error until that transaction is rolled back or commited.
|
||||
attachedTransaction *Transaction
|
||||
attachedTxMu sync.Mutex
|
||||
// context used to notify all connections that the database is closing.
|
||||
closeContext context.Context
|
||||
closeCancel context.CancelFunc
|
||||
|
||||
// waitgroup to wait for all connections to be closed.
|
||||
connectionWg sync.WaitGroup
|
||||
|
||||
// This is used to prevent creating a new transaction
|
||||
// during certain operations (commit, close, etc.)
|
||||
@@ -33,11 +34,11 @@ type Database struct {
|
||||
// This limits the number of write transactions to 1.
|
||||
writetxmu sync.Mutex
|
||||
|
||||
// TransactionIDs is used to assign transaction an ID at runtime.
|
||||
// transactionIDs is used to assign transaction an ID at runtime.
|
||||
// Since transaction IDs are not persisted and not used for concurrent
|
||||
// access, we can use 8 bytes ids that will be reset every time
|
||||
// the database restarts.
|
||||
TransactionIDs uint64
|
||||
transactionIDs atomic.Uint64
|
||||
|
||||
closeOnce sync.Once
|
||||
|
||||
@@ -62,10 +63,6 @@ type CatalogLoader interface {
|
||||
type TxOptions struct {
|
||||
// Open a read-only transaction.
|
||||
ReadOnly bool
|
||||
// Set the transaction as global at the database level.
|
||||
// Any queries run by the database will use that transaction until it is
|
||||
// rolled back or commited.
|
||||
Attached bool
|
||||
}
|
||||
|
||||
func Open(path string, opts *Options) (*Database, error) {
|
||||
@@ -82,6 +79,9 @@ func Open(path string, opts *Options) (*Database, error) {
|
||||
Engine: store,
|
||||
}
|
||||
|
||||
// create a context that will be cancelled when the database is closed.
|
||||
db.closeContext, db.closeCancel = context.WithCancel(context.Background())
|
||||
|
||||
// ensure the rollback segment doesn't contain any data that needs to be rolled back
|
||||
// due to a previous crash.
|
||||
err = db.Engine.Recover()
|
||||
@@ -129,6 +129,9 @@ func (db *Database) Close() error {
|
||||
var err error
|
||||
|
||||
db.closeOnce.Do(func() {
|
||||
db.closeCancel()
|
||||
|
||||
db.connectionWg.Wait()
|
||||
err = db.closeDatabase()
|
||||
})
|
||||
|
||||
@@ -136,16 +139,8 @@ func (db *Database) Close() error {
|
||||
}
|
||||
|
||||
func (db *Database) closeDatabase() error {
|
||||
// If there is an attached transaction
|
||||
// it must be rolled back before closing the engine.
|
||||
if tx := db.GetAttachedTx(); tx != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
db.writetxmu.Lock()
|
||||
defer db.writetxmu.Unlock()
|
||||
|
||||
// release all sequences
|
||||
tx, err := db.beginTx(nil)
|
||||
tx, err := db.beginTxUnlocked(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -171,20 +166,29 @@ func (db *Database) closeDatabase() error {
|
||||
return db.Engine.Close()
|
||||
}
|
||||
|
||||
// GetAttachedTx returns the transaction attached to the database. It returns nil if there is no
|
||||
// such transaction.
|
||||
// The returned transaction is not thread safe.
|
||||
func (db *Database) GetAttachedTx() *Transaction {
|
||||
db.attachedTxMu.Lock()
|
||||
defer db.attachedTxMu.Unlock()
|
||||
// Connect returns a new connection to the database.
|
||||
// The returned connection is not thread safe.
|
||||
// It is the caller's responsibility to close the connection.
|
||||
func (db *Database) Connect() (*Connection, error) {
|
||||
if db.closeContext.Err() != nil {
|
||||
return nil, errors.New("database is closed")
|
||||
}
|
||||
|
||||
return db.attachedTransaction
|
||||
db.connectionWg.Add(1)
|
||||
return &Connection{
|
||||
db: db,
|
||||
ctx: db.closeContext,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Begin starts a new transaction with default options.
|
||||
// The returned transaction must be closed either by calling Rollback or Commit.
|
||||
func (db *Database) Begin(writable bool) (*Transaction, error) {
|
||||
return db.BeginTx(&TxOptions{
|
||||
if db.closeContext.Err() != nil {
|
||||
return nil, errors.New("database is closed")
|
||||
}
|
||||
|
||||
return db.beginTx(&TxOptions{
|
||||
ReadOnly: !writable,
|
||||
})
|
||||
}
|
||||
@@ -192,10 +196,11 @@ func (db *Database) Begin(writable bool) (*Transaction, error) {
|
||||
// BeginTx starts a new transaction with the given options.
|
||||
// If opts is empty, it will use the default options.
|
||||
// The returned transaction must be closed either by calling Rollback or Commit.
|
||||
// If the Attached option is passed, it opens a database level transaction, which gets
|
||||
// attached to the database and prevents any other transaction to be opened afterwards
|
||||
// until it gets rolled back or commited.
|
||||
func (db *Database) BeginTx(opts *TxOptions) (*Transaction, error) {
|
||||
func (db *Database) beginTx(opts *TxOptions) (*Transaction, error) {
|
||||
if db.closeContext.Err() != nil {
|
||||
return nil, errors.New("database is closed")
|
||||
}
|
||||
|
||||
db.txmu.RLock()
|
||||
defer db.txmu.RUnlock()
|
||||
|
||||
@@ -207,18 +212,11 @@ func (db *Database) BeginTx(opts *TxOptions) (*Transaction, error) {
|
||||
db.writetxmu.Lock()
|
||||
}
|
||||
|
||||
db.attachedTxMu.Lock()
|
||||
defer db.attachedTxMu.Unlock()
|
||||
|
||||
if db.attachedTransaction != nil {
|
||||
return nil, errors.New("cannot open a transaction within a transaction")
|
||||
}
|
||||
|
||||
return db.beginTx(opts)
|
||||
return db.beginTxUnlocked(opts)
|
||||
}
|
||||
|
||||
// beginTx creates a transaction without locks.
|
||||
func (db *Database) beginTx(opts *TxOptions) (*Transaction, error) {
|
||||
// beginTxUnlocked creates a transaction without locks.
|
||||
func (db *Database) beginTxUnlocked(opts *TxOptions) (*Transaction, error) {
|
||||
if opts == nil {
|
||||
opts = &TxOptions{}
|
||||
}
|
||||
@@ -235,7 +233,7 @@ func (db *Database) beginTx(opts *TxOptions) (*Transaction, error) {
|
||||
Engine: db.Engine,
|
||||
Session: sess,
|
||||
Writable: !opts.ReadOnly,
|
||||
ID: atomic.AddUint64(&db.TransactionIDs, 1),
|
||||
ID: db.transactionIDs.Add(1),
|
||||
Catalog: db.Catalog(),
|
||||
TxStart: time.Now(),
|
||||
}
|
||||
@@ -244,12 +242,6 @@ func (db *Database) beginTx(opts *TxOptions) (*Transaction, error) {
|
||||
tx.WriteTxMu = &db.writetxmu
|
||||
}
|
||||
|
||||
if opts.Attached {
|
||||
db.attachedTransaction = &tx
|
||||
tx.OnRollbackHooks = append(tx.OnRollbackHooks, db.releaseAttachedTx)
|
||||
tx.OnCommitHooks = append(tx.OnCommitHooks, db.releaseAttachedTx)
|
||||
}
|
||||
|
||||
return &tx, nil
|
||||
}
|
||||
|
||||
@@ -265,12 +257,3 @@ func (db *Database) SetCatalog(c *Catalog) {
|
||||
db.catalog = c
|
||||
db.catalogMu.Unlock()
|
||||
}
|
||||
|
||||
func (db *Database) releaseAttachedTx() {
|
||||
db.attachedTxMu.Lock()
|
||||
defer db.attachedTxMu.Unlock()
|
||||
|
||||
if db.attachedTransaction != nil {
|
||||
db.attachedTransaction = nil
|
||||
}
|
||||
}
|
||||
|
@@ -21,7 +21,11 @@ func TestConcurrentTransactionManagement(t *testing.T) {
|
||||
|
||||
go func() {
|
||||
// 1. Start transaction T1.
|
||||
tx, err := db.Begin(true)
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
tx, err := conn.Begin(true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Start transaction T2.
|
||||
@@ -42,7 +46,11 @@ func TestConcurrentTransactionManagement(t *testing.T) {
|
||||
|
||||
// 2. Attempt to start transaction T2.
|
||||
// Waits for T1 to finish.
|
||||
tx, err := db.Begin(true)
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
tx, err := conn.Begin(true)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, tx.Rollback())
|
||||
|
||||
|
@@ -13,7 +13,8 @@ import (
|
||||
// Transaction is either read-only or read/write. Read-only can be used to read tables
|
||||
// and read/write can be used to read, create, delete and modify tables.
|
||||
type Transaction struct {
|
||||
db *Database
|
||||
db *Database
|
||||
conn *Connection
|
||||
|
||||
// Timestamp at which the transaction was created.
|
||||
// The timestamp must use the local timezone.
|
||||
@@ -33,6 +34,10 @@ type Transaction struct {
|
||||
catalogWriter *CatalogWriter
|
||||
}
|
||||
|
||||
func (tx *Transaction) Connection() *Connection {
|
||||
return tx.conn
|
||||
}
|
||||
|
||||
// Rollback the transaction. Can be used safely after commit.
|
||||
func (tx *Transaction) Rollback() error {
|
||||
err := tx.Session.Close()
|
||||
|
@@ -249,6 +249,10 @@ func TestQueries(t *testing.T) {
|
||||
db, err := chai.Open(filepath.Join(dir, "pebble"))
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
r, err := db.QueryRow(`
|
||||
CREATE TABLE test(a INT);
|
||||
INSERT INTO test (a) VALUES (1), (2), (3), (4);
|
||||
@@ -261,7 +265,7 @@ func TestQueries(t *testing.T) {
|
||||
require.Equal(t, 4, count)
|
||||
|
||||
t.Run("ORDER BY", func(t *testing.T) {
|
||||
st, err := db.Query("SELECT * FROM test ORDER BY a DESC")
|
||||
st, err := conn.Query("SELECT * FROM test ORDER BY a DESC")
|
||||
require.NoError(t, err)
|
||||
defer st.Close()
|
||||
|
||||
@@ -297,7 +301,11 @@ func TestQueries(t *testing.T) {
|
||||
db, err := chai.Open(filepath.Join(dir, "pebble"))
|
||||
require.NoError(t, err)
|
||||
|
||||
st, err := db.Query(`
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
st, err := conn.Query(`
|
||||
CREATE TABLE test(a INT);
|
||||
INSERT INTO test (a) VALUES (1), (2), (3), (4);
|
||||
UPDATE test SET a = 5;
|
||||
@@ -317,16 +325,23 @@ func TestQueries(t *testing.T) {
|
||||
db, err := chai.Open(filepath.Join(dir, "pebble"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Exec("CREATE TABLE test(a INT)")
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.Exec("CREATE TABLE test(a INT)")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Update(func(tx *chai.Tx) error {
|
||||
for i := 1; i < 200; i++ {
|
||||
err = tx.Exec("INSERT INTO test (a) VALUES (?)", i)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
tx, err := conn.Begin(true)
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
for i := 1; i < 200; i++ {
|
||||
err = tx.Exec("INSERT INTO test (a) VALUES (?)", i)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := db.QueryRow(`
|
||||
@@ -349,19 +364,26 @@ func TestQueriesSameTransaction(t *testing.T) {
|
||||
db, err := chai.Open(filepath.Join(dir, "pebble"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Update(func(tx *chai.Tx) error {
|
||||
r, err := tx.QueryRow(`
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
tx, err := conn.Begin(true)
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
r, err := tx.QueryRow(`
|
||||
CREATE TABLE test(a INT);
|
||||
INSERT INTO test (a) VALUES (1), (2), (3), (4);
|
||||
SELECT COUNT(*) FROM test;
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
var count int
|
||||
err = r.Scan(&count)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 4, count)
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
var count int
|
||||
err = r.Scan(&count)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 4, count)
|
||||
|
||||
err = tx.Commit()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
@@ -371,14 +393,21 @@ func TestQueriesSameTransaction(t *testing.T) {
|
||||
db, err := chai.Open(filepath.Join(dir, "pebble"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Update(func(tx *chai.Tx) error {
|
||||
err = tx.Exec(`
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
tx, err := conn.Begin(true)
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
err = tx.Exec(`
|
||||
CREATE TABLE test(a INT);
|
||||
INSERT INTO test (a) VALUES (1), (2), (3), (4);
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tx.Commit()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
@@ -388,21 +417,28 @@ func TestQueriesSameTransaction(t *testing.T) {
|
||||
db, err := chai.Open(filepath.Join(dir, "pebble"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Update(func(tx *chai.Tx) error {
|
||||
st, err := tx.Query(`
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
tx, err := conn.Begin(true)
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
st, err := tx.Query(`
|
||||
CREATE TABLE test(a INT);
|
||||
INSERT INTO test (a) VALUES (1), (2), (3), (4);
|
||||
UPDATE test SET a = 5;
|
||||
SELECT * FROM test;
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
defer st.Close()
|
||||
var buf bytes.Buffer
|
||||
err = st.MarshalJSONTo(&buf)
|
||||
require.NoError(t, err)
|
||||
require.JSONEq(t, `[{"a": 5},{"a": 5},{"a": 5},{"a": 5}]`, buf.String())
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer st.Close()
|
||||
var buf bytes.Buffer
|
||||
err = st.MarshalJSONTo(&buf)
|
||||
require.NoError(t, err)
|
||||
require.JSONEq(t, `[{"a": 5},{"a": 5},{"a": 5},{"a": 5}]`, buf.String())
|
||||
|
||||
err = tx.Commit()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
@@ -412,20 +448,27 @@ func TestQueriesSameTransaction(t *testing.T) {
|
||||
db, err := chai.Open(filepath.Join(dir, "pebble"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Update(func(tx *chai.Tx) error {
|
||||
r, err := tx.QueryRow(`
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
tx, err := conn.Begin(true)
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
r, err := tx.QueryRow(`
|
||||
CREATE TABLE test(a INT);
|
||||
INSERT INTO test (a) VALUES (1), (2), (3), (4), (5), (6), (7), (8), (9), (10);
|
||||
DELETE FROM test WHERE a > 2;
|
||||
SELECT COUNT(*) FROM test;
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
var count int
|
||||
err = r.Scan(&count)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, count)
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
var count int
|
||||
err = r.Scan(&count)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, count)
|
||||
|
||||
err = tx.Commit()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
@@ -25,16 +25,12 @@ func New(statements ...statement.Statement) Query {
|
||||
type Context struct {
|
||||
Ctx context.Context
|
||||
DB *database.Database
|
||||
Tx *database.Transaction
|
||||
Conn *database.Connection
|
||||
Params []environment.Param
|
||||
}
|
||||
|
||||
func (c *Context) GetTx() *database.Transaction {
|
||||
if c.Tx != nil {
|
||||
return c.Tx
|
||||
}
|
||||
|
||||
return c.DB.GetAttachedTx()
|
||||
return c.Conn.GetTx()
|
||||
}
|
||||
|
||||
// Prepare the statements by calling their Prepare methods.
|
||||
@@ -62,7 +58,7 @@ func (q *Query) Prepare(context *Context) error {
|
||||
if tx == nil {
|
||||
tx = context.GetTx()
|
||||
if tx == nil {
|
||||
tx, err = context.DB.BeginTx(&database.TxOptions{
|
||||
tx, err = context.Conn.BeginTx(&database.TxOptions{
|
||||
ReadOnly: true,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -73,8 +69,9 @@ func (q *Query) Prepare(context *Context) error {
|
||||
}
|
||||
|
||||
stmt, err := p.Prepare(&statement.Context{
|
||||
DB: context.DB,
|
||||
Tx: tx,
|
||||
DB: context.DB,
|
||||
Conn: context.Conn,
|
||||
Tx: tx,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -110,10 +107,10 @@ func (q Query) Run(context *Context) (*statement.Result, error) {
|
||||
res = statement.Result{}
|
||||
|
||||
if qa, ok := stmt.(queryAlterer); ok {
|
||||
err = qa.alterQuery(context.DB, &q)
|
||||
err = qa.alterQuery(context.Conn, &q)
|
||||
if err != nil {
|
||||
if tx := context.GetTx(); tx != nil {
|
||||
tx.Rollback()
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
@@ -122,7 +119,7 @@ func (q Query) Run(context *Context) (*statement.Result, error) {
|
||||
}
|
||||
|
||||
if q.tx == nil {
|
||||
q.tx, err = context.DB.BeginTx(&database.TxOptions{
|
||||
q.tx, err = context.Conn.BeginTx(&database.TxOptions{
|
||||
ReadOnly: stmt.IsReadOnly(),
|
||||
})
|
||||
if err != nil {
|
||||
@@ -132,6 +129,7 @@ func (q Query) Run(context *Context) (*statement.Result, error) {
|
||||
|
||||
res, err = stmt.Run(&statement.Context{
|
||||
DB: context.DB,
|
||||
Conn: context.Conn,
|
||||
Tx: q.tx,
|
||||
Params: context.Params,
|
||||
})
|
||||
@@ -185,5 +183,5 @@ func (q Query) Run(context *Context) (*statement.Result, error) {
|
||||
}
|
||||
|
||||
type queryAlterer interface {
|
||||
alterQuery(db *database.Database, q *Query) error
|
||||
alterQuery(conn *database.Connection, q *Query) error
|
||||
}
|
||||
|
@@ -33,6 +33,10 @@ func TestDeleteStmt(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
err = db.Exec("CREATE TABLE test(id INT PRIMARY KEY, a TEXT, b TEXT, c TEXT, d TEXT, e TEXT, n INT)")
|
||||
require.NoError(t, err)
|
||||
err = db.Exec("INSERT INTO test (id, a, b, c, n) VALUES (1, 'foo1', 'bar1', 'baz1', 3)")
|
||||
@@ -42,14 +46,14 @@ func TestDeleteStmt(t *testing.T) {
|
||||
err = db.Exec("INSERT INTO test (id, d, b, e, n) VALUES (3, 'foo3', 'bar2', 'bar3', 1)")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Exec(test.query, test.params...)
|
||||
err = conn.Exec(test.query, test.params...)
|
||||
if test.fails {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
st, err := db.Query("SELECT id FROM test")
|
||||
st, err := conn.Query("SELECT id FROM test")
|
||||
require.NoError(t, err)
|
||||
defer st.Close()
|
||||
|
||||
|
@@ -15,22 +15,26 @@ func TestDropTable(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec("CREATE TABLE test1(a INT UNIQUE); CREATE TABLE test2(a INT); CREATE TABLE test3(a INT)")
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.Exec("CREATE TABLE test1(a INT UNIQUE); CREATE TABLE test2(a INT); CREATE TABLE test3(a INT)")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Exec("DROP TABLE test1")
|
||||
err = conn.Exec("DROP TABLE test1")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Exec("DROP TABLE IF EXISTS test1")
|
||||
err = conn.Exec("DROP TABLE IF EXISTS test1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Dropping a table that doesn't exist without "IF EXISTS"
|
||||
// should return an error.
|
||||
err = db.Exec("DROP TABLE test1")
|
||||
err = conn.Exec("DROP TABLE test1")
|
||||
require.Error(t, err)
|
||||
|
||||
// Assert that no other table has been dropped.
|
||||
res, err := db.Query("SELECT name FROM __chai_catalog WHERE type = 'table'")
|
||||
res, err := conn.Query("SELECT name FROM __chai_catalog WHERE type = 'table'")
|
||||
require.NoError(t, err)
|
||||
var tables []string
|
||||
err = res.Iterate(func(r *chai.Row) error {
|
||||
@@ -49,18 +53,18 @@ func TestDropTable(t *testing.T) {
|
||||
|
||||
// Assert the unique index test1_a_idx, created upon the creation of the table,
|
||||
// has been dropped as well.
|
||||
_, err = db.QueryRow("SELECT 1 FROM __chai_catalog WHERE name = 'test1_a_idx'")
|
||||
_, err = conn.QueryRow("SELECT 1 FROM __chai_catalog WHERE name = 'test1_a_idx'")
|
||||
require.Error(t, err)
|
||||
|
||||
// Assert the rowid sequence test1_seq, created upon the creation of the table,
|
||||
// has been dropped as well.
|
||||
_, err = db.QueryRow("SELECT 1 FROM __chai_catalog WHERE name = 'test1_seq'")
|
||||
_, err = conn.QueryRow("SELECT 1 FROM __chai_catalog WHERE name = 'test1_seq'")
|
||||
require.Error(t, err)
|
||||
_, err = db.QueryRow("SELECT 1 FROM __chai_sequence WHERE name = 'test1_seq'")
|
||||
_, err = conn.QueryRow("SELECT 1 FROM __chai_sequence WHERE name = 'test1_seq'")
|
||||
require.Error(t, err)
|
||||
|
||||
// Dropping a read-only table should fail.
|
||||
err = db.Exec("DROP TABLE __chai_catalog")
|
||||
err = conn.Exec("DROP TABLE __chai_catalog")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
|
@@ -31,10 +31,14 @@ func TestInsertStmt(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec("CREATE TABLE test(a TEXT, b TEXT, c TEXT)")
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.Exec("CREATE TABLE test(a TEXT, b TEXT, c TEXT)")
|
||||
require.NoError(t, err)
|
||||
if withIndexes {
|
||||
err = db.Exec(`
|
||||
err = conn.Exec(`
|
||||
CREATE INDEX idx_a ON test (a);
|
||||
CREATE INDEX idx_b ON test (b);
|
||||
CREATE INDEX idx_c ON test (c);
|
||||
@@ -42,14 +46,14 @@ func TestInsertStmt(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err = db.Exec(test.query, test.params...)
|
||||
err = conn.Exec(test.query, test.params...)
|
||||
if test.fails {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
st, err := db.Query("SELECT * FROM test")
|
||||
st, err := conn.Query("SELECT * FROM test")
|
||||
require.NoError(t, err)
|
||||
defer st.Close()
|
||||
|
||||
@@ -82,13 +86,17 @@ func TestInsertStmt(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE test(a int unique)`)
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.Exec(`CREATE TABLE test(a int unique)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Exec(`insert into test (a) VALUES (1), (1)`)
|
||||
err = conn.Exec(`insert into test (a) VALUES (1), (1)`)
|
||||
require.Error(t, err)
|
||||
|
||||
res, err := db.Query("SELECT * FROM test")
|
||||
res, err := conn.Query("SELECT * FROM test")
|
||||
require.NoError(t, err)
|
||||
defer res.Close()
|
||||
|
||||
@@ -100,13 +108,17 @@ func TestInsertStmt(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE SEQUENCE seq; CREATE TABLE test(a int, b int default NEXT VALUE FOR seq)`)
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.Exec(`CREATE SEQUENCE seq; CREATE TABLE test(a int, b int default NEXT VALUE FOR seq)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Exec(`insert into test (a) VALUES (1), (2), (3)`)
|
||||
err = conn.Exec(`insert into test (a) VALUES (1), (2), (3)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
res, err := db.Query("SELECT * FROM test")
|
||||
res, err := conn.Query("SELECT * FROM test")
|
||||
require.NoError(t, err)
|
||||
defer res.Close()
|
||||
|
||||
@@ -148,21 +160,25 @@ func TestInsertSelect(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.Exec(`
|
||||
CREATE TABLE foo(a INT, b INT, c INT, d INT, e INT);
|
||||
CREATE TABLE bar(a INT, b INT, c INT, d INT, e INT);
|
||||
INSERT INTO bar (a, b) VALUES (1, 10)
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Exec(test.query, test.params...)
|
||||
err = conn.Exec(test.query, test.params...)
|
||||
if test.fails {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
st, err := db.Query("SELECT * FROM foo")
|
||||
st, err := conn.Query("SELECT * FROM foo")
|
||||
require.NoError(t, err)
|
||||
defer st.Close()
|
||||
|
||||
|
@@ -106,7 +106,11 @@ func TestSelectStmt(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`--sql
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.Exec(`--sql
|
||||
CREATE TABLE test (
|
||||
k INTEGER PRIMARY KEY,
|
||||
color TEXT,
|
||||
@@ -117,7 +121,7 @@ func TestSelectStmt(t *testing.T) {
|
||||
)`)
|
||||
require.NoError(t, err)
|
||||
if withIndexes {
|
||||
err = db.Exec(`
|
||||
err = conn.Exec(`
|
||||
CREATE INDEX idx_color ON test (color);
|
||||
CREATE INDEX idx_size ON test (size);
|
||||
CREATE INDEX idx_shape ON test (shape);
|
||||
@@ -127,14 +131,14 @@ func TestSelectStmt(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err = db.Exec("INSERT INTO test (k, color, size, shape) VALUES (1, 'red', 10, 'square')")
|
||||
err = conn.Exec("INSERT INTO test (k, color, size, shape) VALUES (1, 'red', 10, 'square')")
|
||||
require.NoError(t, err)
|
||||
err = db.Exec("INSERT INTO test (k, color, size, weight) VALUES (2, 'blue', 10, 100)")
|
||||
err = conn.Exec("INSERT INTO test (k, color, size, weight) VALUES (2, 'blue', 10, 100)")
|
||||
require.NoError(t, err)
|
||||
err = db.Exec("INSERT INTO test (k, height, weight) VALUES (3, 100, 200)")
|
||||
err = conn.Exec("INSERT INTO test (k, height, weight) VALUES (3, 100, 200)")
|
||||
require.NoError(t, err)
|
||||
|
||||
st, err := db.Query(test.query, test.params...)
|
||||
st, err := conn.Query(test.query, test.params...)
|
||||
defer st.Close()
|
||||
|
||||
if test.fails {
|
||||
@@ -159,19 +163,23 @@ func TestSelectStmt(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec("CREATE TABLE test (foo INTEGER PRIMARY KEY, bar TEXT)")
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.Exec("CREATE TABLE test (foo INTEGER PRIMARY KEY, bar TEXT)")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Exec(`INSERT INTO test (foo, bar) VALUES (1, 'a')`)
|
||||
err = conn.Exec(`INSERT INTO test (foo, bar) VALUES (1, 'a')`)
|
||||
require.NoError(t, err)
|
||||
err = db.Exec(`INSERT INTO test (foo, bar) VALUES (2, 'b')`)
|
||||
err = conn.Exec(`INSERT INTO test (foo, bar) VALUES (2, 'b')`)
|
||||
require.NoError(t, err)
|
||||
err = db.Exec(`INSERT INTO test (foo, bar) VALUES (3, 'c')`)
|
||||
err = conn.Exec(`INSERT INTO test (foo, bar) VALUES (3, 'c')`)
|
||||
require.NoError(t, err)
|
||||
err = db.Exec(`INSERT INTO test (foo, bar) VALUES (4, 'd')`)
|
||||
err = conn.Exec(`INSERT INTO test (foo, bar) VALUES (4, 'd')`)
|
||||
require.NoError(t, err)
|
||||
|
||||
st, err := db.Query("SELECT * FROM test WHERE foo < 400 AND foo >= 2")
|
||||
st, err := conn.Query("SELECT * FROM test WHERE foo < 400 AND foo >= 2")
|
||||
require.NoError(t, err)
|
||||
defer st.Close()
|
||||
|
||||
@@ -195,13 +203,17 @@ func TestSelectStmt(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec("CREATE TABLE test(foo INT); CREATE INDEX idx_foo ON test(foo);")
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.Exec("CREATE TABLE test(foo INT); CREATE INDEX idx_foo ON test(foo);")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Exec(`INSERT INTO test (foo) VALUES (4), (2), (1), (3)`)
|
||||
err = conn.Exec(`INSERT INTO test (foo) VALUES (4), (2), (1), (3)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
st, err := db.Query("SELECT * FROM test ORDER BY foo")
|
||||
st, err := conn.Query("SELECT * FROM test ORDER BY foo")
|
||||
require.NoError(t, err)
|
||||
defer st.Close()
|
||||
|
||||
@@ -233,7 +245,11 @@ func TestSelectStmt(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.Exec(`
|
||||
CREATE TABLE test(a INT);
|
||||
INSERT INTO test (a) VALUES (1);
|
||||
CREATE SEQUENCE seq;
|
||||
@@ -241,7 +257,7 @@ func TestSelectStmt(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// normal query
|
||||
r, err := db.QueryRow("SELECT a, NEXT VALUE FOR seq FROM test")
|
||||
r, err := conn.QueryRow("SELECT a, NEXT VALUE FOR seq FROM test")
|
||||
require.NoError(t, err)
|
||||
var a, seq int
|
||||
err = r.Scan(&a, &seq)
|
||||
@@ -250,7 +266,7 @@ func TestSelectStmt(t *testing.T) {
|
||||
require.Equal(t, 1, seq)
|
||||
|
||||
// query with no table
|
||||
r, err = db.QueryRow("SELECT NEXT VALUE FOR seq")
|
||||
r, err = conn.QueryRow("SELECT NEXT VALUE FOR seq")
|
||||
require.NoError(t, err)
|
||||
err = r.Scan(&seq)
|
||||
require.NoError(t, err)
|
||||
@@ -262,13 +278,17 @@ func TestSelectStmt(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.Exec(`
|
||||
CREATE TABLE test(a INT);
|
||||
INSERT INTO test (a) VALUES (1), (2), (3);
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := db.QueryRow("SELECT a FROM test LIMIT ? OFFSET ?", 1, 1)
|
||||
r, err := conn.QueryRow("SELECT a FROM test LIMIT ? OFFSET ?", 1, 1)
|
||||
require.NoError(t, err)
|
||||
var a int
|
||||
err = r.Scan(&a)
|
||||
@@ -302,7 +322,11 @@ func TestDistinct(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
tx, err := db.Begin(true)
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
tx, err := conn.Begin(true)
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
@@ -334,7 +358,7 @@ func TestDistinct(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
q, err := db.Query(test.query)
|
||||
q, err := conn.Query(test.query)
|
||||
require.NoError(t, err)
|
||||
defer q.Close()
|
||||
|
||||
|
@@ -33,6 +33,7 @@ func (stmt *basePreparedStatement) Run(ctx *Context) (Result, error) {
|
||||
|
||||
type Context struct {
|
||||
DB *database.Database
|
||||
Conn *database.Connection
|
||||
Tx *database.Transaction
|
||||
Params []environment.Param
|
||||
}
|
||||
|
@@ -40,29 +40,33 @@ func TestUpdateStmt(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec("CREATE TABLE test (a text not null, b text, c text, d text, e text)")
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.Exec("CREATE TABLE test (a text not null, b text, c text, d text, e text)")
|
||||
require.NoError(t, err)
|
||||
|
||||
if indexed {
|
||||
err = db.Exec("CREATE INDEX idx_test_a ON test(a)")
|
||||
err = conn.Exec("CREATE INDEX idx_test_a ON test(a)")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err = db.Exec("INSERT INTO test (a, b, c) VALUES ('foo1', 'bar1', 'baz1')")
|
||||
err = conn.Exec("INSERT INTO test (a, b, c) VALUES ('foo1', 'bar1', 'baz1')")
|
||||
require.NoError(t, err)
|
||||
err = db.Exec("INSERT INTO test (a, b) VALUES ('foo2', 'bar2')")
|
||||
err = conn.Exec("INSERT INTO test (a, b) VALUES ('foo2', 'bar2')")
|
||||
require.NoError(t, err)
|
||||
err = db.Exec("INSERT INTO test (a, d, e) VALUES ('foo3', 'bar3', 'baz3')")
|
||||
err = conn.Exec("INSERT INTO test (a, d, e) VALUES ('foo3', 'bar3', 'baz3')")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Exec(test.query, test.params...)
|
||||
err = conn.Exec(test.query, test.params...)
|
||||
if test.fails {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
st, err := db.Query("SELECT * FROM test")
|
||||
st, err := conn.Query("SELECT * FROM test")
|
||||
require.NoError(t, err)
|
||||
defer st.Close()
|
||||
|
||||
|
@@ -6,6 +6,10 @@ import (
|
||||
"github.com/cockroachdb/errors"
|
||||
)
|
||||
|
||||
var _ queryAlterer = BeginStmt{}
|
||||
var _ queryAlterer = RollbackStmt{}
|
||||
var _ queryAlterer = CommitStmt{}
|
||||
|
||||
// BeginStmt is a statement that creates a new transaction.
|
||||
type BeginStmt struct {
|
||||
Writable bool
|
||||
@@ -16,15 +20,14 @@ func (stmt BeginStmt) Prepare(*statement.Context) (statement.Statement, error) {
|
||||
return stmt, nil
|
||||
}
|
||||
|
||||
func (stmt BeginStmt) alterQuery(db *database.Database, q *Query) error {
|
||||
func (stmt BeginStmt) alterQuery(conn *database.Connection, q *Query) error {
|
||||
if q.tx != nil {
|
||||
return errors.New("cannot begin a transaction within a transaction")
|
||||
}
|
||||
|
||||
var err error
|
||||
q.tx, err = db.BeginTx(&database.TxOptions{
|
||||
q.tx, err = conn.BeginTx(&database.TxOptions{
|
||||
ReadOnly: !stmt.Writable,
|
||||
Attached: true,
|
||||
})
|
||||
q.autoCommit = false
|
||||
return err
|
||||
@@ -46,7 +49,7 @@ func (stmt RollbackStmt) Prepare(*statement.Context) (statement.Statement, error
|
||||
return stmt, nil
|
||||
}
|
||||
|
||||
func (stmt RollbackStmt) alterQuery(db *database.Database, q *Query) error {
|
||||
func (stmt RollbackStmt) alterQuery(conn *database.Connection, q *Query) error {
|
||||
if q.tx == nil || q.autoCommit {
|
||||
return errors.New("cannot rollback with no active transaction")
|
||||
}
|
||||
@@ -76,7 +79,7 @@ func (stmt CommitStmt) Prepare(*statement.Context) (statement.Statement, error)
|
||||
return stmt, nil
|
||||
}
|
||||
|
||||
func (stmt CommitStmt) alterQuery(db *database.Database, q *Query) error {
|
||||
func (stmt CommitStmt) alterQuery(conn *database.Connection, q *Query) error {
|
||||
if q.tx == nil || q.autoCommit {
|
||||
return errors.New("cannot commit with no active transaction")
|
||||
}
|
||||
|
@@ -30,10 +30,14 @@ func TestTransactionRun(t *testing.T) {
|
||||
db, err := chai.Open(":memory:")
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
defer db.Exec("ROLLBACK")
|
||||
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
defer conn.Exec("ROLLBACK")
|
||||
|
||||
for _, q := range test.queries {
|
||||
err = db.Exec(q)
|
||||
err = conn.Exec(q)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
@@ -62,16 +62,19 @@ func TestParserDelete(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
db := testutil.NewTestDB(t)
|
||||
|
||||
testutil.MustExec(t, db, nil, "CREATE TABLE test(age int)")
|
||||
db, tx, cleanup := testutil.NewTestTx(t)
|
||||
defer cleanup()
|
||||
|
||||
testutil.MustExec(t, db, tx, "CREATE TABLE test(age int)")
|
||||
|
||||
q, err := parser.ParseQuery(test.s)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = q.Prepare(&query.Context{
|
||||
Ctx: context.Background(),
|
||||
DB: db,
|
||||
Ctx: context.Background(),
|
||||
DB: db,
|
||||
Conn: tx.Connection(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@@ -200,9 +200,10 @@ func TestParserInsert(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
db := testutil.NewTestDB(t)
|
||||
db, tx, cleanup := testutil.NewTestTx(t)
|
||||
defer cleanup()
|
||||
|
||||
testutil.MustExec(t, db, nil, "CREATE TABLE test(a TEXT, b TEXT); CREATE TABLE foo(c TEXT, d TEXT);")
|
||||
testutil.MustExec(t, db, tx, "CREATE TABLE test(a TEXT, b TEXT); CREATE TABLE foo(c TEXT, d TEXT);")
|
||||
|
||||
q, err := parser.ParseQuery(test.s)
|
||||
if test.fails {
|
||||
@@ -212,8 +213,9 @@ func TestParserInsert(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
err = q.Prepare(&query.Context{
|
||||
Ctx: context.Background(),
|
||||
DB: db,
|
||||
Ctx: context.Background(),
|
||||
DB: db,
|
||||
Conn: tx.Connection(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@@ -353,9 +353,10 @@ func TestParserSelect(t *testing.T) {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
q, err := parser.ParseQuery(test.s)
|
||||
if !test.mustFail {
|
||||
db := testutil.NewTestDB(t)
|
||||
db, tx, cleanup := testutil.NewTestTx(t)
|
||||
defer cleanup()
|
||||
|
||||
testutil.MustExec(t, db, nil, `
|
||||
testutil.MustExec(t, db, tx, `
|
||||
CREATE TABLE test(a TEXT, b TEXT, age int);
|
||||
CREATE TABLE test1(age INT, a INT);
|
||||
CREATE TABLE test2(age INT, a INT);
|
||||
@@ -367,8 +368,9 @@ func TestParserSelect(t *testing.T) {
|
||||
)
|
||||
|
||||
err = q.Prepare(&query.Context{
|
||||
Ctx: context.Background(),
|
||||
DB: db,
|
||||
Ctx: context.Background(),
|
||||
DB: db,
|
||||
Conn: tx.Connection(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@@ -49,9 +49,10 @@ func TestParserUpdate(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
db := testutil.NewTestDB(t)
|
||||
db, tx, cleanup := testutil.NewTestTx(t)
|
||||
defer cleanup()
|
||||
|
||||
testutil.MustExec(t, db, nil, "CREATE TABLE test(a INT, b TEXT)")
|
||||
testutil.MustExec(t, db, tx, "CREATE TABLE test(a INT, b TEXT)")
|
||||
|
||||
q, err := parser.ParseQuery(test.s)
|
||||
if test.errored {
|
||||
@@ -61,8 +62,9 @@ func TestParserUpdate(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
err = q.Prepare(&query.Context{
|
||||
Ctx: context.Background(),
|
||||
DB: db,
|
||||
Ctx: context.Background(),
|
||||
DB: db,
|
||||
Conn: tx.Connection(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@@ -142,6 +142,7 @@ type DiscardOperator struct {
|
||||
func Discard() *DiscardOperator {
|
||||
return &DiscardOperator{}
|
||||
}
|
||||
|
||||
func (it *DiscardOperator) Clone() Operator {
|
||||
return &DiscardOperator{
|
||||
BaseOperator: it.BaseOperator.Clone(),
|
||||
|
@@ -60,12 +60,28 @@ func NewTestDB(t testing.TB) *database.Database {
|
||||
return db
|
||||
}
|
||||
|
||||
func NewTestConn(t testing.TB, db *database.Database) *database.Connection {
|
||||
t.Helper()
|
||||
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
conn.Close()
|
||||
})
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
func NewTestTx(t testing.TB) (*database.Database, *database.Transaction, func()) {
|
||||
t.Helper()
|
||||
|
||||
db := NewTestDB(t)
|
||||
conn := NewTestConn(t, db)
|
||||
|
||||
tx, err := db.Begin(true)
|
||||
tx, err := conn.BeginTx(&database.TxOptions{
|
||||
ReadOnly: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return db, tx, func() {
|
||||
@@ -91,7 +107,7 @@ func Query(db *database.Database, tx *database.Transaction, q string, params ...
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx := &query.Context{Ctx: context.Background(), DB: db, Tx: tx, Params: params}
|
||||
ctx := &query.Context{Ctx: context.Background(), DB: db, Conn: tx.Connection(), Params: params}
|
||||
err = pq.Prepare(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@@ -111,7 +111,13 @@ func TestSQL(t *testing.T) {
|
||||
|
||||
if test.Fails {
|
||||
exec := func() error {
|
||||
res, err := db.Query(test.Expr)
|
||||
conn, err := db.Connect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
res, err := conn.Query(test.Expr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -131,7 +137,11 @@ func TestSQL(t *testing.T) {
|
||||
require.Errorf(t, err, "\nSource:%s:%d expected\n%s\nto raise an error but got none", absPath, test.Line, test.Expr)
|
||||
}
|
||||
} else {
|
||||
res, err := db.Query(test.Expr)
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err, "Source: %s:%d", absPath, test.Line)
|
||||
defer conn.Close()
|
||||
|
||||
res, err := conn.Query(test.Expr)
|
||||
require.NoError(t, err, "Source: %s:%d", absPath, test.Line)
|
||||
defer res.Close()
|
||||
|
||||
|
Reference in New Issue
Block a user