fix prepared statements

This commit is contained in:
Asdine El Hrychy
2025-08-31 17:45:17 +08:00
parent 40c1fcbbe1
commit 7f32a3b9c6
40 changed files with 541 additions and 665 deletions

View File

@@ -31,18 +31,11 @@ func ExecSQL(ctx context.Context, db *sql.DB, r io.Reader, w io.Writer) error {
var stmtWithOutputCount int
return parser.NewParser(r).Parse(func(s statement.Statement) error {
qq := query.New(s)
qctx := query.Context{
res, err := query.New(s).Run(&query.Context{
Ctx: ctx,
DB: conn.DB(),
Conn: conn.Conn(),
}
err := qq.Prepare(&qctx)
if err != nil {
return err
}
res, err := qq.Run(&qctx)
})
if err != nil {
return err
}

View File

@@ -19,12 +19,18 @@ func TestExecSQL(t *testing.T) {
err = ExecSQL(t.Context(), db, strings.NewReader(`
CREATE TABLE test(a INT, b TEXT);
CREATE INDEX idx_a ON test (a);
BEGIN;
INSERT INTO test (a, b) VALUES (10, 'aa'), (20, 'bb'), (30, 'cc');
ROLLBACK;
BEGIN;
INSERT INTO test (a, b) VALUES (1, 'a'), (2, 'b'), (3, 'c');
SELECT * FROM test;
COMMIT;
SELECT b, a FROM test;
`), &got)
require.NoError(t, err)
require.Equal(t, "a|b\n1|\"a\"\n2|\"b\"\n3|\"c\"\n", got.String())
require.Equal(t, "a|b\n1|\"a\"\n2|\"b\"\n3|\"c\"\n\nb|a\n\"a\"|1\n\"b\"|2\n\"c\"|3\n", got.String())
var res struct {
A int

View File

@@ -70,12 +70,12 @@ func ListIndexes(ctx context.Context, db *sql.DB, tableName string) ([]string, e
return nil, err
}
q, err := parser.ParseQuery(query)
statements, err := parser.ParseQuery(query)
if err != nil {
return nil, err
}
listName = append(listName, q.Statements[0].(*statement.CreateIndexStmt).Info.IndexName)
listName = append(listName, statements[0].(*statement.CreateIndexStmt).Info.IndexName)
}
if err := rows.Err(); err != nil {
return nil, err

View File

@@ -20,7 +20,13 @@ import (
func updateCatalog(t testing.TB, db *database.Database, fn func(tx *database.Transaction, catalog *database.CatalogWriter) error) {
t.Helper()
tx, err := db.Begin(true)
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
tx, err := conn.BeginTx(&database.TxOptions{
ReadOnly: false,
})
require.NoError(t, err)
defer tx.Rollback()

View File

@@ -95,7 +95,7 @@ func Open(path string, opts *Options) (*Database, error) {
return nil, err
}
tx, err := db.Begin(true)
tx, err := db.begin(true)
if err != nil {
return nil, err
}
@@ -183,7 +183,7 @@ func (db *Database) Connect() (*Connection, error) {
// Begin starts a new transaction with default options.
// The returned transaction must be closed either by calling Rollback or Commit.
func (db *Database) Begin(writable bool) (*Transaction, error) {
func (db *Database) begin(writable bool) (*Transaction, error) {
if db.closeContext.Err() != nil {
return nil, errors.New("database is closed")
}

View File

@@ -157,7 +157,13 @@ func TestSequence(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
db := testutil.NewTestDB(t)
tx, err := db.Begin(true)
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
tx, err := conn.BeginTx(&database.TxOptions{
ReadOnly: false,
})
require.NoError(t, err)
defer tx.Rollback()
@@ -193,7 +199,13 @@ func TestSequence(t *testing.T) {
t.Run("default cache", func(t *testing.T) {
db := testutil.NewTestDB(t)
tx, err := db.Begin(true)
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
tx, err := conn.BeginTx(&database.TxOptions{
ReadOnly: false,
})
require.NoError(t, err)
defer tx.Rollback()
@@ -230,7 +242,13 @@ func TestSequence(t *testing.T) {
t.Run("cache", func(t *testing.T) {
db := testutil.NewTestDB(t)
tx, err := db.Begin(true)
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
tx, err := conn.BeginTx(&database.TxOptions{
ReadOnly: false,
})
require.NoError(t, err)
defer tx.Rollback()
@@ -276,7 +294,13 @@ func TestSequence(t *testing.T) {
t.Run("cache desc", func(t *testing.T) {
db := testutil.NewTestDB(t)
tx, err := db.Begin(true)
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
tx, err := conn.BeginTx(&database.TxOptions{
ReadOnly: false,
})
require.NoError(t, err)
defer tx.Rollback()
@@ -323,7 +347,13 @@ func TestSequence(t *testing.T) {
t.Run("read-only", func(t *testing.T) {
db := testutil.NewTestDB(t)
tx, err := db.Begin(true)
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
tx, err := conn.BeginTx(&database.TxOptions{
ReadOnly: false,
})
require.NoError(t, err)
err = tx.CatalogWriter().CreateSequence(tx, &database.SequenceInfo{
@@ -339,7 +369,9 @@ func TestSequence(t *testing.T) {
require.NoError(t, err)
// open a read-only tx
tx, err = db.Begin(false)
tx, err = conn.BeginTx(&database.TxOptions{
ReadOnly: true,
})
require.NoError(t, err)
defer tx.Rollback()
@@ -353,7 +385,13 @@ func TestSequence(t *testing.T) {
t.Run("release", func(t *testing.T) {
db := testutil.NewTestDB(t)
tx, err := db.Begin(true)
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
tx, err := conn.BeginTx(&database.TxOptions{
ReadOnly: false,
})
require.NoError(t, err)
defer tx.Rollback()

View File

@@ -21,7 +21,13 @@ var errDontCommit = errors.New("don't commit please")
func update(t testing.TB, db *database.Database, fn func(tx *database.Transaction) error) {
t.Helper()
tx, err := db.Begin(true)
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
tx, err := conn.BeginTx(&database.TxOptions{
ReadOnly: false,
})
require.NoError(t, err)
defer tx.Rollback()
@@ -63,7 +69,7 @@ func createTable(t testing.TB, tx *database.Transaction, info database.TableInfo
stmt := statement.CreateTableStmt{Info: info}
res, err := stmt.Run(&statement.Context{
Tx: tx,
Conn: tx.Connection(),
})
require.NoError(t, err)
res.Close()
@@ -80,7 +86,7 @@ func createTableIfNotExists(t testing.TB, tx *database.Transaction, info databas
stmt := statement.CreateTableStmt{Info: info, IfNotExists: true}
res, err := stmt.Run(&statement.Context{
Tx: tx,
Conn: tx.Connection(),
})
require.NoError(t, err)
res.Close()

View File

@@ -13,13 +13,11 @@ import (
// Results are returned as streams.
type Query struct {
Statements []statement.Statement
tx *database.Transaction
autoCommit bool
}
// New creates a new query with the given statements.
func New(statements ...statement.Statement) Query {
return Query{Statements: statements}
func New(statements ...statement.Statement) *Query {
return &Query{Statements: statements}
}
type Context struct {
@@ -29,133 +27,101 @@ type Context struct {
Params []environment.Param
}
func (c *Context) GetTx() *database.Transaction {
return c.Conn.GetTx()
}
// Prepare the statements by calling their Prepare methods.
// It stops at the first statement that doesn't implement the statement.Preparer interface.
func (q *Query) Prepare(context *Context) error {
// Run executes all the statements in their own transaction and returns the last result.
func (q *Query) Run(c *Context) (*statement.Result, error) {
var res *statement.Result
var err error
ctx := c.Ctx
if ctx == nil {
ctx = context.Background()
}
sctx := statement.Context{
DB: c.DB,
Conn: c.Conn,
Params: c.Params,
}
var tx *database.Transaction
ctx := context.Ctx
for i, stmt := range q.Statements {
if ctx != nil {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
if err := ctx.Err(); err != nil {
return nil, err
}
if tx == nil {
tx = context.GetTx()
if tx == nil {
tx, err = context.Conn.BeginTx(&database.TxOptions{
ReadOnly: true,
})
if err != nil {
return err
}
defer tx.Rollback()
}
}
sctx := &statement.Context{
DB: context.DB,
Conn: context.Conn,
Tx: tx,
}
err = stmt.Bind(sctx)
if err != nil {
return err
}
p, ok := stmt.(statement.Preparer)
if !ok {
continue
}
stmt, err := p.Prepare(sctx)
if err != nil {
return err
}
q.Statements[i] = stmt
}
return nil
}
// Run executes all the statements in their own transaction and returns the last result.
func (q *Query) Run(context *Context) (*statement.Result, error) {
var res statement.Result
var err error
q.tx = context.GetTx()
if q.tx == nil {
q.autoCommit = true
}
ctx := context.Ctx
for i, stmt := range q.Statements {
if ctx != nil {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
}
// reinitialize the result
res = statement.Result{}
res = nil
if qa, ok := stmt.(queryAlterer); ok {
err = qa.alterQuery(context.Conn, q)
if err != nil {
if tx := context.GetTx(); tx != nil {
_ = tx.Rollback()
}
return nil, err
}
continue
// handles transactions
isReadOnly := false
if ro, ok := stmt.(statement.ReadOnly); ok {
isReadOnly = ro.IsReadOnly()
}
if q.tx == nil {
q.tx, err = context.Conn.BeginTx(&database.TxOptions{
ReadOnly: stmt.IsReadOnly(),
needsTx := true
if ntx, ok := stmt.(statement.Transactional); ok {
needsTx = ntx.NeedsTransaction()
}
var autoCommit bool
if c.Conn.GetTx() == nil && needsTx {
autoCommit = true
tx, err = c.Conn.BeginTx(&database.TxOptions{
ReadOnly: isReadOnly,
})
if err != nil {
return nil, err
}
}
res, err = stmt.Run(&statement.Context{
DB: context.DB,
Conn: context.Conn,
Tx: q.tx,
Params: context.Params,
})
// bind
if b, ok := stmt.(statement.Bindable); ok {
err = b.Bind(&sctx)
if err != nil {
if tx != nil {
_ = tx.Rollback()
}
return nil, err
}
}
// prepare
if prep, ok := stmt.(statement.Preparer); ok {
stmt, err = prep.Prepare(&sctx)
if err != nil {
if tx != nil {
_ = tx.Rollback()
}
return nil, err
}
}
// run
res, err = stmt.Run(&sctx)
if err != nil {
if q.autoCommit {
_ = q.tx.Rollback()
if tx != nil {
_ = tx.Rollback()
}
return nil, err
}
if res == nil {
res = &statement.Result{}
}
// if there are still statements to be executed,
// and the current statement is not read-only,
// iterate over the result.
if !stmt.IsReadOnly() && i+1 < len(q.Statements) {
err = res.Skip()
if res != nil && !isReadOnly && i+1 < len(q.Statements) {
err = res.Skip(ctx)
if err != nil {
if q.autoCommit {
_ = q.tx.Rollback()
if tx != nil {
_ = tx.Rollback()
}
return nil, err
@@ -164,31 +130,27 @@ func (q *Query) Run(context *Context) (*statement.Result, error) {
// it there is an opened transaction but there are still statements
// to be executed, close the current transaction.
if q.tx != nil && q.autoCommit && i+1 < len(q.Statements) {
if q.tx.Writable {
err := q.tx.Commit()
if autoCommit && i+1 < len(q.Statements) {
if tx.Writable {
err := tx.Commit()
if err != nil {
return nil, err
}
} else {
err := q.tx.Rollback()
err := tx.Rollback()
if err != nil {
return nil, err
}
}
q.tx = nil
tx = nil
}
}
if q.autoCommit {
if tx != nil {
// the returned result will now own the transaction.
// its Close method is expected to be called.
res.Tx = q.tx
res.Tx = tx
}
return &res, nil
}
type queryAlterer interface {
alterQuery(conn *database.Connection, q *Query) error
return res, nil
}

View File

@@ -18,34 +18,23 @@ type AlterTableRenameStmt struct {
NewTableName string
}
func (stmt *AlterTableRenameStmt) Bind(ctx *Context) error {
return nil
}
// IsReadOnly always returns false. It implements the Statement interface.
func (stmt *AlterTableRenameStmt) IsReadOnly() bool {
return false
}
// Run runs the ALTER TABLE statement in the given transaction.
// It implements the Statement interface.
func (stmt *AlterTableRenameStmt) Run(ctx *Context) (Result, error) {
var res Result
func (stmt *AlterTableRenameStmt) Run(ctx *Context) (*Result, error) {
if stmt.TableName == "" {
return res, errors.New("missing table name")
return nil, errors.New("missing table name")
}
if stmt.NewTableName == "" {
return res, errors.New("missing new table name")
return nil, errors.New("missing new table name")
}
if stmt.TableName == stmt.NewTableName {
return res, errs.AlreadyExistsError{Name: stmt.NewTableName}
return nil, errs.AlreadyExistsError{Name: stmt.NewTableName}
}
err := ctx.Tx.CatalogWriter().RenameTable(ctx.Tx, stmt.TableName, stmt.NewTableName)
return res, err
err := ctx.Conn.GetTx().CatalogWriter().RenameTable(ctx.Conn.GetTx(), stmt.TableName, stmt.NewTableName)
return nil, err
}
type AlterTableAddColumnStmt struct {
@@ -54,41 +43,32 @@ type AlterTableAddColumnStmt struct {
TableConstraints database.TableConstraints
}
// IsReadOnly always returns false. It implements the Statement interface.
func (stmt *AlterTableAddColumnStmt) IsReadOnly() bool {
return false
}
func (stmt *AlterTableAddColumnStmt) Bind(ctx *Context) error {
return nil
}
// Run runs the ALTER TABLE ADD COLUMN statement in the given transaction.
// It implements the Statement interface.
// The statement rebuilds the table.
func (stmt *AlterTableAddColumnStmt) Run(ctx *Context) (Result, error) {
func (stmt *AlterTableAddColumnStmt) Run(ctx *Context) (*Result, error) {
var err error
// get the table before adding the column constraint
// and assign the table to the table.Scan operator
// so that it can decode the records properly
scan := table.Scan(stmt.TableName)
scan.Table, err = ctx.Tx.Catalog.GetTable(ctx.Tx, stmt.TableName)
scan.Table, err = ctx.Conn.GetTx().Catalog.GetTable(ctx.Conn.GetTx(), stmt.TableName)
if err != nil {
return Result{}, errors.Wrap(err, "failed to get table")
return nil, errors.Wrap(err, "failed to get table")
}
// get the current list of indexes
indexNames := ctx.Tx.Catalog.ListIndexes(stmt.TableName)
indexNames := ctx.Conn.GetTx().Catalog.ListIndexes(stmt.TableName)
// add the column constraint to the table
err = ctx.Tx.CatalogWriter().AddColumnConstraint(
ctx.Tx,
err = ctx.Conn.GetTx().CatalogWriter().AddColumnConstraint(
ctx.Conn.GetTx(),
stmt.TableName,
stmt.ColumnConstraint,
stmt.TableConstraints)
if err != nil {
return Result{}, err
return nil, err
}
// create a unique index for every unique constraint
@@ -96,7 +76,7 @@ func (stmt *AlterTableAddColumnStmt) Run(ctx *Context) (Result, error) {
var newIdxs []*database.IndexInfo
for _, tc := range stmt.TableConstraints {
if tc.Unique {
idx, err := ctx.Tx.CatalogWriter().CreateIndex(ctx.Tx, &database.IndexInfo{
idx, err := ctx.Conn.GetTx().CatalogWriter().CreateIndex(ctx.Conn.GetTx(), &database.IndexInfo{
Columns: tc.Columns,
Unique: true,
Owner: database.Owner{
@@ -105,7 +85,7 @@ func (stmt *AlterTableAddColumnStmt) Run(ctx *Context) (Result, error) {
},
})
if err != nil {
return Result{}, err
return nil, err
}
newIdxs = append(newIdxs, idx)
@@ -141,11 +121,11 @@ func (stmt *AlterTableAddColumnStmt) Run(ctx *Context) (Result, error) {
s = s.Pipe(table.Insert(stmt.TableName))
// insert the record into the all the indexes
indexNames = ctx.Tx.Catalog.ListIndexes(stmt.TableName)
indexNames = ctx.Conn.GetTx().Catalog.ListIndexes(stmt.TableName)
for _, indexName := range indexNames {
info, err := ctx.Tx.Catalog.GetIndexInfo(indexName)
info, err := ctx.Conn.GetTx().Catalog.GetIndexInfo(indexName)
if err != nil {
return Result{}, err
return nil, err
}
if info.Unique {
s = s.Pipe(index.Validate(indexName))
@@ -176,7 +156,7 @@ func (stmt *AlterTableAddColumnStmt) Run(ctx *Context) (Result, error) {
s = s.Pipe(stream.Discard())
// do NOT optimize the stream
return Result{
return &Result{
Result: &StreamStmtResult{
Stream: s,
Context: ctx,

View File

@@ -20,20 +20,9 @@ type CreateTableStmt struct {
Info database.TableInfo
}
// IsReadOnly always returns false. It implements the Statement interface.
func (stmt *CreateTableStmt) IsReadOnly() bool {
return false
}
func (stmt *CreateTableStmt) Bind(ctx *Context) error {
return nil
}
// Run runs the Create table statement in the given transaction.
// It implements the Statement interface.
func (stmt *CreateTableStmt) Run(ctx *Context) (Result, error) {
var res Result
func (stmt *CreateTableStmt) Run(ctx *Context) (*Result, error) {
// if there is no primary key, create a rowid sequence
if stmt.Info.PrimaryKey == nil {
seq := database.SequenceInfo{
@@ -45,25 +34,25 @@ func (stmt *CreateTableStmt) Run(ctx *Context) (Result, error) {
TableName: stmt.Info.TableName,
},
}
err := ctx.Tx.CatalogWriter().CreateSequence(ctx.Tx, &seq)
err := ctx.Conn.GetTx().CatalogWriter().CreateSequence(ctx.Conn.GetTx(), &seq)
if err != nil {
return res, err
return nil, err
}
stmt.Info.RowidSequenceName = seq.Name
}
err := ctx.Tx.CatalogWriter().CreateTable(ctx.Tx, stmt.Info.TableName, &stmt.Info)
err := ctx.Conn.GetTx().CatalogWriter().CreateTable(ctx.Conn.GetTx(), stmt.Info.TableName, &stmt.Info)
if stmt.IfNotExists {
if errs.IsAlreadyExistsError(err) {
return res, nil
return nil, nil
}
}
// create a unique index for every unique constraint
for _, tc := range stmt.Info.TableConstraints {
if tc.Unique {
_, err = ctx.Tx.CatalogWriter().CreateIndex(ctx.Tx, &database.IndexInfo{
_, err = ctx.Conn.GetTx().CatalogWriter().CreateIndex(ctx.Conn.GetTx(), &database.IndexInfo{
Columns: tc.Columns,
Unique: true,
Owner: database.Owner{
@@ -73,12 +62,12 @@ func (stmt *CreateTableStmt) Run(ctx *Context) (Result, error) {
KeySortOrder: tc.SortOrder,
})
if err != nil {
return res, err
return nil, err
}
}
}
return res, err
return nil, err
}
// CreateIndexStmt represents a parsed CREATE INDEX statement.
@@ -87,28 +76,17 @@ type CreateIndexStmt struct {
Info database.IndexInfo
}
// IsReadOnly always returns false. It implements the Statement interface.
func (stmt *CreateIndexStmt) IsReadOnly() bool {
return false
}
func (stmt *CreateIndexStmt) Bind(ctx *Context) error {
return nil
}
// Run runs the Create index statement in the given transaction.
// It implements the Statement interface.
func (stmt *CreateIndexStmt) Run(ctx *Context) (Result, error) {
var res Result
_, err := ctx.Tx.CatalogWriter().CreateIndex(ctx.Tx, &stmt.Info)
func (stmt *CreateIndexStmt) Run(ctx *Context) (*Result, error) {
_, err := ctx.Conn.GetTx().CatalogWriter().CreateIndex(ctx.Conn.GetTx(), &stmt.Info)
if stmt.IfNotExists {
if errs.IsAlreadyExistsError(err) {
return res, nil
return nil, nil
}
}
if err != nil {
return res, err
return nil, err
}
s := stream.New(table.Scan(stmt.Info.Owner.TableName)).
@@ -116,8 +94,7 @@ func (stmt *CreateIndexStmt) Run(ctx *Context) (Result, error) {
Pipe(stream.Discard())
ss := PreparedStreamStmt{
Stream: s,
ReadOnly: false,
Stream: s,
}
return ss.Run(ctx)
@@ -129,25 +106,14 @@ type CreateSequenceStmt struct {
Info database.SequenceInfo
}
// IsReadOnly always returns false. It implements the Statement interface.
func (stmt *CreateSequenceStmt) IsReadOnly() bool {
return false
}
func (stmt *CreateSequenceStmt) Bind(ctx *Context) error {
return nil
}
// Run the statement in the given transaction.
// It implements the Statement interface.
func (stmt *CreateSequenceStmt) Run(ctx *Context) (Result, error) {
var res Result
err := ctx.Tx.CatalogWriter().CreateSequence(ctx.Tx, &stmt.Info)
func (stmt *CreateSequenceStmt) Run(ctx *Context) (*Result, error) {
err := ctx.Conn.GetTx().CatalogWriter().CreateSequence(ctx.Conn.GetTx(), &stmt.Info)
if stmt.IfNotExists {
if errs.IsAlreadyExistsError(err) {
return res, nil
return nil, nil
}
}
return res, err
return nil, err
}

View File

@@ -13,7 +13,7 @@ var _ Statement = (*DeleteStmt)(nil)
// DeleteConfig holds DELETE configuration.
type DeleteStmt struct {
basePreparedStatement
PreparedStreamStmt
TableName string
WhereExpr expr.Expr
@@ -23,17 +23,6 @@ type DeleteStmt struct {
OrderByDirection scanner.Token
}
func NewDeleteStatement() *DeleteStmt {
var p DeleteStmt
p.basePreparedStatement = basePreparedStatement{
Preparer: &p,
ReadOnly: false,
}
return &p
}
func (stmt *DeleteStmt) Bind(ctx *Context) error {
err := BindExpr(ctx, stmt.TableName, stmt.WhereExpr)
if err != nil {
@@ -81,7 +70,7 @@ func (stmt *DeleteStmt) Prepare(c *Context) (Statement, error) {
s = s.Pipe(rows.Take(stmt.LimitExpr))
}
indexNames := c.Tx.Catalog.ListIndexes(stmt.TableName)
indexNames := c.Conn.GetTx().Catalog.ListIndexes(stmt.TableName)
for _, indexName := range indexNames {
s = s.Pipe(index.Delete(indexName))
}
@@ -90,10 +79,6 @@ func (stmt *DeleteStmt) Prepare(c *Context) (Statement, error) {
s = s.Pipe(stream.Discard())
st := StreamStmt{
Stream: s,
ReadOnly: false,
}
return st.Prepare(c)
stmt.PreparedStreamStmt.Stream = s
return stmt, nil
}

View File

@@ -28,36 +28,34 @@ func (stmt *DropTableStmt) Bind(ctx *Context) error {
// Run runs the DropTable statement in the given transaction.
// It implements the Statement interface.
func (stmt *DropTableStmt) Run(ctx *Context) (Result, error) {
var res Result
func (stmt *DropTableStmt) Run(ctx *Context) (*Result, error) {
if stmt.TableName == "" {
return res, errors.New("missing table name")
return nil, errors.New("missing table name")
}
tb, err := ctx.Tx.Catalog.GetTable(ctx.Tx, stmt.TableName)
tb, err := ctx.Conn.GetTx().Catalog.GetTable(ctx.Conn.GetTx(), stmt.TableName)
if err != nil {
if errs.IsNotFoundError(err) && stmt.IfExists {
err = nil
}
return res, err
return nil, err
}
err = ctx.Tx.CatalogWriter().DropTable(ctx.Tx, stmt.TableName)
err = ctx.Conn.GetTx().CatalogWriter().DropTable(ctx.Conn.GetTx(), stmt.TableName)
if err != nil {
return res, err
return nil, err
}
// if there is no primary key, drop the rowid sequence
if tb.Info.PrimaryKey == nil {
err = ctx.Tx.CatalogWriter().DropSequence(ctx.Tx, tb.Info.RowidSequenceName)
err = ctx.Conn.GetTx().CatalogWriter().DropSequence(ctx.Conn.GetTx(), tb.Info.RowidSequenceName)
if err != nil {
return res, err
return nil, err
}
}
return res, err
return nil, err
}
// DropIndexStmt is a DSL that allows creating a DROP INDEX query.
@@ -77,19 +75,17 @@ func (stmt *DropIndexStmt) Bind(ctx *Context) error {
// Run runs the DropIndex statement in the given transaction.
// It implements the Statement interface.
func (stmt *DropIndexStmt) Run(ctx *Context) (Result, error) {
var res Result
func (stmt *DropIndexStmt) Run(ctx *Context) (*Result, error) {
if stmt.IndexName == "" {
return res, errors.New("missing index name")
return nil, errors.New("missing index name")
}
err := ctx.Tx.CatalogWriter().DropIndex(ctx.Tx, stmt.IndexName)
err := ctx.Conn.GetTx().CatalogWriter().DropIndex(ctx.Conn.GetTx(), stmt.IndexName)
if errs.IsNotFoundError(err) && stmt.IfExists {
err = nil
}
return res, err
return nil, err
}
// DropSequenceStmt is a DSL that allows creating a DROP INDEX query.
@@ -109,29 +105,27 @@ func (stmt *DropSequenceStmt) Bind(ctx *Context) error {
// Run runs the DropSequence statement in the given transaction.
// It implements the Statement interface.
func (stmt *DropSequenceStmt) Run(ctx *Context) (Result, error) {
var res Result
func (stmt *DropSequenceStmt) Run(ctx *Context) (*Result, error) {
if stmt.SequenceName == "" {
return res, errors.New("missing index name")
return nil, errors.New("missing index name")
}
seq, err := ctx.Tx.Catalog.GetSequence(stmt.SequenceName)
seq, err := ctx.Conn.GetTx().Catalog.GetSequence(stmt.SequenceName)
if err != nil {
if errs.IsNotFoundError(err) && stmt.IfExists {
err = nil
}
return res, err
return nil, err
}
if seq.Info.Owner.TableName != "" {
return res, fmt.Errorf("cannot drop sequence %s because constraint of table %s requires it", seq.Info.Name, seq.Info.Owner.TableName)
return nil, fmt.Errorf("cannot drop sequence %s because constraint of table %s requires it", seq.Info.Name, seq.Info.Owner.TableName)
}
err = ctx.Tx.CatalogWriter().DropSequence(ctx.Tx, stmt.SequenceName)
err = ctx.Conn.GetTx().CatalogWriter().DropSequence(ctx.Conn.GetTx(), stmt.SequenceName)
if errs.IsNotFoundError(err) && stmt.IfExists {
err = nil
}
return res, err
return nil, err
}

View File

@@ -19,7 +19,7 @@ type ExplainStmt struct {
}
func (stmt *ExplainStmt) Bind(ctx *Context) error {
if s, ok := stmt.Statement.(Statement); ok {
if s, ok := stmt.Statement.(Bindable); ok {
return s.Bind(ctx)
}
@@ -30,26 +30,36 @@ func (stmt *ExplainStmt) Bind(ctx *Context) error {
// If the statement is a stream, Optimize will be called prior to
// displaying all the operations.
// Explain currently only works on SELECT, UPDATE, INSERT and DELETE statements.
func (stmt *ExplainStmt) Run(ctx *Context) (Result, error) {
func (stmt *ExplainStmt) Run(ctx *Context) (*Result, error) {
st, err := stmt.Statement.Prepare(ctx)
if err != nil {
return Result{}, err
return nil, err
}
s, ok := st.(*PreparedStreamStmt)
if !ok {
return Result{}, errors.New("EXPLAIN only works on INSERT, SELECT, UPDATE AND DELETE statements")
var s *stream.Stream
switch stmt := st.(type) {
case *InsertStmt:
s = stmt.Stream
case *SelectStmt:
s = stmt.Stream
case *UpdateStmt:
s = stmt.Stream
case *DeleteStmt:
s = stmt.Stream
default:
return nil, errors.New("EXPLAIN only works on INSERT, SELECT, UPDATE AND DELETE statements")
}
// Optimize the stream.
s.Stream, err = planner.Optimize(s.Stream, ctx.Tx.Catalog, ctx.Params)
s, err = planner.Optimize(s, ctx.Conn.GetTx().Catalog, ctx.Params)
if err != nil {
return Result{}, err
return nil, err
}
var plan string
if s.Stream != nil {
plan = s.Stream.String()
if s != nil {
plan = s.String()
} else {
plan = "<no exec>"
}
@@ -62,7 +72,6 @@ func (stmt *ExplainStmt) Run(ctx *Context) (Result, error) {
Expr: expr.LiteralValue{Value: types.NewTextValue(plan)},
}),
},
ReadOnly: true,
}
return newStatement.Run(ctx)
}

View File

@@ -15,7 +15,7 @@ var _ Statement = (*InsertStmt)(nil)
// InsertStmt holds INSERT configuration.
type InsertStmt struct {
basePreparedStatement
PreparedStreamStmt
TableName string
Values []expr.Expr
@@ -25,17 +25,6 @@ type InsertStmt struct {
OnConflict database.OnConflictAction
}
func NewInsertStatement() *InsertStmt {
var p InsertStmt
p.basePreparedStatement = basePreparedStatement{
Preparer: &p,
ReadOnly: false,
}
return &p
}
func (stmt *InsertStmt) Bind(ctx *Context) error {
for i := range stmt.Values {
err := BindExpr(ctx, stmt.TableName, stmt.Values[i])
@@ -45,7 +34,7 @@ func (stmt *InsertStmt) Bind(ctx *Context) error {
}
if stmt.SelectStmt != nil {
if s, ok := stmt.SelectStmt.(Statement); ok {
if s, ok := stmt.SelectStmt.(Bindable); ok {
err := s.Bind(ctx)
if err != nil {
return err
@@ -68,7 +57,7 @@ func (stmt *InsertStmt) Prepare(c *Context) (Statement, error) {
var columns []string
if stmt.Values != nil {
ti, err := c.Tx.Catalog.GetTableInfo(stmt.TableName)
ti, err := c.Conn.GetTx().Catalog.GetTableInfo(stmt.TableName)
if err != nil {
return nil, err
}
@@ -129,7 +118,7 @@ func (stmt *InsertStmt) Prepare(c *Context) (Statement, error) {
return nil, err
}
s = selectStream.(*PreparedStreamStmt).Stream
s = selectStream.(*SelectStmt).Stream
// ensure we are not reading and writing to the same table.
// TODO(asdine): if same table, write content to a temp table.
@@ -157,9 +146,9 @@ func (stmt *InsertStmt) Prepare(c *Context) (Statement, error) {
}
// check unique constraints
indexNames := c.Tx.Catalog.ListIndexes(stmt.TableName)
indexNames := c.Conn.GetTx().Catalog.ListIndexes(stmt.TableName)
for _, indexName := range indexNames {
info, err := c.Tx.Catalog.GetIndexInfo(indexName)
info, err := c.Conn.GetTx().Catalog.GetIndexInfo(indexName)
if err != nil {
return nil, err
}
@@ -189,10 +178,6 @@ func (stmt *InsertStmt) Prepare(c *Context) (Statement, error) {
s = s.Pipe(stream.Discard())
}
st := StreamStmt{
Stream: s,
ReadOnly: false,
}
return st.Prepare(c)
stmt.PreparedStreamStmt.Stream = s
return stmt, nil
}

View File

@@ -12,34 +12,19 @@ var _ Statement = (*ReIndexStmt)(nil)
// ReIndexStmt is a DSL that allows creating a full REINDEX statement.
type ReIndexStmt struct {
basePreparedStatement
PreparedStreamStmt
TableOrIndexName string
}
func NewReIndexStatement() *ReIndexStmt {
var p ReIndexStmt
p.basePreparedStatement = basePreparedStatement{
Preparer: &p,
ReadOnly: false,
}
return &p
}
func (stmt *ReIndexStmt) Bind(ctx *Context) error {
return nil
}
// Prepare implements the Preparer interface.
func (stmt *ReIndexStmt) Prepare(ctx *Context) (Statement, error) {
var indexNames []string
if stmt.TableOrIndexName == "" {
indexNames = ctx.Tx.Catalog.Cache.ListObjects(database.RelationIndexType)
} else if _, err := ctx.Tx.Catalog.GetTable(ctx.Tx, stmt.TableOrIndexName); err == nil {
indexNames = ctx.Tx.Catalog.ListIndexes(stmt.TableOrIndexName)
indexNames = ctx.Conn.GetTx().Catalog.Cache.ListObjects(database.RelationIndexType)
} else if _, err := ctx.Conn.GetTx().Catalog.GetTable(ctx.Conn.GetTx(), stmt.TableOrIndexName); err == nil {
indexNames = ctx.Conn.GetTx().Catalog.ListIndexes(stmt.TableOrIndexName)
} else if !errs.IsNotFoundError(err) {
return nil, err
} else {
@@ -49,11 +34,11 @@ func (stmt *ReIndexStmt) Prepare(ctx *Context) (Statement, error) {
var streams []*stream.Stream
for _, indexName := range indexNames {
idx, err := ctx.Tx.Catalog.GetIndex(ctx.Tx, indexName)
idx, err := ctx.Conn.GetTx().Catalog.GetIndex(ctx.Conn.GetTx(), indexName)
if err != nil {
return nil, err
}
info, err := ctx.Tx.Catalog.GetIndexInfo(indexName)
info, err := ctx.Conn.GetTx().Catalog.GetIndexInfo(indexName)
if err != nil {
return nil, err
}
@@ -67,10 +52,8 @@ func (stmt *ReIndexStmt) Prepare(ctx *Context) (Statement, error) {
streams = append(streams, s)
}
st := StreamStmt{
Stream: stream.New(stream.Concat(streams...)).Pipe(stream.Discard()),
ReadOnly: false,
}
s := stream.New(stream.Concat(streams...)).Pipe(stream.Discard())
return st.Prepare(ctx)
stmt.PreparedStreamStmt.Stream = s
return stmt, nil
}

View File

@@ -43,13 +43,33 @@ func (stmt *SelectCoreStmt) Bind(ctx *Context) error {
return nil
}
func (stmt *SelectCoreStmt) Prepare(ctx *Context) (*StreamStmt, error) {
isReadOnly := true
func (stmt *SelectCoreStmt) IsReadOnly() bool {
var isReadOnly = true
// SELECT is read-only most of the time, unless it's using some expressions
// that require write access and that are allowed to be run, such as nextval
for _, e := range stmt.ProjectionExprs {
expr.Walk(e, func(e expr.Expr) bool {
switch e.(type) {
case *expr.NamedExpr:
return true
case *functions.NextVal:
isReadOnly = false
return false
default:
return true
}
})
}
return isReadOnly
}
func (stmt *SelectCoreStmt) Prepare(ctx *Context) (*stream.Stream, error) {
var s *stream.Stream
if stmt.TableName != "" {
_, err := ctx.Tx.Catalog.GetTableInfo(stmt.TableName)
_, err := ctx.Conn.GetTx().Catalog.GetTableInfo(stmt.TableName)
if err != nil {
return nil, err
}
@@ -148,35 +168,16 @@ func (stmt *SelectCoreStmt) Prepare(ctx *Context) (*StreamStmt, error) {
}
s = s.Pipe(rows.Project(stmt.ProjectionExprs...))
// SELECT is read-only most of the time, unless it's using some expressions
// that require write access and that are allowed to be run, such as nextval
for _, e := range stmt.ProjectionExprs {
expr.Walk(e, func(e expr.Expr) bool {
switch e.(type) {
case *expr.NamedExpr:
return true
case *functions.NextVal:
isReadOnly = false
return false
default:
return true
}
})
}
if stmt.Distinct {
s = stream.New(stream.Union(s))
}
return &StreamStmt{
Stream: s,
ReadOnly: isReadOnly,
}, nil
return s, nil
}
// SelectStmt holds SELECT configuration.
type SelectStmt struct {
basePreparedStatement
PreparedStreamStmt
CompoundSelect []*SelectCoreStmt
CompoundOperators []scanner.Token
@@ -186,15 +187,13 @@ type SelectStmt struct {
LimitExpr expr.Expr
}
func NewSelectStatement() *SelectStmt {
var p SelectStmt
p.basePreparedStatement = basePreparedStatement{
Preparer: &p,
ReadOnly: true,
func (stmt *SelectStmt) IsReadOnly() bool {
for i := range stmt.CompoundSelect {
if !stmt.CompoundSelect[i].IsReadOnly() {
return false
}
}
return &p
return true
}
func (stmt *SelectStmt) Bind(ctx *Context) error {
@@ -230,7 +229,6 @@ func (stmt *SelectStmt) Prepare(ctx *Context) (Statement, error) {
var prev scanner.Token
var coreStmts []*stream.Stream
var readOnly bool = true
for i, coreSelect := range stmt.CompoundSelect {
coreStmt, err := coreSelect.Prepare(ctx)
@@ -239,16 +237,11 @@ func (stmt *SelectStmt) Prepare(ctx *Context) (Statement, error) {
}
if len(stmt.CompoundSelect) == 1 {
s = coreStmt.Stream
readOnly = coreStmt.ReadOnly
s = coreStmt
break
}
coreStmts = append(coreStmts, coreStmt.Stream)
if !coreStmt.ReadOnly {
readOnly = false
}
coreStmts = append(coreStmts, coreStmt)
var tok scanner.Token
if i < len(stmt.CompoundOperators) {
@@ -285,10 +278,6 @@ func (stmt *SelectStmt) Prepare(ctx *Context) (Statement, error) {
s = s.Pipe(rows.Take(stmt.LimitExpr))
}
st := StreamStmt{
Stream: s,
ReadOnly: readOnly,
}
return st.Prepare(ctx)
stmt.PreparedStreamStmt.Stream = s
return stmt, nil
}

View File

@@ -1,6 +1,8 @@
package statement
import (
"context"
"github.com/chaisql/chai/internal/database"
"github.com/chaisql/chai/internal/environment"
"github.com/chaisql/chai/internal/expr"
@@ -9,33 +11,31 @@ import (
// A Statement represents a unique action that can be executed against the database.
type Statement interface {
Bind(*Context) error
Run(*Context) (Result, error)
Run(*Context) (*Result, error)
}
// Optional interface that allows a statement to specify if it is read-only.
// Defaults to false if not implemented.
type ReadOnly interface {
IsReadOnly() bool
}
type basePreparedStatement struct {
Preparer Preparer
ReadOnly bool
// Optional interface that allows a statement to specify if they need a transaction.
// Defaults to true if not implemented.
// If true, the engine will auto-commit.
type Transactional interface {
NeedsTransaction() bool
}
func (stmt *basePreparedStatement) IsReadOnly() bool {
return stmt.ReadOnly
}
func (stmt *basePreparedStatement) Run(ctx *Context) (Result, error) {
s, err := stmt.Preparer.Prepare(ctx)
if err != nil {
return Result{}, err
}
return s.Run(ctx)
// Optional interface that allows a statement to specify if they need to be bound to database
// objects.
type Bindable interface {
Bind(*Context) error
}
type Context struct {
DB *database.Database
Conn *database.Connection
Tx *database.Transaction
Params []environment.Param
}
@@ -90,7 +90,10 @@ func (r *Result) Iterate(fn func(r database.Row) error) (err error) {
// Skip iterates over the result and skips all rows.
// It is useful when you need the query to be executed
// but don't care about the results.
func (r *Result) Skip() (err error) {
func (r *Result) Skip(ctx context.Context) (err error) {
if r == nil {
return nil
}
if r.Result == nil {
return nil
}
@@ -107,6 +110,9 @@ func (r *Result) Skip() (err error) {
defer it.Close()
for it.Next() {
if err := ctx.Err(); err != nil {
return err
}
}
return it.Error()
@@ -122,7 +128,7 @@ func (r *Result) Columns() ([]string, error) {
return nil, nil
}
env := environment.New(stmt.Context.DB, stmt.Context.Tx, stmt.Context.Params, nil)
env := environment.New(stmt.Context.DB, stmt.Context.Conn.GetTx(), stmt.Context.Params, nil)
return stmt.Stream.Columns(env)
}
@@ -158,7 +164,7 @@ func BindExpr(ctx *Context, tableName string, e expr.Expr) (err error) {
var info *database.TableInfo
if tableName != "" {
info, err = ctx.Tx.Catalog.GetTableInfo(tableName)
info, err = ctx.Conn.GetTx().Catalog.GetTableInfo(tableName)
if err != nil {
return err
}

View File

@@ -15,33 +15,20 @@ type StreamStmt struct {
ReadOnly bool
}
// Prepare implements the Preparer interface.
func (s *StreamStmt) Prepare(ctx *Context) (Statement, error) {
return &PreparedStreamStmt{
Stream: s.Stream,
ReadOnly: s.ReadOnly,
}, nil
}
// PreparedStreamStmt is a PreparedStreamStmt using a Stream.
type PreparedStreamStmt struct {
Stream *stream.Stream
ReadOnly bool
}
func (s *PreparedStreamStmt) Bind(ctx *Context) error {
return nil
Stream *stream.Stream
}
// Run returns a result containing the stream. The stream will be executed by calling the Iterate method of
// the result.
func (s *PreparedStreamStmt) Run(ctx *Context) (Result, error) {
st, err := planner.Optimize(s.Stream.Clone(), ctx.Tx.Catalog, ctx.Params)
func (s *PreparedStreamStmt) Run(ctx *Context) (*Result, error) {
st, err := planner.Optimize(s.Stream.Clone(), ctx.Conn.GetTx().Catalog, ctx.Params)
if err != nil {
return Result{}, err
return nil, err
}
return Result{
return &Result{
Result: &StreamStmtResult{
Stream: st,
Context: ctx,
@@ -49,11 +36,6 @@ func (s *PreparedStreamStmt) Run(ctx *Context) (Result, error) {
}, nil
}
// IsReadOnly reports whether the stream will modify the database or only read it.
func (s *PreparedStreamStmt) IsReadOnly() bool {
return s.ReadOnly
}
func (s *PreparedStreamStmt) String() string {
return s.Stream.String()
}
@@ -65,7 +47,7 @@ type StreamStmtResult struct {
}
func (s *StreamStmtResult) Iterator() (database.Iterator, error) {
env := environment.New(s.Context.DB, s.Context.Tx, s.Context.Params, nil)
env := environment.New(s.Context.DB, s.Context.Conn.GetTx(), s.Context.Params, nil)
return s.Stream.Iterator(env)
}

View File

@@ -13,7 +13,7 @@ var _ Statement = (*UpdateStmt)(nil)
// UpdateConfig holds UPDATE configuration.
type UpdateStmt struct {
basePreparedStatement
PreparedStreamStmt
TableName string
@@ -25,17 +25,6 @@ type UpdateStmt struct {
WhereExpr expr.Expr
}
func NewUpdateStatement() *UpdateStmt {
var p UpdateStmt
p.basePreparedStatement = basePreparedStatement{
Preparer: &p,
ReadOnly: false,
}
return &p
}
type UpdateSetPair struct {
Column *expr.Column
E expr.Expr
@@ -64,7 +53,7 @@ func (stmt *UpdateStmt) Bind(ctx *Context) error {
// Prepare implements the Preparer interface.
func (stmt *UpdateStmt) Prepare(c *Context) (Statement, error) {
ti, err := c.Tx.Catalog.GetTableInfo(stmt.TableName)
ti, err := c.Conn.GetTx().Catalog.GetTableInfo(stmt.TableName)
if err != nil {
return nil, err
}
@@ -99,7 +88,7 @@ func (stmt *UpdateStmt) Prepare(c *Context) (Statement, error) {
// TODO(asdine): This removes ALL indexed fields for each row
// even if the update modified a single field. We should only
// update the indexed fields that were modified.
indexNames := c.Tx.Catalog.ListIndexes(stmt.TableName)
indexNames := c.Conn.GetTx().Catalog.ListIndexes(stmt.TableName)
for _, indexName := range indexNames {
s = s.Pipe(index.Delete(indexName))
}
@@ -114,7 +103,7 @@ func (stmt *UpdateStmt) Prepare(c *Context) (Statement, error) {
}
for _, indexName := range indexNames {
info, err := c.Tx.Catalog.GetIndexInfo(indexName)
info, err := c.Conn.GetTx().Catalog.GetIndexInfo(indexName)
if err != nil {
return nil, err
}
@@ -127,10 +116,6 @@ func (stmt *UpdateStmt) Prepare(c *Context) (Statement, error) {
s = s.Pipe(stream.Discard())
st := StreamStmt{
Stream: s,
ReadOnly: false,
}
return st.Prepare(c)
stmt.PreparedStreamStmt.Stream = s
return stmt, nil
}

View File

@@ -6,109 +6,58 @@ import (
"github.com/cockroachdb/errors"
)
var _ queryAlterer = BeginStmt{}
var _ queryAlterer = RollbackStmt{}
var _ queryAlterer = CommitStmt{}
// BeginStmt is a statement that creates a new transaction.
type BeginStmt struct {
Writable bool
}
func (stmt BeginStmt) Bind(ctx *statement.Context) error {
return nil
func (stmt BeginStmt) NeedsTransaction() bool {
return false
}
// Prepare implements the Preparer interface.
func (stmt BeginStmt) Prepare(*statement.Context) (statement.Statement, error) {
return stmt, nil
}
func (stmt BeginStmt) alterQuery(conn *database.Connection, q *Query) error {
if q.tx != nil {
return errors.New("cannot begin a transaction within a transaction")
func (stmt BeginStmt) Run(ctx *statement.Context) (*statement.Result, error) {
if ctx.Conn.GetTx() != nil {
return nil, errors.New("cannot begin a transaction within a transaction")
}
var err error
q.tx, err = conn.BeginTx(&database.TxOptions{
_, err := ctx.Conn.BeginTx(&database.TxOptions{
ReadOnly: !stmt.Writable,
})
q.autoCommit = false
return err
}
func (stmt BeginStmt) IsReadOnly() bool {
return !stmt.Writable
}
func (stmt BeginStmt) Run(ctx *statement.Context) (statement.Result, error) {
return statement.Result{}, errors.New("cannot begin a transaction within a transaction")
return nil, err
}
// RollbackStmt is a statement that rollbacks the current active transaction.
type RollbackStmt struct{}
func (stmt RollbackStmt) Bind(ctx *statement.Context) error {
return nil
}
// Prepare implements the Preparer interface.
func (stmt RollbackStmt) Prepare(*statement.Context) (statement.Statement, error) {
return stmt, nil
}
func (stmt RollbackStmt) alterQuery(conn *database.Connection, q *Query) error {
if q.tx == nil || q.autoCommit {
return errors.New("cannot rollback with no active transaction")
}
err := q.tx.Rollback()
if err != nil {
return err
}
q.tx = nil
q.autoCommit = true
return nil
}
func (stmt RollbackStmt) IsReadOnly() bool {
func (stmt RollbackStmt) NeedsTransaction() bool {
return false
}
func (stmt RollbackStmt) Run(ctx *statement.Context) (statement.Result, error) {
return statement.Result{}, errors.New("cannot rollback with no active transaction")
func (stmt RollbackStmt) Run(ctx *statement.Context) (*statement.Result, error) {
tx := ctx.Conn.GetTx()
if tx == nil {
return nil, errors.New("cannot rollback with no active transaction")
}
return nil, tx.Rollback()
}
// CommitStmt is a statement that commits the current active transaction.
type CommitStmt struct{}
func (stmt CommitStmt) Bind(ctx *statement.Context) error {
return nil
}
// Prepare implements the Preparer interface.
func (stmt CommitStmt) Prepare(*statement.Context) (statement.Statement, error) {
return stmt, nil
}
func (stmt CommitStmt) alterQuery(conn *database.Connection, q *Query) error {
if q.tx == nil || q.autoCommit {
return errors.New("cannot commit with no active transaction")
}
err := q.tx.Commit()
if err != nil {
return err
}
q.tx = nil
q.autoCommit = true
return nil
}
func (stmt CommitStmt) IsReadOnly() bool {
return false
}
func (stmt CommitStmt) Run(ctx *statement.Context) (statement.Result, error) {
return statement.Result{}, errors.New("cannot commit with no active transaction")
func (stmt CommitStmt) NeedsTransaction() bool {
return false
}
func (stmt CommitStmt) Run(ctx *statement.Context) (*statement.Result, error) {
tx := ctx.Conn.GetTx()
if tx == nil {
return nil, errors.New("cannot commit with no active transaction")
}
return nil, tx.Commit()
}

View File

@@ -17,8 +17,10 @@ import (
)
var (
_ driver.Driver = (*Driver)(nil)
_ driver.DriverContext = (*Driver)(nil)
_ driver.Driver = (*Driver)(nil)
_ driver.DriverContext = (*Driver)(nil)
_ driver.QueryerContext = (*Conn)(nil)
_ driver.ExecerContext = (*Conn)(nil)
)
// Driver is a driver.Driver that can open a new connection to a Chai database.
@@ -98,24 +100,100 @@ func (c *Conn) Prepare(q string) (driver.Stmt, error) {
return c.PrepareContext(context.Background(), q)
}
// PrepareContext returns a prepared statement, bound to this connection.
func (c *Conn) PrepareContext(ctx context.Context, q string) (driver.Stmt, error) {
pq, err := parser.ParseQuery(q)
func (c *Conn) ExecContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Result, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
statements, err := parser.ParseQuery(q)
if err != nil {
return nil, err
}
err = pq.Prepare(&query.Context{
Ctx: ctx,
DB: c.db,
Conn: c.conn,
res, err := query.New(statements...).Run(&query.Context{
Ctx: ctx,
DB: c.DB(),
Conn: c.conn,
Params: NamedValueToParams(args),
})
if err != nil {
return nil, err
}
defer func() {
er := res.Close()
if err == nil {
err = er
}
}()
return ExecResult{}, res.Skip(ctx)
}
func (c *Conn) QueryContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Rows, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
statements, err := parser.ParseQuery(q)
if err != nil {
return nil, err
}
res, err := query.New(statements...).Run(&query.Context{
Ctx: ctx,
DB: c.DB(),
Conn: c.conn,
Params: NamedValueToParams(args),
})
if err != nil {
return nil, err
}
return NewRows(res)
}
// PrepareContext returns a prepared statement, bound to this connection.
func (c *Conn) PrepareContext(ctx context.Context, q string) (driver.Stmt, error) {
statements, err := parser.ParseQuery(q)
if err != nil {
return nil, err
}
if len(statements) != 1 {
return nil, errors.New("cannot insert multiple commands into a prepared statement")
}
sctx := statement.Context{
DB: c.db,
Conn: c.conn,
}
tx, err := c.conn.BeginTx(&database.TxOptions{
ReadOnly: true,
})
if err != nil {
return nil, err
}
defer tx.Rollback()
stmt := statements[0]
if b, ok := stmt.(statement.Bindable); ok {
err = b.Bind(&sctx)
if err != nil {
return nil, err
}
}
if p, ok := stmt.(statement.Preparer); ok {
stmt, err = p.Prepare(&sctx)
if err != nil {
return nil, err
}
}
return &Stmt{
pq: pq,
stmt: stmt,
conn: c,
}, nil
}
@@ -162,7 +240,7 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
// Stmt is a prepared statement. It is bound to a Conn and not
// used by multiple goroutines concurrently.
type Stmt struct {
pq query.Query
stmt statement.Statement
conn *Conn
}
@@ -178,14 +256,11 @@ func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) {
// ExecContext executes a query that doesn't return rows, such
// as an INSERT or UPDATE.
func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
if err := ctx.Err(); err != nil {
return nil, err
}
res, err := s.pq.Run(&query.Context{
Ctx: ctx,
res, err := s.stmt.Run(&statement.Context{
DB: s.conn.db,
Conn: s.conn.conn,
Params: NamedValueToParams(args),
@@ -200,7 +275,26 @@ func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
}
}()
return ExecResult{}, res.Skip()
return ExecResult{}, res.Skip(ctx)
}
// QueryContext executes a query that may return rows, such as a
// SELECT.
func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
res, err := s.stmt.Run(&statement.Context{
DB: s.conn.db,
Conn: s.conn.conn,
Params: NamedValueToParams(args),
})
if err != nil {
return nil, err
}
return NewRows(res)
}
type ExecResult struct{}
@@ -219,28 +313,6 @@ func (s Stmt) Query(args []driver.Value) (driver.Rows, error) {
return nil, errors.New("not implemented")
}
// QueryContext executes a query that may return rows, such as a
// SELECT.
func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
res, err := s.pq.Run(&query.Context{
Ctx: ctx,
DB: s.conn.db,
Conn: s.conn.conn,
Params: NamedValueToParams(args),
})
if err != nil {
return nil, err
}
return NewRows(res)
}
// Close does nothing.
func (s Stmt) Close() error {
return nil
@@ -253,6 +325,10 @@ type Rows struct {
}
func NewRows(res *statement.Result) (*Rows, error) {
if res == nil {
return &Rows{}, nil
}
columns, err := res.Columns()
if err != nil {
return nil, err

View File

@@ -26,14 +26,14 @@ func TestParserAlterTable(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
q, err := parser.ParseQuery(test.s)
stmts, err := parser.ParseQuery(test.s)
if test.errored {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Len(t, q.Statements, 1)
require.EqualValues(t, test.expected, q.Statements[0])
require.Len(t, stmts, 1)
require.EqualValues(t, test.expected, stmts[0])
})
}
}
@@ -89,14 +89,14 @@ func TestParserAlterTableAddColumn(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
q, err := parser.ParseQuery(test.s)
stmts, err := parser.ParseQuery(test.s)
if test.errored {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Len(t, q.Statements, 1)
require.EqualValues(t, test.expected, q.Statements[0])
require.Len(t, stmts, 1)
require.EqualValues(t, test.expected, stmts[0])
})
}
}

View File

@@ -49,14 +49,14 @@ func TestParserCreateIndex(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
q, err := parser.ParseQuery(test.s)
stmts, err := parser.ParseQuery(test.s)
if test.errored {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Len(t, q.Statements, 1)
require.EqualValues(t, test.expected, q.Statements[0])
require.Len(t, stmts, 1)
require.EqualValues(t, test.expected, stmts[0])
})
}
}
@@ -190,14 +190,14 @@ func TestParserCreateSequence(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
q, err := parser.ParseQuery(test.s)
stmts, err := parser.ParseQuery(test.s)
if test.errored {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Len(t, q.Statements, 1)
require.EqualValues(t, test.expected, q.Statements[0])
require.Len(t, stmts, 1)
require.EqualValues(t, test.expected, stmts[0])
})
}
}

View File

@@ -9,7 +9,7 @@ import (
// parseDeleteStatement parses a delete string and returns a Statement AST row.
func (p *Parser) parseDeleteStatement() (statement.Statement, error) {
stmt := statement.NewDeleteStatement()
var stmt statement.DeleteStmt
var err error
// Parse "DELETE FROM".
@@ -49,5 +49,5 @@ func (p *Parser) parseDeleteStatement() (statement.Statement, error) {
return nil, err
}
return stmt, nil
return &stmt, nil
}

View File

@@ -1,11 +1,9 @@
package parser_test
import (
"context"
"testing"
"github.com/chaisql/chai/internal/expr"
"github.com/chaisql/chai/internal/query"
"github.com/chaisql/chai/internal/query/statement"
"github.com/chaisql/chai/internal/sql/parser"
"github.com/chaisql/chai/internal/stream"
@@ -23,7 +21,7 @@ func TestParserDelete(t *testing.T) {
parseExpr := func(s string) expr.Expr {
e := parser.MustParseExpr(s)
err := statement.BindExpr(&statement.Context{DB: db, Tx: tx, Conn: tx.Connection()}, "test", e)
err := statement.BindExpr(&statement.Context{DB: db, Conn: tx.Connection()}, "test", e)
require.NoError(t, err)
return e
}
@@ -75,18 +73,17 @@ func TestParserDelete(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
q, err := parser.ParseQuery(test.s)
stmts, err := parser.ParseQuery(test.s)
require.NoError(t, err)
err = q.Prepare(&query.Context{
Ctx: context.Background(),
require.Len(t, stmts, 1)
stmt, err := stmts[0].(statement.Preparer).Prepare(&statement.Context{
DB: db,
Conn: tx.Connection(),
})
require.NoError(t, err)
require.Len(t, q.Statements, 1)
require.EqualValues(t, &statement.PreparedStreamStmt{Stream: test.expected}, q.Statements[0].(*statement.PreparedStreamStmt))
require.Equal(t, test.expected.String(), stmt.(*statement.DeleteStmt).Stream.String())
})
}
}

View File

@@ -25,14 +25,14 @@ func TestParserDrop(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
q, err := parser.ParseQuery(test.s)
stmts, err := parser.ParseQuery(test.s)
if test.errored {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Len(t, q.Statements, 1)
require.EqualValues(t, test.expected, q.Statements[0])
require.Len(t, stmts, 1)
require.EqualValues(t, test.expected, stmts[0])
})
}
}

View File

@@ -10,7 +10,7 @@ import (
)
func TestParserExplain(t *testing.T) {
slct := statement.NewSelectStatement()
var slct statement.SelectStmt
slct.CompoundSelect = []*statement.SelectCoreStmt{
{TableName: "test", ProjectionExprs: []expr.Expr{expr.Wildcard{}}},
}
@@ -21,20 +21,20 @@ func TestParserExplain(t *testing.T) {
expected statement.Statement
errored bool
}{
{"Explain select", "EXPLAIN SELECT * FROM test", &statement.ExplainStmt{Statement: slct}, false},
{"Explain select", "EXPLAIN SELECT * FROM test", &statement.ExplainStmt{Statement: &slct}, false},
{"Multiple Explains", "EXPLAIN EXPLAIN CREATE TABLE test", nil, true},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
q, err := parser.ParseQuery(test.s)
stmts, err := parser.ParseQuery(test.s)
if test.errored {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Len(t, q.Statements, 1)
require.EqualValues(t, test.expected, q.Statements[0])
require.Len(t, stmts, 1)
require.EqualValues(t, test.expected, stmts[0])
})
}
}

View File

@@ -12,7 +12,7 @@ import (
// parseInsertStatement parses an insert string and returns a Statement AST row.
func (p *Parser) parseInsertStatement() (*statement.InsertStmt, error) {
stmt := statement.NewInsertStatement()
var stmt statement.InsertStmt
var err error
// Parse "INSERT INTO".
@@ -64,7 +64,7 @@ func (p *Parser) parseInsertStatement() (*statement.InsertStmt, error) {
return nil, err
}
return stmt, nil
return &stmt, nil
}
// parseColumnList parses a list of columns in the form: (column, column, ...), if exists.

View File

@@ -1,11 +1,9 @@
package parser_test
import (
"context"
"testing"
"github.com/chaisql/chai/internal/expr"
"github.com/chaisql/chai/internal/query"
"github.com/chaisql/chai/internal/query/statement"
"github.com/chaisql/chai/internal/sql/parser"
"github.com/chaisql/chai/internal/stream"
@@ -221,23 +219,22 @@ func TestParserInsert(t *testing.T) {
testutil.MustExec(t, db, tx, "CREATE TABLE test(a TEXT, b TEXT); CREATE TABLE foo(c TEXT, d TEXT);")
q, err := parser.ParseQuery(test.s)
stmts, err := parser.ParseQuery(test.s)
if test.fails {
require.Error(t, err)
return
}
require.NoError(t, err)
err = q.Prepare(&query.Context{
Ctx: context.Background(),
require.Len(t, stmts, 1)
stmt, err := stmts[0].(statement.Preparer).Prepare(&statement.Context{
DB: db,
Conn: tx.Connection(),
})
require.NoError(t, err)
require.Len(t, q.Statements, 1)
require.Equal(t, test.expected.String(), q.Statements[0].(*statement.PreparedStreamStmt).Stream.String())
require.Equal(t, test.expected.String(), stmt.(*statement.InsertStmt).Stream.String())
})
}
}

View File

@@ -6,7 +6,6 @@ import (
"strings"
"github.com/chaisql/chai/internal/expr"
"github.com/chaisql/chai/internal/query"
"github.com/chaisql/chai/internal/query/statement"
"github.com/chaisql/chai/internal/sql/scanner"
"github.com/chaisql/chai/internal/tree"
@@ -24,7 +23,7 @@ func NewParser(r io.Reader) *Parser {
}
// ParseQuery parses a query string and returns its AST representation.
func ParseQuery(s string) (query.Query, error) {
func ParseQuery(s string) ([]statement.Statement, error) {
return NewParser(strings.NewReader(s)).ParseQuery()
}
@@ -45,7 +44,7 @@ func MustParseExpr(s string) expr.Expr {
}
// ParseQuery parses a Chai SQL string and returns a Query.
func (p *Parser) ParseQuery() (query.Query, error) {
func (p *Parser) ParseQuery() ([]statement.Statement, error) {
var statements []statement.Statement
err := p.Parse(func(s statement.Statement) error {
@@ -53,10 +52,10 @@ func (p *Parser) ParseQuery() (query.Query, error) {
return nil
})
if err != nil {
return query.Query{}, err
return nil, err
}
return query.Query{Statements: statements}, nil
return statements, nil
}
// ParseQuery parses a Chai SQL string and returns a Query.

View File

@@ -10,8 +10,8 @@ import (
func FuzzParseQuery(f *testing.F) {
f.Fuzz(func(t *testing.T, s string) {
// Fuzz ParseQuery for panics.
q, err := ParseQuery(s)
if err != nil || len(q.Statements) < 1 {
statements, err := ParseQuery(s)
if err != nil || len(statements) < 1 {
t.Skip()
}
})

View File

@@ -11,12 +11,12 @@ import (
)
func TestParserMultiStatement(t *testing.T) {
slct := statement.NewSelectStatement()
var slct statement.SelectStmt
slct.CompoundSelect = []*statement.SelectCoreStmt{
{TableName: "foo", ProjectionExprs: []expr.Expr{expr.Wildcard{}}},
}
dlt := statement.NewDeleteStatement()
var dlt statement.DeleteStmt
dlt.TableName = "foo"
tests := []struct {
@@ -26,16 +26,16 @@ func TestParserMultiStatement(t *testing.T) {
}{
{"OnlyCommas", ";;;", nil},
{"TrailingComma", "SELECT * FROM foo;;;DELETE FROM foo;", []statement.Statement{
slct,
dlt,
&slct,
&dlt,
}},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
q, err := parser.ParseQuery(test.s)
stmts, err := parser.ParseQuery(test.s)
require.NoError(t, err)
require.EqualValues(t, test.expected, q.Statements)
require.EqualValues(t, test.expected, stmts)
})
}
}

View File

@@ -7,7 +7,7 @@ import (
// parseReindexStatement parses a reindex statement.
func (p *Parser) parseReIndexStatement() (statement.Statement, error) {
stmt := statement.NewReIndexStatement()
var stmt statement.ReIndexStmt
// Parse "REINDEX".
if err := p.ParseTokens(scanner.REINDEX); err != nil {
@@ -20,5 +20,5 @@ func (p *Parser) parseReIndexStatement() (statement.Statement, error) {
} else {
p.Unscan()
}
return stmt, nil
return &stmt, nil
}

View File

@@ -9,8 +9,8 @@ import (
)
func TestParserReIndex(t *testing.T) {
r1 := statement.NewReIndexStatement()
r2 := statement.NewReIndexStatement()
var r1 statement.ReIndexStmt
var r2 statement.ReIndexStmt
r2.TableOrIndexName = "tableOrIndex"
tests := []struct {
name string
@@ -18,21 +18,21 @@ func TestParserReIndex(t *testing.T) {
expected statement.Statement
errored bool
}{
{"All", "REINDEX", r1, false},
{"With ident", "REINDEX tableOrIndex", r2, false},
{"All", "REINDEX", &r1, false},
{"With ident", "REINDEX tableOrIndex", &r2, false},
{"With extra", "REINDEX tableOrIndex tableOrIndex", nil, true},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
q, err := parser.ParseQuery(test.s)
stmts, err := parser.ParseQuery(test.s)
if test.errored {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Len(t, q.Statements, 1)
require.EqualValues(t, test.expected, q.Statements[0])
require.Len(t, stmts, 1)
require.EqualValues(t, test.expected, stmts[0])
})
}
}

View File

@@ -10,10 +10,10 @@ import (
// parseSelectStatement parses a select string and returns a Statement AST row.
// This function assumes the SELECT token has already been consumed.
func (p *Parser) parseSelectStatement() (*statement.SelectStmt, error) {
stmt := statement.NewSelectStatement()
var stmt statement.SelectStmt
// Parse SELECT ... [UNION | UNION ALL | INTERSECT] SELECT ...
err := p.parseCompoundSelectStatement(stmt)
err := p.parseCompoundSelectStatement(&stmt)
if err != nil {
return nil, err
}
@@ -36,7 +36,7 @@ func (p *Parser) parseSelectStatement() (*statement.SelectStmt, error) {
return nil, errors.Wrap(err, "failed to parse OFFSET clause")
}
return stmt, nil
return &stmt, nil
}
func (p *Parser) parseCompoundSelectStatement(stmt *statement.SelectStmt) error {

View File

@@ -1,12 +1,10 @@
package parser_test
import (
"context"
"testing"
"github.com/chaisql/chai/internal/expr"
"github.com/chaisql/chai/internal/expr/functions"
"github.com/chaisql/chai/internal/query"
"github.com/chaisql/chai/internal/query/statement"
"github.com/chaisql/chai/internal/sql/parser"
"github.com/chaisql/chai/internal/stream"
@@ -37,7 +35,7 @@ func TestParserSelect(t *testing.T) {
if len(table) > 0 {
tb = table[0]
}
err := statement.BindExpr(&statement.Context{DB: db, Tx: tx, Conn: tx.Connection()}, tb, e)
err := statement.BindExpr(&statement.Context{DB: db, Conn: tx.Connection()}, tb, e)
require.NoError(t, err)
return e
}
@@ -391,27 +389,20 @@ func TestParserSelect(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
q, err := parser.ParseQuery(test.s)
stmts, err := parser.ParseQuery(test.s)
if test.mustFail {
require.Error(t, err)
return
}
require.Len(t, stmts, 1)
err = q.Prepare(&query.Context{
Ctx: context.Background(),
stmt, err := stmts[0].(statement.Preparer).Prepare(&statement.Context{
DB: db,
Conn: tx.Connection(),
})
require.NoError(t, err)
require.Len(t, q.Statements, 1)
require.EqualValues(t, &statement.PreparedStreamStmt{ReadOnly: test.readOnly, Stream: test.expected}, q.Statements[0].(*statement.PreparedStreamStmt))
require.Equal(t, test.expected.String(), stmt.(*statement.SelectStmt).Stream.String())
})
}
}
func BenchmarkSelect(b *testing.B) {
for i := 0; i < b.N; i++ {
_, _ = parser.ParseQuery("SELECT a, b AS `foo` FROM `some table` WHERE d.e[100] >= 12 AND c.d IN ([1, true], [2, false]) GROUP BY d.e[0] LIMIT 10 + 10 OFFSET 20 - 20 ORDER BY d DESC")
}
}

View File

@@ -29,14 +29,14 @@ func TestParserTransactions(t *testing.T) {
for _, test := range tests {
t.Run(test.s, func(t *testing.T) {
q, err := parser.ParseQuery(test.s)
stmts, err := parser.ParseQuery(test.s)
if test.errored {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Len(t, q.Statements, 1)
require.EqualValues(t, test.expected, q.Statements[0])
require.Len(t, stmts, 1)
require.EqualValues(t, test.expected, stmts[0])
})
}
}

View File

@@ -8,7 +8,7 @@ import (
// parseUpdateStatement parses a update string and returns a Statement AST row.
func (p *Parser) parseUpdateStatement() (*statement.UpdateStmt, error) {
stmt := statement.NewUpdateStatement()
var stmt statement.UpdateStmt
var err error
// Parse "UPDATE".
@@ -42,7 +42,7 @@ func (p *Parser) parseUpdateStatement() (*statement.UpdateStmt, error) {
return nil, err
}
return stmt, nil
return &stmt, nil
}
// parseSetClause parses the "SET" clause of the query.

View File

@@ -1,11 +1,9 @@
package parser_test
import (
"context"
"testing"
"github.com/chaisql/chai/internal/expr"
"github.com/chaisql/chai/internal/query"
"github.com/chaisql/chai/internal/query/statement"
"github.com/chaisql/chai/internal/sql/parser"
"github.com/chaisql/chai/internal/stream"
@@ -28,7 +26,7 @@ func TestParserUpdate(t *testing.T) {
if len(table) > 0 {
tb = table[0]
}
err := statement.BindExpr(&statement.Context{DB: db, Tx: tx, Conn: tx.Connection()}, tb, e)
err := statement.BindExpr(&statement.Context{DB: db, Conn: tx.Connection()}, tb, e)
require.NoError(t, err)
return e
}
@@ -66,22 +64,21 @@ func TestParserUpdate(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
q, err := parser.ParseQuery(test.s)
stmts, err := parser.ParseQuery(test.s)
if test.errored {
require.Error(t, err)
return
}
require.NoError(t, err)
err = q.Prepare(&query.Context{
Ctx: context.Background(),
stmt, err := stmts[0].(statement.Preparer).Prepare(&statement.Context{
DB: db,
Conn: tx.Connection(),
})
require.NoError(t, err)
require.Len(t, q.Statements, 1)
require.EqualValues(t, &statement.PreparedStreamStmt{Stream: test.expected}, q.Statements[0].(*statement.PreparedStreamStmt))
require.Len(t, stmts, 1)
require.EqualValues(t, test.expected.String(), stmt.(*statement.UpdateStmt).Stream.String())
})
}
}

View File

@@ -85,8 +85,7 @@ func NewTestTx(t testing.TB) (*database.Database, *database.Transaction, func())
require.NoError(t, err)
return db, tx, func() {
err = tx.Rollback()
require.NoError(t, err)
_ = tx.Rollback()
}
}
@@ -97,7 +96,7 @@ func Exec(db *database.Database, tx *database.Transaction, q string, params ...e
}
defer res.Close()
return res.Skip()
return res.Skip(context.Background())
}
func Query(db *database.Database, tx *database.Transaction, q string, params ...environment.Param) (*statement.Result, error) {
@@ -107,12 +106,8 @@ func Query(db *database.Database, tx *database.Transaction, q string, params ...
}
ctx := &query.Context{Ctx: context.Background(), DB: db, Conn: tx.Connection(), Params: params}
err = pq.Prepare(ctx)
if err != nil {
return nil, err
}
return pq.Run(ctx)
return query.New(pq...).Run(ctx)
}
func MustExec(t *testing.T, db *database.Database, tx *database.Transaction, q string, params ...environment.Param) {