mirror of
https://github.com/chaisql/chai.git
synced 2025-09-26 19:51:21 +08:00
fix prepared statements
This commit is contained in:
@@ -31,18 +31,11 @@ func ExecSQL(ctx context.Context, db *sql.DB, r io.Reader, w io.Writer) error {
|
||||
|
||||
var stmtWithOutputCount int
|
||||
return parser.NewParser(r).Parse(func(s statement.Statement) error {
|
||||
qq := query.New(s)
|
||||
qctx := query.Context{
|
||||
res, err := query.New(s).Run(&query.Context{
|
||||
Ctx: ctx,
|
||||
DB: conn.DB(),
|
||||
Conn: conn.Conn(),
|
||||
}
|
||||
err := qq.Prepare(&qctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := qq.Run(&qctx)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@@ -19,12 +19,18 @@ func TestExecSQL(t *testing.T) {
|
||||
err = ExecSQL(t.Context(), db, strings.NewReader(`
|
||||
CREATE TABLE test(a INT, b TEXT);
|
||||
CREATE INDEX idx_a ON test (a);
|
||||
BEGIN;
|
||||
INSERT INTO test (a, b) VALUES (10, 'aa'), (20, 'bb'), (30, 'cc');
|
||||
ROLLBACK;
|
||||
BEGIN;
|
||||
INSERT INTO test (a, b) VALUES (1, 'a'), (2, 'b'), (3, 'c');
|
||||
SELECT * FROM test;
|
||||
COMMIT;
|
||||
SELECT b, a FROM test;
|
||||
`), &got)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "a|b\n1|\"a\"\n2|\"b\"\n3|\"c\"\n", got.String())
|
||||
require.Equal(t, "a|b\n1|\"a\"\n2|\"b\"\n3|\"c\"\n\nb|a\n\"a\"|1\n\"b\"|2\n\"c\"|3\n", got.String())
|
||||
|
||||
var res struct {
|
||||
A int
|
||||
|
@@ -70,12 +70,12 @@ func ListIndexes(ctx context.Context, db *sql.DB, tableName string) ([]string, e
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q, err := parser.ParseQuery(query)
|
||||
statements, err := parser.ParseQuery(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
listName = append(listName, q.Statements[0].(*statement.CreateIndexStmt).Info.IndexName)
|
||||
listName = append(listName, statements[0].(*statement.CreateIndexStmt).Info.IndexName)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
|
@@ -20,7 +20,13 @@ import (
|
||||
func updateCatalog(t testing.TB, db *database.Database, fn func(tx *database.Transaction, catalog *database.CatalogWriter) error) {
|
||||
t.Helper()
|
||||
|
||||
tx, err := db.Begin(true)
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
tx, err := conn.BeginTx(&database.TxOptions{
|
||||
ReadOnly: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
|
@@ -95,7 +95,7 @@ func Open(path string, opts *Options) (*Database, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tx, err := db.Begin(true)
|
||||
tx, err := db.begin(true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -183,7 +183,7 @@ func (db *Database) Connect() (*Connection, error) {
|
||||
|
||||
// Begin starts a new transaction with default options.
|
||||
// The returned transaction must be closed either by calling Rollback or Commit.
|
||||
func (db *Database) Begin(writable bool) (*Transaction, error) {
|
||||
func (db *Database) begin(writable bool) (*Transaction, error) {
|
||||
if db.closeContext.Err() != nil {
|
||||
return nil, errors.New("database is closed")
|
||||
}
|
||||
|
@@ -157,7 +157,13 @@ func TestSequence(t *testing.T) {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
db := testutil.NewTestDB(t)
|
||||
|
||||
tx, err := db.Begin(true)
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
tx, err := conn.BeginTx(&database.TxOptions{
|
||||
ReadOnly: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
@@ -193,7 +199,13 @@ func TestSequence(t *testing.T) {
|
||||
t.Run("default cache", func(t *testing.T) {
|
||||
db := testutil.NewTestDB(t)
|
||||
|
||||
tx, err := db.Begin(true)
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
tx, err := conn.BeginTx(&database.TxOptions{
|
||||
ReadOnly: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
@@ -230,7 +242,13 @@ func TestSequence(t *testing.T) {
|
||||
t.Run("cache", func(t *testing.T) {
|
||||
db := testutil.NewTestDB(t)
|
||||
|
||||
tx, err := db.Begin(true)
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
tx, err := conn.BeginTx(&database.TxOptions{
|
||||
ReadOnly: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
@@ -276,7 +294,13 @@ func TestSequence(t *testing.T) {
|
||||
t.Run("cache desc", func(t *testing.T) {
|
||||
db := testutil.NewTestDB(t)
|
||||
|
||||
tx, err := db.Begin(true)
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
tx, err := conn.BeginTx(&database.TxOptions{
|
||||
ReadOnly: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
@@ -323,7 +347,13 @@ func TestSequence(t *testing.T) {
|
||||
t.Run("read-only", func(t *testing.T) {
|
||||
db := testutil.NewTestDB(t)
|
||||
|
||||
tx, err := db.Begin(true)
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
tx, err := conn.BeginTx(&database.TxOptions{
|
||||
ReadOnly: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tx.CatalogWriter().CreateSequence(tx, &database.SequenceInfo{
|
||||
@@ -339,7 +369,9 @@ func TestSequence(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// open a read-only tx
|
||||
tx, err = db.Begin(false)
|
||||
tx, err = conn.BeginTx(&database.TxOptions{
|
||||
ReadOnly: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
@@ -353,7 +385,13 @@ func TestSequence(t *testing.T) {
|
||||
t.Run("release", func(t *testing.T) {
|
||||
db := testutil.NewTestDB(t)
|
||||
|
||||
tx, err := db.Begin(true)
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
tx, err := conn.BeginTx(&database.TxOptions{
|
||||
ReadOnly: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
|
@@ -21,7 +21,13 @@ var errDontCommit = errors.New("don't commit please")
|
||||
func update(t testing.TB, db *database.Database, fn func(tx *database.Transaction) error) {
|
||||
t.Helper()
|
||||
|
||||
tx, err := db.Begin(true)
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
tx, err := conn.BeginTx(&database.TxOptions{
|
||||
ReadOnly: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
@@ -63,7 +69,7 @@ func createTable(t testing.TB, tx *database.Transaction, info database.TableInfo
|
||||
stmt := statement.CreateTableStmt{Info: info}
|
||||
|
||||
res, err := stmt.Run(&statement.Context{
|
||||
Tx: tx,
|
||||
Conn: tx.Connection(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
res.Close()
|
||||
@@ -80,7 +86,7 @@ func createTableIfNotExists(t testing.TB, tx *database.Transaction, info databas
|
||||
stmt := statement.CreateTableStmt{Info: info, IfNotExists: true}
|
||||
|
||||
res, err := stmt.Run(&statement.Context{
|
||||
Tx: tx,
|
||||
Conn: tx.Connection(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
res.Close()
|
||||
|
@@ -13,13 +13,11 @@ import (
|
||||
// Results are returned as streams.
|
||||
type Query struct {
|
||||
Statements []statement.Statement
|
||||
tx *database.Transaction
|
||||
autoCommit bool
|
||||
}
|
||||
|
||||
// New creates a new query with the given statements.
|
||||
func New(statements ...statement.Statement) Query {
|
||||
return Query{Statements: statements}
|
||||
func New(statements ...statement.Statement) *Query {
|
||||
return &Query{Statements: statements}
|
||||
}
|
||||
|
||||
type Context struct {
|
||||
@@ -29,133 +27,101 @@ type Context struct {
|
||||
Params []environment.Param
|
||||
}
|
||||
|
||||
func (c *Context) GetTx() *database.Transaction {
|
||||
return c.Conn.GetTx()
|
||||
}
|
||||
|
||||
// Prepare the statements by calling their Prepare methods.
|
||||
// It stops at the first statement that doesn't implement the statement.Preparer interface.
|
||||
func (q *Query) Prepare(context *Context) error {
|
||||
// Run executes all the statements in their own transaction and returns the last result.
|
||||
func (q *Query) Run(c *Context) (*statement.Result, error) {
|
||||
var res *statement.Result
|
||||
var err error
|
||||
|
||||
ctx := c.Ctx
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
sctx := statement.Context{
|
||||
DB: c.DB,
|
||||
Conn: c.Conn,
|
||||
Params: c.Params,
|
||||
}
|
||||
|
||||
var tx *database.Transaction
|
||||
|
||||
ctx := context.Ctx
|
||||
|
||||
for i, stmt := range q.Statements {
|
||||
if ctx != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tx == nil {
|
||||
tx = context.GetTx()
|
||||
if tx == nil {
|
||||
tx, err = context.Conn.BeginTx(&database.TxOptions{
|
||||
ReadOnly: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
}
|
||||
}
|
||||
|
||||
sctx := &statement.Context{
|
||||
DB: context.DB,
|
||||
Conn: context.Conn,
|
||||
Tx: tx,
|
||||
}
|
||||
|
||||
err = stmt.Bind(sctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p, ok := stmt.(statement.Preparer)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt, err := p.Prepare(sctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q.Statements[i] = stmt
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run executes all the statements in their own transaction and returns the last result.
|
||||
func (q *Query) Run(context *Context) (*statement.Result, error) {
|
||||
var res statement.Result
|
||||
var err error
|
||||
|
||||
q.tx = context.GetTx()
|
||||
if q.tx == nil {
|
||||
q.autoCommit = true
|
||||
}
|
||||
|
||||
ctx := context.Ctx
|
||||
|
||||
for i, stmt := range q.Statements {
|
||||
if ctx != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
// reinitialize the result
|
||||
res = statement.Result{}
|
||||
res = nil
|
||||
|
||||
if qa, ok := stmt.(queryAlterer); ok {
|
||||
err = qa.alterQuery(context.Conn, q)
|
||||
if err != nil {
|
||||
if tx := context.GetTx(); tx != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
continue
|
||||
// handles transactions
|
||||
isReadOnly := false
|
||||
if ro, ok := stmt.(statement.ReadOnly); ok {
|
||||
isReadOnly = ro.IsReadOnly()
|
||||
}
|
||||
|
||||
if q.tx == nil {
|
||||
q.tx, err = context.Conn.BeginTx(&database.TxOptions{
|
||||
ReadOnly: stmt.IsReadOnly(),
|
||||
needsTx := true
|
||||
if ntx, ok := stmt.(statement.Transactional); ok {
|
||||
needsTx = ntx.NeedsTransaction()
|
||||
}
|
||||
|
||||
var autoCommit bool
|
||||
|
||||
if c.Conn.GetTx() == nil && needsTx {
|
||||
autoCommit = true
|
||||
|
||||
tx, err = c.Conn.BeginTx(&database.TxOptions{
|
||||
ReadOnly: isReadOnly,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
res, err = stmt.Run(&statement.Context{
|
||||
DB: context.DB,
|
||||
Conn: context.Conn,
|
||||
Tx: q.tx,
|
||||
Params: context.Params,
|
||||
})
|
||||
// bind
|
||||
if b, ok := stmt.(statement.Bindable); ok {
|
||||
err = b.Bind(&sctx)
|
||||
if err != nil {
|
||||
if tx != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// prepare
|
||||
if prep, ok := stmt.(statement.Preparer); ok {
|
||||
stmt, err = prep.Prepare(&sctx)
|
||||
if err != nil {
|
||||
if tx != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// run
|
||||
res, err = stmt.Run(&sctx)
|
||||
if err != nil {
|
||||
if q.autoCommit {
|
||||
_ = q.tx.Rollback()
|
||||
if tx != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
if res == nil {
|
||||
res = &statement.Result{}
|
||||
}
|
||||
|
||||
// if there are still statements to be executed,
|
||||
// and the current statement is not read-only,
|
||||
// iterate over the result.
|
||||
if !stmt.IsReadOnly() && i+1 < len(q.Statements) {
|
||||
err = res.Skip()
|
||||
if res != nil && !isReadOnly && i+1 < len(q.Statements) {
|
||||
err = res.Skip(ctx)
|
||||
if err != nil {
|
||||
if q.autoCommit {
|
||||
_ = q.tx.Rollback()
|
||||
if tx != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
@@ -164,31 +130,27 @@ func (q *Query) Run(context *Context) (*statement.Result, error) {
|
||||
|
||||
// it there is an opened transaction but there are still statements
|
||||
// to be executed, close the current transaction.
|
||||
if q.tx != nil && q.autoCommit && i+1 < len(q.Statements) {
|
||||
if q.tx.Writable {
|
||||
err := q.tx.Commit()
|
||||
if autoCommit && i+1 < len(q.Statements) {
|
||||
if tx.Writable {
|
||||
err := tx.Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
err := q.tx.Rollback()
|
||||
err := tx.Rollback()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
q.tx = nil
|
||||
tx = nil
|
||||
}
|
||||
}
|
||||
|
||||
if q.autoCommit {
|
||||
if tx != nil {
|
||||
// the returned result will now own the transaction.
|
||||
// its Close method is expected to be called.
|
||||
res.Tx = q.tx
|
||||
res.Tx = tx
|
||||
}
|
||||
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
type queryAlterer interface {
|
||||
alterQuery(conn *database.Connection, q *Query) error
|
||||
return res, nil
|
||||
}
|
||||
|
@@ -18,34 +18,23 @@ type AlterTableRenameStmt struct {
|
||||
NewTableName string
|
||||
}
|
||||
|
||||
func (stmt *AlterTableRenameStmt) Bind(ctx *Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsReadOnly always returns false. It implements the Statement interface.
|
||||
func (stmt *AlterTableRenameStmt) IsReadOnly() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Run runs the ALTER TABLE statement in the given transaction.
|
||||
// It implements the Statement interface.
|
||||
func (stmt *AlterTableRenameStmt) Run(ctx *Context) (Result, error) {
|
||||
var res Result
|
||||
|
||||
func (stmt *AlterTableRenameStmt) Run(ctx *Context) (*Result, error) {
|
||||
if stmt.TableName == "" {
|
||||
return res, errors.New("missing table name")
|
||||
return nil, errors.New("missing table name")
|
||||
}
|
||||
|
||||
if stmt.NewTableName == "" {
|
||||
return res, errors.New("missing new table name")
|
||||
return nil, errors.New("missing new table name")
|
||||
}
|
||||
|
||||
if stmt.TableName == stmt.NewTableName {
|
||||
return res, errs.AlreadyExistsError{Name: stmt.NewTableName}
|
||||
return nil, errs.AlreadyExistsError{Name: stmt.NewTableName}
|
||||
}
|
||||
|
||||
err := ctx.Tx.CatalogWriter().RenameTable(ctx.Tx, stmt.TableName, stmt.NewTableName)
|
||||
return res, err
|
||||
err := ctx.Conn.GetTx().CatalogWriter().RenameTable(ctx.Conn.GetTx(), stmt.TableName, stmt.NewTableName)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
type AlterTableAddColumnStmt struct {
|
||||
@@ -54,41 +43,32 @@ type AlterTableAddColumnStmt struct {
|
||||
TableConstraints database.TableConstraints
|
||||
}
|
||||
|
||||
// IsReadOnly always returns false. It implements the Statement interface.
|
||||
func (stmt *AlterTableAddColumnStmt) IsReadOnly() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (stmt *AlterTableAddColumnStmt) Bind(ctx *Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run runs the ALTER TABLE ADD COLUMN statement in the given transaction.
|
||||
// It implements the Statement interface.
|
||||
// The statement rebuilds the table.
|
||||
func (stmt *AlterTableAddColumnStmt) Run(ctx *Context) (Result, error) {
|
||||
func (stmt *AlterTableAddColumnStmt) Run(ctx *Context) (*Result, error) {
|
||||
var err error
|
||||
|
||||
// get the table before adding the column constraint
|
||||
// and assign the table to the table.Scan operator
|
||||
// so that it can decode the records properly
|
||||
scan := table.Scan(stmt.TableName)
|
||||
scan.Table, err = ctx.Tx.Catalog.GetTable(ctx.Tx, stmt.TableName)
|
||||
scan.Table, err = ctx.Conn.GetTx().Catalog.GetTable(ctx.Conn.GetTx(), stmt.TableName)
|
||||
if err != nil {
|
||||
return Result{}, errors.Wrap(err, "failed to get table")
|
||||
return nil, errors.Wrap(err, "failed to get table")
|
||||
}
|
||||
|
||||
// get the current list of indexes
|
||||
indexNames := ctx.Tx.Catalog.ListIndexes(stmt.TableName)
|
||||
indexNames := ctx.Conn.GetTx().Catalog.ListIndexes(stmt.TableName)
|
||||
|
||||
// add the column constraint to the table
|
||||
err = ctx.Tx.CatalogWriter().AddColumnConstraint(
|
||||
ctx.Tx,
|
||||
err = ctx.Conn.GetTx().CatalogWriter().AddColumnConstraint(
|
||||
ctx.Conn.GetTx(),
|
||||
stmt.TableName,
|
||||
stmt.ColumnConstraint,
|
||||
stmt.TableConstraints)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// create a unique index for every unique constraint
|
||||
@@ -96,7 +76,7 @@ func (stmt *AlterTableAddColumnStmt) Run(ctx *Context) (Result, error) {
|
||||
var newIdxs []*database.IndexInfo
|
||||
for _, tc := range stmt.TableConstraints {
|
||||
if tc.Unique {
|
||||
idx, err := ctx.Tx.CatalogWriter().CreateIndex(ctx.Tx, &database.IndexInfo{
|
||||
idx, err := ctx.Conn.GetTx().CatalogWriter().CreateIndex(ctx.Conn.GetTx(), &database.IndexInfo{
|
||||
Columns: tc.Columns,
|
||||
Unique: true,
|
||||
Owner: database.Owner{
|
||||
@@ -105,7 +85,7 @@ func (stmt *AlterTableAddColumnStmt) Run(ctx *Context) (Result, error) {
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newIdxs = append(newIdxs, idx)
|
||||
@@ -141,11 +121,11 @@ func (stmt *AlterTableAddColumnStmt) Run(ctx *Context) (Result, error) {
|
||||
s = s.Pipe(table.Insert(stmt.TableName))
|
||||
|
||||
// insert the record into the all the indexes
|
||||
indexNames = ctx.Tx.Catalog.ListIndexes(stmt.TableName)
|
||||
indexNames = ctx.Conn.GetTx().Catalog.ListIndexes(stmt.TableName)
|
||||
for _, indexName := range indexNames {
|
||||
info, err := ctx.Tx.Catalog.GetIndexInfo(indexName)
|
||||
info, err := ctx.Conn.GetTx().Catalog.GetIndexInfo(indexName)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
return nil, err
|
||||
}
|
||||
if info.Unique {
|
||||
s = s.Pipe(index.Validate(indexName))
|
||||
@@ -176,7 +156,7 @@ func (stmt *AlterTableAddColumnStmt) Run(ctx *Context) (Result, error) {
|
||||
s = s.Pipe(stream.Discard())
|
||||
|
||||
// do NOT optimize the stream
|
||||
return Result{
|
||||
return &Result{
|
||||
Result: &StreamStmtResult{
|
||||
Stream: s,
|
||||
Context: ctx,
|
||||
|
@@ -20,20 +20,9 @@ type CreateTableStmt struct {
|
||||
Info database.TableInfo
|
||||
}
|
||||
|
||||
// IsReadOnly always returns false. It implements the Statement interface.
|
||||
func (stmt *CreateTableStmt) IsReadOnly() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (stmt *CreateTableStmt) Bind(ctx *Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run runs the Create table statement in the given transaction.
|
||||
// It implements the Statement interface.
|
||||
func (stmt *CreateTableStmt) Run(ctx *Context) (Result, error) {
|
||||
var res Result
|
||||
|
||||
func (stmt *CreateTableStmt) Run(ctx *Context) (*Result, error) {
|
||||
// if there is no primary key, create a rowid sequence
|
||||
if stmt.Info.PrimaryKey == nil {
|
||||
seq := database.SequenceInfo{
|
||||
@@ -45,25 +34,25 @@ func (stmt *CreateTableStmt) Run(ctx *Context) (Result, error) {
|
||||
TableName: stmt.Info.TableName,
|
||||
},
|
||||
}
|
||||
err := ctx.Tx.CatalogWriter().CreateSequence(ctx.Tx, &seq)
|
||||
err := ctx.Conn.GetTx().CatalogWriter().CreateSequence(ctx.Conn.GetTx(), &seq)
|
||||
if err != nil {
|
||||
return res, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stmt.Info.RowidSequenceName = seq.Name
|
||||
}
|
||||
|
||||
err := ctx.Tx.CatalogWriter().CreateTable(ctx.Tx, stmt.Info.TableName, &stmt.Info)
|
||||
err := ctx.Conn.GetTx().CatalogWriter().CreateTable(ctx.Conn.GetTx(), stmt.Info.TableName, &stmt.Info)
|
||||
if stmt.IfNotExists {
|
||||
if errs.IsAlreadyExistsError(err) {
|
||||
return res, nil
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// create a unique index for every unique constraint
|
||||
for _, tc := range stmt.Info.TableConstraints {
|
||||
if tc.Unique {
|
||||
_, err = ctx.Tx.CatalogWriter().CreateIndex(ctx.Tx, &database.IndexInfo{
|
||||
_, err = ctx.Conn.GetTx().CatalogWriter().CreateIndex(ctx.Conn.GetTx(), &database.IndexInfo{
|
||||
Columns: tc.Columns,
|
||||
Unique: true,
|
||||
Owner: database.Owner{
|
||||
@@ -73,12 +62,12 @@ func (stmt *CreateTableStmt) Run(ctx *Context) (Result, error) {
|
||||
KeySortOrder: tc.SortOrder,
|
||||
})
|
||||
if err != nil {
|
||||
return res, err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return res, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// CreateIndexStmt represents a parsed CREATE INDEX statement.
|
||||
@@ -87,28 +76,17 @@ type CreateIndexStmt struct {
|
||||
Info database.IndexInfo
|
||||
}
|
||||
|
||||
// IsReadOnly always returns false. It implements the Statement interface.
|
||||
func (stmt *CreateIndexStmt) IsReadOnly() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (stmt *CreateIndexStmt) Bind(ctx *Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run runs the Create index statement in the given transaction.
|
||||
// It implements the Statement interface.
|
||||
func (stmt *CreateIndexStmt) Run(ctx *Context) (Result, error) {
|
||||
var res Result
|
||||
|
||||
_, err := ctx.Tx.CatalogWriter().CreateIndex(ctx.Tx, &stmt.Info)
|
||||
func (stmt *CreateIndexStmt) Run(ctx *Context) (*Result, error) {
|
||||
_, err := ctx.Conn.GetTx().CatalogWriter().CreateIndex(ctx.Conn.GetTx(), &stmt.Info)
|
||||
if stmt.IfNotExists {
|
||||
if errs.IsAlreadyExistsError(err) {
|
||||
return res, nil
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return res, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := stream.New(table.Scan(stmt.Info.Owner.TableName)).
|
||||
@@ -116,8 +94,7 @@ func (stmt *CreateIndexStmt) Run(ctx *Context) (Result, error) {
|
||||
Pipe(stream.Discard())
|
||||
|
||||
ss := PreparedStreamStmt{
|
||||
Stream: s,
|
||||
ReadOnly: false,
|
||||
Stream: s,
|
||||
}
|
||||
|
||||
return ss.Run(ctx)
|
||||
@@ -129,25 +106,14 @@ type CreateSequenceStmt struct {
|
||||
Info database.SequenceInfo
|
||||
}
|
||||
|
||||
// IsReadOnly always returns false. It implements the Statement interface.
|
||||
func (stmt *CreateSequenceStmt) IsReadOnly() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (stmt *CreateSequenceStmt) Bind(ctx *Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run the statement in the given transaction.
|
||||
// It implements the Statement interface.
|
||||
func (stmt *CreateSequenceStmt) Run(ctx *Context) (Result, error) {
|
||||
var res Result
|
||||
|
||||
err := ctx.Tx.CatalogWriter().CreateSequence(ctx.Tx, &stmt.Info)
|
||||
func (stmt *CreateSequenceStmt) Run(ctx *Context) (*Result, error) {
|
||||
err := ctx.Conn.GetTx().CatalogWriter().CreateSequence(ctx.Conn.GetTx(), &stmt.Info)
|
||||
if stmt.IfNotExists {
|
||||
if errs.IsAlreadyExistsError(err) {
|
||||
return res, nil
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
return res, err
|
||||
return nil, err
|
||||
}
|
||||
|
@@ -13,7 +13,7 @@ var _ Statement = (*DeleteStmt)(nil)
|
||||
|
||||
// DeleteConfig holds DELETE configuration.
|
||||
type DeleteStmt struct {
|
||||
basePreparedStatement
|
||||
PreparedStreamStmt
|
||||
|
||||
TableName string
|
||||
WhereExpr expr.Expr
|
||||
@@ -23,17 +23,6 @@ type DeleteStmt struct {
|
||||
OrderByDirection scanner.Token
|
||||
}
|
||||
|
||||
func NewDeleteStatement() *DeleteStmt {
|
||||
var p DeleteStmt
|
||||
|
||||
p.basePreparedStatement = basePreparedStatement{
|
||||
Preparer: &p,
|
||||
ReadOnly: false,
|
||||
}
|
||||
|
||||
return &p
|
||||
}
|
||||
|
||||
func (stmt *DeleteStmt) Bind(ctx *Context) error {
|
||||
err := BindExpr(ctx, stmt.TableName, stmt.WhereExpr)
|
||||
if err != nil {
|
||||
@@ -81,7 +70,7 @@ func (stmt *DeleteStmt) Prepare(c *Context) (Statement, error) {
|
||||
s = s.Pipe(rows.Take(stmt.LimitExpr))
|
||||
}
|
||||
|
||||
indexNames := c.Tx.Catalog.ListIndexes(stmt.TableName)
|
||||
indexNames := c.Conn.GetTx().Catalog.ListIndexes(stmt.TableName)
|
||||
for _, indexName := range indexNames {
|
||||
s = s.Pipe(index.Delete(indexName))
|
||||
}
|
||||
@@ -90,10 +79,6 @@ func (stmt *DeleteStmt) Prepare(c *Context) (Statement, error) {
|
||||
|
||||
s = s.Pipe(stream.Discard())
|
||||
|
||||
st := StreamStmt{
|
||||
Stream: s,
|
||||
ReadOnly: false,
|
||||
}
|
||||
|
||||
return st.Prepare(c)
|
||||
stmt.PreparedStreamStmt.Stream = s
|
||||
return stmt, nil
|
||||
}
|
||||
|
@@ -28,36 +28,34 @@ func (stmt *DropTableStmt) Bind(ctx *Context) error {
|
||||
|
||||
// Run runs the DropTable statement in the given transaction.
|
||||
// It implements the Statement interface.
|
||||
func (stmt *DropTableStmt) Run(ctx *Context) (Result, error) {
|
||||
var res Result
|
||||
|
||||
func (stmt *DropTableStmt) Run(ctx *Context) (*Result, error) {
|
||||
if stmt.TableName == "" {
|
||||
return res, errors.New("missing table name")
|
||||
return nil, errors.New("missing table name")
|
||||
}
|
||||
|
||||
tb, err := ctx.Tx.Catalog.GetTable(ctx.Tx, stmt.TableName)
|
||||
tb, err := ctx.Conn.GetTx().Catalog.GetTable(ctx.Conn.GetTx(), stmt.TableName)
|
||||
if err != nil {
|
||||
if errs.IsNotFoundError(err) && stmt.IfExists {
|
||||
err = nil
|
||||
}
|
||||
|
||||
return res, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = ctx.Tx.CatalogWriter().DropTable(ctx.Tx, stmt.TableName)
|
||||
err = ctx.Conn.GetTx().CatalogWriter().DropTable(ctx.Conn.GetTx(), stmt.TableName)
|
||||
if err != nil {
|
||||
return res, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// if there is no primary key, drop the rowid sequence
|
||||
if tb.Info.PrimaryKey == nil {
|
||||
err = ctx.Tx.CatalogWriter().DropSequence(ctx.Tx, tb.Info.RowidSequenceName)
|
||||
err = ctx.Conn.GetTx().CatalogWriter().DropSequence(ctx.Conn.GetTx(), tb.Info.RowidSequenceName)
|
||||
if err != nil {
|
||||
return res, err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return res, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// DropIndexStmt is a DSL that allows creating a DROP INDEX query.
|
||||
@@ -77,19 +75,17 @@ func (stmt *DropIndexStmt) Bind(ctx *Context) error {
|
||||
|
||||
// Run runs the DropIndex statement in the given transaction.
|
||||
// It implements the Statement interface.
|
||||
func (stmt *DropIndexStmt) Run(ctx *Context) (Result, error) {
|
||||
var res Result
|
||||
|
||||
func (stmt *DropIndexStmt) Run(ctx *Context) (*Result, error) {
|
||||
if stmt.IndexName == "" {
|
||||
return res, errors.New("missing index name")
|
||||
return nil, errors.New("missing index name")
|
||||
}
|
||||
|
||||
err := ctx.Tx.CatalogWriter().DropIndex(ctx.Tx, stmt.IndexName)
|
||||
err := ctx.Conn.GetTx().CatalogWriter().DropIndex(ctx.Conn.GetTx(), stmt.IndexName)
|
||||
if errs.IsNotFoundError(err) && stmt.IfExists {
|
||||
err = nil
|
||||
}
|
||||
|
||||
return res, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// DropSequenceStmt is a DSL that allows creating a DROP INDEX query.
|
||||
@@ -109,29 +105,27 @@ func (stmt *DropSequenceStmt) Bind(ctx *Context) error {
|
||||
|
||||
// Run runs the DropSequence statement in the given transaction.
|
||||
// It implements the Statement interface.
|
||||
func (stmt *DropSequenceStmt) Run(ctx *Context) (Result, error) {
|
||||
var res Result
|
||||
|
||||
func (stmt *DropSequenceStmt) Run(ctx *Context) (*Result, error) {
|
||||
if stmt.SequenceName == "" {
|
||||
return res, errors.New("missing index name")
|
||||
return nil, errors.New("missing index name")
|
||||
}
|
||||
|
||||
seq, err := ctx.Tx.Catalog.GetSequence(stmt.SequenceName)
|
||||
seq, err := ctx.Conn.GetTx().Catalog.GetSequence(stmt.SequenceName)
|
||||
if err != nil {
|
||||
if errs.IsNotFoundError(err) && stmt.IfExists {
|
||||
err = nil
|
||||
}
|
||||
return res, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if seq.Info.Owner.TableName != "" {
|
||||
return res, fmt.Errorf("cannot drop sequence %s because constraint of table %s requires it", seq.Info.Name, seq.Info.Owner.TableName)
|
||||
return nil, fmt.Errorf("cannot drop sequence %s because constraint of table %s requires it", seq.Info.Name, seq.Info.Owner.TableName)
|
||||
}
|
||||
|
||||
err = ctx.Tx.CatalogWriter().DropSequence(ctx.Tx, stmt.SequenceName)
|
||||
err = ctx.Conn.GetTx().CatalogWriter().DropSequence(ctx.Conn.GetTx(), stmt.SequenceName)
|
||||
if errs.IsNotFoundError(err) && stmt.IfExists {
|
||||
err = nil
|
||||
}
|
||||
|
||||
return res, err
|
||||
return nil, err
|
||||
}
|
||||
|
@@ -19,7 +19,7 @@ type ExplainStmt struct {
|
||||
}
|
||||
|
||||
func (stmt *ExplainStmt) Bind(ctx *Context) error {
|
||||
if s, ok := stmt.Statement.(Statement); ok {
|
||||
if s, ok := stmt.Statement.(Bindable); ok {
|
||||
return s.Bind(ctx)
|
||||
}
|
||||
|
||||
@@ -30,26 +30,36 @@ func (stmt *ExplainStmt) Bind(ctx *Context) error {
|
||||
// If the statement is a stream, Optimize will be called prior to
|
||||
// displaying all the operations.
|
||||
// Explain currently only works on SELECT, UPDATE, INSERT and DELETE statements.
|
||||
func (stmt *ExplainStmt) Run(ctx *Context) (Result, error) {
|
||||
func (stmt *ExplainStmt) Run(ctx *Context) (*Result, error) {
|
||||
st, err := stmt.Statement.Prepare(ctx)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s, ok := st.(*PreparedStreamStmt)
|
||||
if !ok {
|
||||
return Result{}, errors.New("EXPLAIN only works on INSERT, SELECT, UPDATE AND DELETE statements")
|
||||
var s *stream.Stream
|
||||
|
||||
switch stmt := st.(type) {
|
||||
case *InsertStmt:
|
||||
s = stmt.Stream
|
||||
case *SelectStmt:
|
||||
s = stmt.Stream
|
||||
case *UpdateStmt:
|
||||
s = stmt.Stream
|
||||
case *DeleteStmt:
|
||||
s = stmt.Stream
|
||||
default:
|
||||
return nil, errors.New("EXPLAIN only works on INSERT, SELECT, UPDATE AND DELETE statements")
|
||||
}
|
||||
|
||||
// Optimize the stream.
|
||||
s.Stream, err = planner.Optimize(s.Stream, ctx.Tx.Catalog, ctx.Params)
|
||||
s, err = planner.Optimize(s, ctx.Conn.GetTx().Catalog, ctx.Params)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var plan string
|
||||
if s.Stream != nil {
|
||||
plan = s.Stream.String()
|
||||
if s != nil {
|
||||
plan = s.String()
|
||||
} else {
|
||||
plan = "<no exec>"
|
||||
}
|
||||
@@ -62,7 +72,6 @@ func (stmt *ExplainStmt) Run(ctx *Context) (Result, error) {
|
||||
Expr: expr.LiteralValue{Value: types.NewTextValue(plan)},
|
||||
}),
|
||||
},
|
||||
ReadOnly: true,
|
||||
}
|
||||
return newStatement.Run(ctx)
|
||||
}
|
||||
|
@@ -15,7 +15,7 @@ var _ Statement = (*InsertStmt)(nil)
|
||||
|
||||
// InsertStmt holds INSERT configuration.
|
||||
type InsertStmt struct {
|
||||
basePreparedStatement
|
||||
PreparedStreamStmt
|
||||
|
||||
TableName string
|
||||
Values []expr.Expr
|
||||
@@ -25,17 +25,6 @@ type InsertStmt struct {
|
||||
OnConflict database.OnConflictAction
|
||||
}
|
||||
|
||||
func NewInsertStatement() *InsertStmt {
|
||||
var p InsertStmt
|
||||
|
||||
p.basePreparedStatement = basePreparedStatement{
|
||||
Preparer: &p,
|
||||
ReadOnly: false,
|
||||
}
|
||||
|
||||
return &p
|
||||
}
|
||||
|
||||
func (stmt *InsertStmt) Bind(ctx *Context) error {
|
||||
for i := range stmt.Values {
|
||||
err := BindExpr(ctx, stmt.TableName, stmt.Values[i])
|
||||
@@ -45,7 +34,7 @@ func (stmt *InsertStmt) Bind(ctx *Context) error {
|
||||
}
|
||||
|
||||
if stmt.SelectStmt != nil {
|
||||
if s, ok := stmt.SelectStmt.(Statement); ok {
|
||||
if s, ok := stmt.SelectStmt.(Bindable); ok {
|
||||
err := s.Bind(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -68,7 +57,7 @@ func (stmt *InsertStmt) Prepare(c *Context) (Statement, error) {
|
||||
|
||||
var columns []string
|
||||
if stmt.Values != nil {
|
||||
ti, err := c.Tx.Catalog.GetTableInfo(stmt.TableName)
|
||||
ti, err := c.Conn.GetTx().Catalog.GetTableInfo(stmt.TableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -129,7 +118,7 @@ func (stmt *InsertStmt) Prepare(c *Context) (Statement, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s = selectStream.(*PreparedStreamStmt).Stream
|
||||
s = selectStream.(*SelectStmt).Stream
|
||||
|
||||
// ensure we are not reading and writing to the same table.
|
||||
// TODO(asdine): if same table, write content to a temp table.
|
||||
@@ -157,9 +146,9 @@ func (stmt *InsertStmt) Prepare(c *Context) (Statement, error) {
|
||||
}
|
||||
|
||||
// check unique constraints
|
||||
indexNames := c.Tx.Catalog.ListIndexes(stmt.TableName)
|
||||
indexNames := c.Conn.GetTx().Catalog.ListIndexes(stmt.TableName)
|
||||
for _, indexName := range indexNames {
|
||||
info, err := c.Tx.Catalog.GetIndexInfo(indexName)
|
||||
info, err := c.Conn.GetTx().Catalog.GetIndexInfo(indexName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -189,10 +178,6 @@ func (stmt *InsertStmt) Prepare(c *Context) (Statement, error) {
|
||||
s = s.Pipe(stream.Discard())
|
||||
}
|
||||
|
||||
st := StreamStmt{
|
||||
Stream: s,
|
||||
ReadOnly: false,
|
||||
}
|
||||
|
||||
return st.Prepare(c)
|
||||
stmt.PreparedStreamStmt.Stream = s
|
||||
return stmt, nil
|
||||
}
|
||||
|
@@ -12,34 +12,19 @@ var _ Statement = (*ReIndexStmt)(nil)
|
||||
|
||||
// ReIndexStmt is a DSL that allows creating a full REINDEX statement.
|
||||
type ReIndexStmt struct {
|
||||
basePreparedStatement
|
||||
PreparedStreamStmt
|
||||
|
||||
TableOrIndexName string
|
||||
}
|
||||
|
||||
func NewReIndexStatement() *ReIndexStmt {
|
||||
var p ReIndexStmt
|
||||
|
||||
p.basePreparedStatement = basePreparedStatement{
|
||||
Preparer: &p,
|
||||
ReadOnly: false,
|
||||
}
|
||||
|
||||
return &p
|
||||
}
|
||||
|
||||
func (stmt *ReIndexStmt) Bind(ctx *Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Prepare implements the Preparer interface.
|
||||
func (stmt *ReIndexStmt) Prepare(ctx *Context) (Statement, error) {
|
||||
var indexNames []string
|
||||
|
||||
if stmt.TableOrIndexName == "" {
|
||||
indexNames = ctx.Tx.Catalog.Cache.ListObjects(database.RelationIndexType)
|
||||
} else if _, err := ctx.Tx.Catalog.GetTable(ctx.Tx, stmt.TableOrIndexName); err == nil {
|
||||
indexNames = ctx.Tx.Catalog.ListIndexes(stmt.TableOrIndexName)
|
||||
indexNames = ctx.Conn.GetTx().Catalog.Cache.ListObjects(database.RelationIndexType)
|
||||
} else if _, err := ctx.Conn.GetTx().Catalog.GetTable(ctx.Conn.GetTx(), stmt.TableOrIndexName); err == nil {
|
||||
indexNames = ctx.Conn.GetTx().Catalog.ListIndexes(stmt.TableOrIndexName)
|
||||
} else if !errs.IsNotFoundError(err) {
|
||||
return nil, err
|
||||
} else {
|
||||
@@ -49,11 +34,11 @@ func (stmt *ReIndexStmt) Prepare(ctx *Context) (Statement, error) {
|
||||
var streams []*stream.Stream
|
||||
|
||||
for _, indexName := range indexNames {
|
||||
idx, err := ctx.Tx.Catalog.GetIndex(ctx.Tx, indexName)
|
||||
idx, err := ctx.Conn.GetTx().Catalog.GetIndex(ctx.Conn.GetTx(), indexName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info, err := ctx.Tx.Catalog.GetIndexInfo(indexName)
|
||||
info, err := ctx.Conn.GetTx().Catalog.GetIndexInfo(indexName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -67,10 +52,8 @@ func (stmt *ReIndexStmt) Prepare(ctx *Context) (Statement, error) {
|
||||
streams = append(streams, s)
|
||||
}
|
||||
|
||||
st := StreamStmt{
|
||||
Stream: stream.New(stream.Concat(streams...)).Pipe(stream.Discard()),
|
||||
ReadOnly: false,
|
||||
}
|
||||
s := stream.New(stream.Concat(streams...)).Pipe(stream.Discard())
|
||||
|
||||
return st.Prepare(ctx)
|
||||
stmt.PreparedStreamStmt.Stream = s
|
||||
return stmt, nil
|
||||
}
|
||||
|
@@ -43,13 +43,33 @@ func (stmt *SelectCoreStmt) Bind(ctx *Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (stmt *SelectCoreStmt) Prepare(ctx *Context) (*StreamStmt, error) {
|
||||
isReadOnly := true
|
||||
func (stmt *SelectCoreStmt) IsReadOnly() bool {
|
||||
var isReadOnly = true
|
||||
|
||||
// SELECT is read-only most of the time, unless it's using some expressions
|
||||
// that require write access and that are allowed to be run, such as nextval
|
||||
for _, e := range stmt.ProjectionExprs {
|
||||
expr.Walk(e, func(e expr.Expr) bool {
|
||||
switch e.(type) {
|
||||
case *expr.NamedExpr:
|
||||
return true
|
||||
case *functions.NextVal:
|
||||
isReadOnly = false
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return isReadOnly
|
||||
}
|
||||
|
||||
func (stmt *SelectCoreStmt) Prepare(ctx *Context) (*stream.Stream, error) {
|
||||
var s *stream.Stream
|
||||
|
||||
if stmt.TableName != "" {
|
||||
_, err := ctx.Tx.Catalog.GetTableInfo(stmt.TableName)
|
||||
_, err := ctx.Conn.GetTx().Catalog.GetTableInfo(stmt.TableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -148,35 +168,16 @@ func (stmt *SelectCoreStmt) Prepare(ctx *Context) (*StreamStmt, error) {
|
||||
}
|
||||
s = s.Pipe(rows.Project(stmt.ProjectionExprs...))
|
||||
|
||||
// SELECT is read-only most of the time, unless it's using some expressions
|
||||
// that require write access and that are allowed to be run, such as nextval
|
||||
for _, e := range stmt.ProjectionExprs {
|
||||
expr.Walk(e, func(e expr.Expr) bool {
|
||||
switch e.(type) {
|
||||
case *expr.NamedExpr:
|
||||
return true
|
||||
case *functions.NextVal:
|
||||
isReadOnly = false
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if stmt.Distinct {
|
||||
s = stream.New(stream.Union(s))
|
||||
}
|
||||
|
||||
return &StreamStmt{
|
||||
Stream: s,
|
||||
ReadOnly: isReadOnly,
|
||||
}, nil
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// SelectStmt holds SELECT configuration.
|
||||
type SelectStmt struct {
|
||||
basePreparedStatement
|
||||
PreparedStreamStmt
|
||||
|
||||
CompoundSelect []*SelectCoreStmt
|
||||
CompoundOperators []scanner.Token
|
||||
@@ -186,15 +187,13 @@ type SelectStmt struct {
|
||||
LimitExpr expr.Expr
|
||||
}
|
||||
|
||||
func NewSelectStatement() *SelectStmt {
|
||||
var p SelectStmt
|
||||
|
||||
p.basePreparedStatement = basePreparedStatement{
|
||||
Preparer: &p,
|
||||
ReadOnly: true,
|
||||
func (stmt *SelectStmt) IsReadOnly() bool {
|
||||
for i := range stmt.CompoundSelect {
|
||||
if !stmt.CompoundSelect[i].IsReadOnly() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return &p
|
||||
return true
|
||||
}
|
||||
|
||||
func (stmt *SelectStmt) Bind(ctx *Context) error {
|
||||
@@ -230,7 +229,6 @@ func (stmt *SelectStmt) Prepare(ctx *Context) (Statement, error) {
|
||||
var prev scanner.Token
|
||||
|
||||
var coreStmts []*stream.Stream
|
||||
var readOnly bool = true
|
||||
|
||||
for i, coreSelect := range stmt.CompoundSelect {
|
||||
coreStmt, err := coreSelect.Prepare(ctx)
|
||||
@@ -239,16 +237,11 @@ func (stmt *SelectStmt) Prepare(ctx *Context) (Statement, error) {
|
||||
}
|
||||
|
||||
if len(stmt.CompoundSelect) == 1 {
|
||||
s = coreStmt.Stream
|
||||
readOnly = coreStmt.ReadOnly
|
||||
s = coreStmt
|
||||
break
|
||||
}
|
||||
|
||||
coreStmts = append(coreStmts, coreStmt.Stream)
|
||||
|
||||
if !coreStmt.ReadOnly {
|
||||
readOnly = false
|
||||
}
|
||||
coreStmts = append(coreStmts, coreStmt)
|
||||
|
||||
var tok scanner.Token
|
||||
if i < len(stmt.CompoundOperators) {
|
||||
@@ -285,10 +278,6 @@ func (stmt *SelectStmt) Prepare(ctx *Context) (Statement, error) {
|
||||
s = s.Pipe(rows.Take(stmt.LimitExpr))
|
||||
}
|
||||
|
||||
st := StreamStmt{
|
||||
Stream: s,
|
||||
ReadOnly: readOnly,
|
||||
}
|
||||
|
||||
return st.Prepare(ctx)
|
||||
stmt.PreparedStreamStmt.Stream = s
|
||||
return stmt, nil
|
||||
}
|
||||
|
@@ -1,6 +1,8 @@
|
||||
package statement
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/chaisql/chai/internal/database"
|
||||
"github.com/chaisql/chai/internal/environment"
|
||||
"github.com/chaisql/chai/internal/expr"
|
||||
@@ -9,33 +11,31 @@ import (
|
||||
|
||||
// A Statement represents a unique action that can be executed against the database.
|
||||
type Statement interface {
|
||||
Bind(*Context) error
|
||||
Run(*Context) (Result, error)
|
||||
Run(*Context) (*Result, error)
|
||||
}
|
||||
|
||||
// Optional interface that allows a statement to specify if it is read-only.
|
||||
// Defaults to false if not implemented.
|
||||
type ReadOnly interface {
|
||||
IsReadOnly() bool
|
||||
}
|
||||
|
||||
type basePreparedStatement struct {
|
||||
Preparer Preparer
|
||||
ReadOnly bool
|
||||
// Optional interface that allows a statement to specify if they need a transaction.
|
||||
// Defaults to true if not implemented.
|
||||
// If true, the engine will auto-commit.
|
||||
type Transactional interface {
|
||||
NeedsTransaction() bool
|
||||
}
|
||||
|
||||
func (stmt *basePreparedStatement) IsReadOnly() bool {
|
||||
return stmt.ReadOnly
|
||||
}
|
||||
|
||||
func (stmt *basePreparedStatement) Run(ctx *Context) (Result, error) {
|
||||
s, err := stmt.Preparer.Prepare(ctx)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
|
||||
return s.Run(ctx)
|
||||
// Optional interface that allows a statement to specify if they need to be bound to database
|
||||
// objects.
|
||||
type Bindable interface {
|
||||
Bind(*Context) error
|
||||
}
|
||||
|
||||
type Context struct {
|
||||
DB *database.Database
|
||||
Conn *database.Connection
|
||||
Tx *database.Transaction
|
||||
Params []environment.Param
|
||||
}
|
||||
|
||||
@@ -90,7 +90,10 @@ func (r *Result) Iterate(fn func(r database.Row) error) (err error) {
|
||||
// Skip iterates over the result and skips all rows.
|
||||
// It is useful when you need the query to be executed
|
||||
// but don't care about the results.
|
||||
func (r *Result) Skip() (err error) {
|
||||
func (r *Result) Skip(ctx context.Context) (err error) {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
if r.Result == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -107,6 +110,9 @@ func (r *Result) Skip() (err error) {
|
||||
defer it.Close()
|
||||
|
||||
for it.Next() {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return it.Error()
|
||||
@@ -122,7 +128,7 @@ func (r *Result) Columns() ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
env := environment.New(stmt.Context.DB, stmt.Context.Tx, stmt.Context.Params, nil)
|
||||
env := environment.New(stmt.Context.DB, stmt.Context.Conn.GetTx(), stmt.Context.Params, nil)
|
||||
return stmt.Stream.Columns(env)
|
||||
}
|
||||
|
||||
@@ -158,7 +164,7 @@ func BindExpr(ctx *Context, tableName string, e expr.Expr) (err error) {
|
||||
|
||||
var info *database.TableInfo
|
||||
if tableName != "" {
|
||||
info, err = ctx.Tx.Catalog.GetTableInfo(tableName)
|
||||
info, err = ctx.Conn.GetTx().Catalog.GetTableInfo(tableName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@@ -15,33 +15,20 @@ type StreamStmt struct {
|
||||
ReadOnly bool
|
||||
}
|
||||
|
||||
// Prepare implements the Preparer interface.
|
||||
func (s *StreamStmt) Prepare(ctx *Context) (Statement, error) {
|
||||
return &PreparedStreamStmt{
|
||||
Stream: s.Stream,
|
||||
ReadOnly: s.ReadOnly,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// PreparedStreamStmt is a PreparedStreamStmt using a Stream.
|
||||
type PreparedStreamStmt struct {
|
||||
Stream *stream.Stream
|
||||
ReadOnly bool
|
||||
}
|
||||
|
||||
func (s *PreparedStreamStmt) Bind(ctx *Context) error {
|
||||
return nil
|
||||
Stream *stream.Stream
|
||||
}
|
||||
|
||||
// Run returns a result containing the stream. The stream will be executed by calling the Iterate method of
|
||||
// the result.
|
||||
func (s *PreparedStreamStmt) Run(ctx *Context) (Result, error) {
|
||||
st, err := planner.Optimize(s.Stream.Clone(), ctx.Tx.Catalog, ctx.Params)
|
||||
func (s *PreparedStreamStmt) Run(ctx *Context) (*Result, error) {
|
||||
st, err := planner.Optimize(s.Stream.Clone(), ctx.Conn.GetTx().Catalog, ctx.Params)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return Result{
|
||||
return &Result{
|
||||
Result: &StreamStmtResult{
|
||||
Stream: st,
|
||||
Context: ctx,
|
||||
@@ -49,11 +36,6 @@ func (s *PreparedStreamStmt) Run(ctx *Context) (Result, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// IsReadOnly reports whether the stream will modify the database or only read it.
|
||||
func (s *PreparedStreamStmt) IsReadOnly() bool {
|
||||
return s.ReadOnly
|
||||
}
|
||||
|
||||
func (s *PreparedStreamStmt) String() string {
|
||||
return s.Stream.String()
|
||||
}
|
||||
@@ -65,7 +47,7 @@ type StreamStmtResult struct {
|
||||
}
|
||||
|
||||
func (s *StreamStmtResult) Iterator() (database.Iterator, error) {
|
||||
env := environment.New(s.Context.DB, s.Context.Tx, s.Context.Params, nil)
|
||||
env := environment.New(s.Context.DB, s.Context.Conn.GetTx(), s.Context.Params, nil)
|
||||
|
||||
return s.Stream.Iterator(env)
|
||||
}
|
||||
|
@@ -13,7 +13,7 @@ var _ Statement = (*UpdateStmt)(nil)
|
||||
|
||||
// UpdateConfig holds UPDATE configuration.
|
||||
type UpdateStmt struct {
|
||||
basePreparedStatement
|
||||
PreparedStreamStmt
|
||||
|
||||
TableName string
|
||||
|
||||
@@ -25,17 +25,6 @@ type UpdateStmt struct {
|
||||
WhereExpr expr.Expr
|
||||
}
|
||||
|
||||
func NewUpdateStatement() *UpdateStmt {
|
||||
var p UpdateStmt
|
||||
|
||||
p.basePreparedStatement = basePreparedStatement{
|
||||
Preparer: &p,
|
||||
ReadOnly: false,
|
||||
}
|
||||
|
||||
return &p
|
||||
}
|
||||
|
||||
type UpdateSetPair struct {
|
||||
Column *expr.Column
|
||||
E expr.Expr
|
||||
@@ -64,7 +53,7 @@ func (stmt *UpdateStmt) Bind(ctx *Context) error {
|
||||
|
||||
// Prepare implements the Preparer interface.
|
||||
func (stmt *UpdateStmt) Prepare(c *Context) (Statement, error) {
|
||||
ti, err := c.Tx.Catalog.GetTableInfo(stmt.TableName)
|
||||
ti, err := c.Conn.GetTx().Catalog.GetTableInfo(stmt.TableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -99,7 +88,7 @@ func (stmt *UpdateStmt) Prepare(c *Context) (Statement, error) {
|
||||
// TODO(asdine): This removes ALL indexed fields for each row
|
||||
// even if the update modified a single field. We should only
|
||||
// update the indexed fields that were modified.
|
||||
indexNames := c.Tx.Catalog.ListIndexes(stmt.TableName)
|
||||
indexNames := c.Conn.GetTx().Catalog.ListIndexes(stmt.TableName)
|
||||
for _, indexName := range indexNames {
|
||||
s = s.Pipe(index.Delete(indexName))
|
||||
}
|
||||
@@ -114,7 +103,7 @@ func (stmt *UpdateStmt) Prepare(c *Context) (Statement, error) {
|
||||
}
|
||||
|
||||
for _, indexName := range indexNames {
|
||||
info, err := c.Tx.Catalog.GetIndexInfo(indexName)
|
||||
info, err := c.Conn.GetTx().Catalog.GetIndexInfo(indexName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -127,10 +116,6 @@ func (stmt *UpdateStmt) Prepare(c *Context) (Statement, error) {
|
||||
|
||||
s = s.Pipe(stream.Discard())
|
||||
|
||||
st := StreamStmt{
|
||||
Stream: s,
|
||||
ReadOnly: false,
|
||||
}
|
||||
|
||||
return st.Prepare(c)
|
||||
stmt.PreparedStreamStmt.Stream = s
|
||||
return stmt, nil
|
||||
}
|
||||
|
@@ -6,109 +6,58 @@ import (
|
||||
"github.com/cockroachdb/errors"
|
||||
)
|
||||
|
||||
var _ queryAlterer = BeginStmt{}
|
||||
var _ queryAlterer = RollbackStmt{}
|
||||
var _ queryAlterer = CommitStmt{}
|
||||
|
||||
// BeginStmt is a statement that creates a new transaction.
|
||||
type BeginStmt struct {
|
||||
Writable bool
|
||||
}
|
||||
|
||||
func (stmt BeginStmt) Bind(ctx *statement.Context) error {
|
||||
return nil
|
||||
func (stmt BeginStmt) NeedsTransaction() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Prepare implements the Preparer interface.
|
||||
func (stmt BeginStmt) Prepare(*statement.Context) (statement.Statement, error) {
|
||||
return stmt, nil
|
||||
}
|
||||
|
||||
func (stmt BeginStmt) alterQuery(conn *database.Connection, q *Query) error {
|
||||
if q.tx != nil {
|
||||
return errors.New("cannot begin a transaction within a transaction")
|
||||
func (stmt BeginStmt) Run(ctx *statement.Context) (*statement.Result, error) {
|
||||
if ctx.Conn.GetTx() != nil {
|
||||
return nil, errors.New("cannot begin a transaction within a transaction")
|
||||
}
|
||||
|
||||
var err error
|
||||
q.tx, err = conn.BeginTx(&database.TxOptions{
|
||||
_, err := ctx.Conn.BeginTx(&database.TxOptions{
|
||||
ReadOnly: !stmt.Writable,
|
||||
})
|
||||
q.autoCommit = false
|
||||
return err
|
||||
}
|
||||
|
||||
func (stmt BeginStmt) IsReadOnly() bool {
|
||||
return !stmt.Writable
|
||||
}
|
||||
|
||||
func (stmt BeginStmt) Run(ctx *statement.Context) (statement.Result, error) {
|
||||
return statement.Result{}, errors.New("cannot begin a transaction within a transaction")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// RollbackStmt is a statement that rollbacks the current active transaction.
|
||||
type RollbackStmt struct{}
|
||||
|
||||
func (stmt RollbackStmt) Bind(ctx *statement.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Prepare implements the Preparer interface.
|
||||
func (stmt RollbackStmt) Prepare(*statement.Context) (statement.Statement, error) {
|
||||
return stmt, nil
|
||||
}
|
||||
|
||||
func (stmt RollbackStmt) alterQuery(conn *database.Connection, q *Query) error {
|
||||
if q.tx == nil || q.autoCommit {
|
||||
return errors.New("cannot rollback with no active transaction")
|
||||
}
|
||||
|
||||
err := q.tx.Rollback()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
q.tx = nil
|
||||
q.autoCommit = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (stmt RollbackStmt) IsReadOnly() bool {
|
||||
func (stmt RollbackStmt) NeedsTransaction() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (stmt RollbackStmt) Run(ctx *statement.Context) (statement.Result, error) {
|
||||
return statement.Result{}, errors.New("cannot rollback with no active transaction")
|
||||
func (stmt RollbackStmt) Run(ctx *statement.Context) (*statement.Result, error) {
|
||||
tx := ctx.Conn.GetTx()
|
||||
if tx == nil {
|
||||
return nil, errors.New("cannot rollback with no active transaction")
|
||||
}
|
||||
|
||||
return nil, tx.Rollback()
|
||||
}
|
||||
|
||||
// CommitStmt is a statement that commits the current active transaction.
|
||||
type CommitStmt struct{}
|
||||
|
||||
func (stmt CommitStmt) Bind(ctx *statement.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Prepare implements the Preparer interface.
|
||||
func (stmt CommitStmt) Prepare(*statement.Context) (statement.Statement, error) {
|
||||
return stmt, nil
|
||||
}
|
||||
|
||||
func (stmt CommitStmt) alterQuery(conn *database.Connection, q *Query) error {
|
||||
if q.tx == nil || q.autoCommit {
|
||||
return errors.New("cannot commit with no active transaction")
|
||||
}
|
||||
|
||||
err := q.tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
q.tx = nil
|
||||
q.autoCommit = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (stmt CommitStmt) IsReadOnly() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (stmt CommitStmt) Run(ctx *statement.Context) (statement.Result, error) {
|
||||
return statement.Result{}, errors.New("cannot commit with no active transaction")
|
||||
func (stmt CommitStmt) NeedsTransaction() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (stmt CommitStmt) Run(ctx *statement.Context) (*statement.Result, error) {
|
||||
tx := ctx.Conn.GetTx()
|
||||
if tx == nil {
|
||||
return nil, errors.New("cannot commit with no active transaction")
|
||||
}
|
||||
|
||||
return nil, tx.Commit()
|
||||
}
|
||||
|
@@ -17,8 +17,10 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
_ driver.Driver = (*Driver)(nil)
|
||||
_ driver.DriverContext = (*Driver)(nil)
|
||||
_ driver.Driver = (*Driver)(nil)
|
||||
_ driver.DriverContext = (*Driver)(nil)
|
||||
_ driver.QueryerContext = (*Conn)(nil)
|
||||
_ driver.ExecerContext = (*Conn)(nil)
|
||||
)
|
||||
|
||||
// Driver is a driver.Driver that can open a new connection to a Chai database.
|
||||
@@ -98,24 +100,100 @@ func (c *Conn) Prepare(q string) (driver.Stmt, error) {
|
||||
return c.PrepareContext(context.Background(), q)
|
||||
}
|
||||
|
||||
// PrepareContext returns a prepared statement, bound to this connection.
|
||||
func (c *Conn) PrepareContext(ctx context.Context, q string) (driver.Stmt, error) {
|
||||
pq, err := parser.ParseQuery(q)
|
||||
func (c *Conn) ExecContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Result, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
statements, err := parser.ParseQuery(q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = pq.Prepare(&query.Context{
|
||||
Ctx: ctx,
|
||||
DB: c.db,
|
||||
Conn: c.conn,
|
||||
res, err := query.New(statements...).Run(&query.Context{
|
||||
Ctx: ctx,
|
||||
DB: c.DB(),
|
||||
Conn: c.conn,
|
||||
Params: NamedValueToParams(args),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
er := res.Close()
|
||||
if err == nil {
|
||||
err = er
|
||||
}
|
||||
}()
|
||||
|
||||
return ExecResult{}, res.Skip(ctx)
|
||||
}
|
||||
|
||||
func (c *Conn) QueryContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Rows, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
statements, err := parser.ParseQuery(q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := query.New(statements...).Run(&query.Context{
|
||||
Ctx: ctx,
|
||||
DB: c.DB(),
|
||||
Conn: c.conn,
|
||||
Params: NamedValueToParams(args),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewRows(res)
|
||||
}
|
||||
|
||||
// PrepareContext returns a prepared statement, bound to this connection.
|
||||
func (c *Conn) PrepareContext(ctx context.Context, q string) (driver.Stmt, error) {
|
||||
statements, err := parser.ParseQuery(q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(statements) != 1 {
|
||||
return nil, errors.New("cannot insert multiple commands into a prepared statement")
|
||||
}
|
||||
|
||||
sctx := statement.Context{
|
||||
DB: c.db,
|
||||
Conn: c.conn,
|
||||
}
|
||||
|
||||
tx, err := c.conn.BeginTx(&database.TxOptions{
|
||||
ReadOnly: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
stmt := statements[0]
|
||||
|
||||
if b, ok := stmt.(statement.Bindable); ok {
|
||||
err = b.Bind(&sctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if p, ok := stmt.(statement.Preparer); ok {
|
||||
stmt, err = p.Prepare(&sctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &Stmt{
|
||||
pq: pq,
|
||||
stmt: stmt,
|
||||
conn: c,
|
||||
}, nil
|
||||
}
|
||||
@@ -162,7 +240,7 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
|
||||
// Stmt is a prepared statement. It is bound to a Conn and not
|
||||
// used by multiple goroutines concurrently.
|
||||
type Stmt struct {
|
||||
pq query.Query
|
||||
stmt statement.Statement
|
||||
conn *Conn
|
||||
}
|
||||
|
||||
@@ -178,14 +256,11 @@ func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
// ExecContext executes a query that doesn't return rows, such
|
||||
// as an INSERT or UPDATE.
|
||||
func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := s.pq.Run(&query.Context{
|
||||
Ctx: ctx,
|
||||
res, err := s.stmt.Run(&statement.Context{
|
||||
DB: s.conn.db,
|
||||
Conn: s.conn.conn,
|
||||
Params: NamedValueToParams(args),
|
||||
@@ -200,7 +275,26 @@ func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
|
||||
}
|
||||
}()
|
||||
|
||||
return ExecResult{}, res.Skip()
|
||||
return ExecResult{}, res.Skip(ctx)
|
||||
}
|
||||
|
||||
// QueryContext executes a query that may return rows, such as a
|
||||
// SELECT.
|
||||
func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := s.stmt.Run(&statement.Context{
|
||||
DB: s.conn.db,
|
||||
Conn: s.conn.conn,
|
||||
Params: NamedValueToParams(args),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewRows(res)
|
||||
}
|
||||
|
||||
type ExecResult struct{}
|
||||
@@ -219,28 +313,6 @@ func (s Stmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
// QueryContext executes a query that may return rows, such as a
|
||||
// SELECT.
|
||||
func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
res, err := s.pq.Run(&query.Context{
|
||||
Ctx: ctx,
|
||||
DB: s.conn.db,
|
||||
Conn: s.conn.conn,
|
||||
Params: NamedValueToParams(args),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewRows(res)
|
||||
}
|
||||
|
||||
// Close does nothing.
|
||||
func (s Stmt) Close() error {
|
||||
return nil
|
||||
@@ -253,6 +325,10 @@ type Rows struct {
|
||||
}
|
||||
|
||||
func NewRows(res *statement.Result) (*Rows, error) {
|
||||
if res == nil {
|
||||
return &Rows{}, nil
|
||||
}
|
||||
|
||||
columns, err := res.Columns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@@ -26,14 +26,14 @@ func TestParserAlterTable(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
q, err := parser.ParseQuery(test.s)
|
||||
stmts, err := parser.ParseQuery(test.s)
|
||||
if test.errored {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Len(t, q.Statements, 1)
|
||||
require.EqualValues(t, test.expected, q.Statements[0])
|
||||
require.Len(t, stmts, 1)
|
||||
require.EqualValues(t, test.expected, stmts[0])
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -89,14 +89,14 @@ func TestParserAlterTableAddColumn(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
q, err := parser.ParseQuery(test.s)
|
||||
stmts, err := parser.ParseQuery(test.s)
|
||||
if test.errored {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Len(t, q.Statements, 1)
|
||||
require.EqualValues(t, test.expected, q.Statements[0])
|
||||
require.Len(t, stmts, 1)
|
||||
require.EqualValues(t, test.expected, stmts[0])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -49,14 +49,14 @@ func TestParserCreateIndex(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
q, err := parser.ParseQuery(test.s)
|
||||
stmts, err := parser.ParseQuery(test.s)
|
||||
if test.errored {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Len(t, q.Statements, 1)
|
||||
require.EqualValues(t, test.expected, q.Statements[0])
|
||||
require.Len(t, stmts, 1)
|
||||
require.EqualValues(t, test.expected, stmts[0])
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -190,14 +190,14 @@ func TestParserCreateSequence(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
q, err := parser.ParseQuery(test.s)
|
||||
stmts, err := parser.ParseQuery(test.s)
|
||||
if test.errored {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Len(t, q.Statements, 1)
|
||||
require.EqualValues(t, test.expected, q.Statements[0])
|
||||
require.Len(t, stmts, 1)
|
||||
require.EqualValues(t, test.expected, stmts[0])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -9,7 +9,7 @@ import (
|
||||
|
||||
// parseDeleteStatement parses a delete string and returns a Statement AST row.
|
||||
func (p *Parser) parseDeleteStatement() (statement.Statement, error) {
|
||||
stmt := statement.NewDeleteStatement()
|
||||
var stmt statement.DeleteStmt
|
||||
var err error
|
||||
|
||||
// Parse "DELETE FROM".
|
||||
@@ -49,5 +49,5 @@ func (p *Parser) parseDeleteStatement() (statement.Statement, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return stmt, nil
|
||||
return &stmt, nil
|
||||
}
|
||||
|
@@ -1,11 +1,9 @@
|
||||
package parser_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/chaisql/chai/internal/expr"
|
||||
"github.com/chaisql/chai/internal/query"
|
||||
"github.com/chaisql/chai/internal/query/statement"
|
||||
"github.com/chaisql/chai/internal/sql/parser"
|
||||
"github.com/chaisql/chai/internal/stream"
|
||||
@@ -23,7 +21,7 @@ func TestParserDelete(t *testing.T) {
|
||||
|
||||
parseExpr := func(s string) expr.Expr {
|
||||
e := parser.MustParseExpr(s)
|
||||
err := statement.BindExpr(&statement.Context{DB: db, Tx: tx, Conn: tx.Connection()}, "test", e)
|
||||
err := statement.BindExpr(&statement.Context{DB: db, Conn: tx.Connection()}, "test", e)
|
||||
require.NoError(t, err)
|
||||
return e
|
||||
}
|
||||
@@ -75,18 +73,17 @@ func TestParserDelete(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
q, err := parser.ParseQuery(test.s)
|
||||
stmts, err := parser.ParseQuery(test.s)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = q.Prepare(&query.Context{
|
||||
Ctx: context.Background(),
|
||||
require.Len(t, stmts, 1)
|
||||
stmt, err := stmts[0].(statement.Preparer).Prepare(&statement.Context{
|
||||
DB: db,
|
||||
Conn: tx.Connection(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, q.Statements, 1)
|
||||
require.EqualValues(t, &statement.PreparedStreamStmt{Stream: test.expected}, q.Statements[0].(*statement.PreparedStreamStmt))
|
||||
require.Equal(t, test.expected.String(), stmt.(*statement.DeleteStmt).Stream.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -25,14 +25,14 @@ func TestParserDrop(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
q, err := parser.ParseQuery(test.s)
|
||||
stmts, err := parser.ParseQuery(test.s)
|
||||
if test.errored {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Len(t, q.Statements, 1)
|
||||
require.EqualValues(t, test.expected, q.Statements[0])
|
||||
require.Len(t, stmts, 1)
|
||||
require.EqualValues(t, test.expected, stmts[0])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
func TestParserExplain(t *testing.T) {
|
||||
slct := statement.NewSelectStatement()
|
||||
var slct statement.SelectStmt
|
||||
slct.CompoundSelect = []*statement.SelectCoreStmt{
|
||||
{TableName: "test", ProjectionExprs: []expr.Expr{expr.Wildcard{}}},
|
||||
}
|
||||
@@ -21,20 +21,20 @@ func TestParserExplain(t *testing.T) {
|
||||
expected statement.Statement
|
||||
errored bool
|
||||
}{
|
||||
{"Explain select", "EXPLAIN SELECT * FROM test", &statement.ExplainStmt{Statement: slct}, false},
|
||||
{"Explain select", "EXPLAIN SELECT * FROM test", &statement.ExplainStmt{Statement: &slct}, false},
|
||||
{"Multiple Explains", "EXPLAIN EXPLAIN CREATE TABLE test", nil, true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
q, err := parser.ParseQuery(test.s)
|
||||
stmts, err := parser.ParseQuery(test.s)
|
||||
if test.errored {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Len(t, q.Statements, 1)
|
||||
require.EqualValues(t, test.expected, q.Statements[0])
|
||||
require.Len(t, stmts, 1)
|
||||
require.EqualValues(t, test.expected, stmts[0])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
// parseInsertStatement parses an insert string and returns a Statement AST row.
|
||||
func (p *Parser) parseInsertStatement() (*statement.InsertStmt, error) {
|
||||
stmt := statement.NewInsertStatement()
|
||||
var stmt statement.InsertStmt
|
||||
var err error
|
||||
|
||||
// Parse "INSERT INTO".
|
||||
@@ -64,7 +64,7 @@ func (p *Parser) parseInsertStatement() (*statement.InsertStmt, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return stmt, nil
|
||||
return &stmt, nil
|
||||
}
|
||||
|
||||
// parseColumnList parses a list of columns in the form: (column, column, ...), if exists.
|
||||
|
@@ -1,11 +1,9 @@
|
||||
package parser_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/chaisql/chai/internal/expr"
|
||||
"github.com/chaisql/chai/internal/query"
|
||||
"github.com/chaisql/chai/internal/query/statement"
|
||||
"github.com/chaisql/chai/internal/sql/parser"
|
||||
"github.com/chaisql/chai/internal/stream"
|
||||
@@ -221,23 +219,22 @@ func TestParserInsert(t *testing.T) {
|
||||
|
||||
testutil.MustExec(t, db, tx, "CREATE TABLE test(a TEXT, b TEXT); CREATE TABLE foo(c TEXT, d TEXT);")
|
||||
|
||||
q, err := parser.ParseQuery(test.s)
|
||||
stmts, err := parser.ParseQuery(test.s)
|
||||
if test.fails {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
err = q.Prepare(&query.Context{
|
||||
Ctx: context.Background(),
|
||||
require.Len(t, stmts, 1)
|
||||
|
||||
stmt, err := stmts[0].(statement.Preparer).Prepare(&statement.Context{
|
||||
DB: db,
|
||||
Conn: tx.Connection(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, q.Statements, 1)
|
||||
|
||||
require.Equal(t, test.expected.String(), q.Statements[0].(*statement.PreparedStreamStmt).Stream.String())
|
||||
require.Equal(t, test.expected.String(), stmt.(*statement.InsertStmt).Stream.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -6,7 +6,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/chaisql/chai/internal/expr"
|
||||
"github.com/chaisql/chai/internal/query"
|
||||
"github.com/chaisql/chai/internal/query/statement"
|
||||
"github.com/chaisql/chai/internal/sql/scanner"
|
||||
"github.com/chaisql/chai/internal/tree"
|
||||
@@ -24,7 +23,7 @@ func NewParser(r io.Reader) *Parser {
|
||||
}
|
||||
|
||||
// ParseQuery parses a query string and returns its AST representation.
|
||||
func ParseQuery(s string) (query.Query, error) {
|
||||
func ParseQuery(s string) ([]statement.Statement, error) {
|
||||
return NewParser(strings.NewReader(s)).ParseQuery()
|
||||
}
|
||||
|
||||
@@ -45,7 +44,7 @@ func MustParseExpr(s string) expr.Expr {
|
||||
}
|
||||
|
||||
// ParseQuery parses a Chai SQL string and returns a Query.
|
||||
func (p *Parser) ParseQuery() (query.Query, error) {
|
||||
func (p *Parser) ParseQuery() ([]statement.Statement, error) {
|
||||
var statements []statement.Statement
|
||||
|
||||
err := p.Parse(func(s statement.Statement) error {
|
||||
@@ -53,10 +52,10 @@ func (p *Parser) ParseQuery() (query.Query, error) {
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return query.Query{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return query.Query{Statements: statements}, nil
|
||||
return statements, nil
|
||||
}
|
||||
|
||||
// ParseQuery parses a Chai SQL string and returns a Query.
|
||||
|
@@ -10,8 +10,8 @@ import (
|
||||
func FuzzParseQuery(f *testing.F) {
|
||||
f.Fuzz(func(t *testing.T, s string) {
|
||||
// Fuzz ParseQuery for panics.
|
||||
q, err := ParseQuery(s)
|
||||
if err != nil || len(q.Statements) < 1 {
|
||||
statements, err := ParseQuery(s)
|
||||
if err != nil || len(statements) < 1 {
|
||||
t.Skip()
|
||||
}
|
||||
})
|
||||
|
@@ -11,12 +11,12 @@ import (
|
||||
)
|
||||
|
||||
func TestParserMultiStatement(t *testing.T) {
|
||||
slct := statement.NewSelectStatement()
|
||||
var slct statement.SelectStmt
|
||||
slct.CompoundSelect = []*statement.SelectCoreStmt{
|
||||
{TableName: "foo", ProjectionExprs: []expr.Expr{expr.Wildcard{}}},
|
||||
}
|
||||
|
||||
dlt := statement.NewDeleteStatement()
|
||||
var dlt statement.DeleteStmt
|
||||
dlt.TableName = "foo"
|
||||
|
||||
tests := []struct {
|
||||
@@ -26,16 +26,16 @@ func TestParserMultiStatement(t *testing.T) {
|
||||
}{
|
||||
{"OnlyCommas", ";;;", nil},
|
||||
{"TrailingComma", "SELECT * FROM foo;;;DELETE FROM foo;", []statement.Statement{
|
||||
slct,
|
||||
dlt,
|
||||
&slct,
|
||||
&dlt,
|
||||
}},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
q, err := parser.ParseQuery(test.s)
|
||||
stmts, err := parser.ParseQuery(test.s)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, test.expected, q.Statements)
|
||||
require.EqualValues(t, test.expected, stmts)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
// parseReindexStatement parses a reindex statement.
|
||||
func (p *Parser) parseReIndexStatement() (statement.Statement, error) {
|
||||
stmt := statement.NewReIndexStatement()
|
||||
var stmt statement.ReIndexStmt
|
||||
|
||||
// Parse "REINDEX".
|
||||
if err := p.ParseTokens(scanner.REINDEX); err != nil {
|
||||
@@ -20,5 +20,5 @@ func (p *Parser) parseReIndexStatement() (statement.Statement, error) {
|
||||
} else {
|
||||
p.Unscan()
|
||||
}
|
||||
return stmt, nil
|
||||
return &stmt, nil
|
||||
}
|
||||
|
@@ -9,8 +9,8 @@ import (
|
||||
)
|
||||
|
||||
func TestParserReIndex(t *testing.T) {
|
||||
r1 := statement.NewReIndexStatement()
|
||||
r2 := statement.NewReIndexStatement()
|
||||
var r1 statement.ReIndexStmt
|
||||
var r2 statement.ReIndexStmt
|
||||
r2.TableOrIndexName = "tableOrIndex"
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -18,21 +18,21 @@ func TestParserReIndex(t *testing.T) {
|
||||
expected statement.Statement
|
||||
errored bool
|
||||
}{
|
||||
{"All", "REINDEX", r1, false},
|
||||
{"With ident", "REINDEX tableOrIndex", r2, false},
|
||||
{"All", "REINDEX", &r1, false},
|
||||
{"With ident", "REINDEX tableOrIndex", &r2, false},
|
||||
{"With extra", "REINDEX tableOrIndex tableOrIndex", nil, true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
q, err := parser.ParseQuery(test.s)
|
||||
stmts, err := parser.ParseQuery(test.s)
|
||||
if test.errored {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Len(t, q.Statements, 1)
|
||||
require.EqualValues(t, test.expected, q.Statements[0])
|
||||
require.Len(t, stmts, 1)
|
||||
require.EqualValues(t, test.expected, stmts[0])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -10,10 +10,10 @@ import (
|
||||
// parseSelectStatement parses a select string and returns a Statement AST row.
|
||||
// This function assumes the SELECT token has already been consumed.
|
||||
func (p *Parser) parseSelectStatement() (*statement.SelectStmt, error) {
|
||||
stmt := statement.NewSelectStatement()
|
||||
var stmt statement.SelectStmt
|
||||
|
||||
// Parse SELECT ... [UNION | UNION ALL | INTERSECT] SELECT ...
|
||||
err := p.parseCompoundSelectStatement(stmt)
|
||||
err := p.parseCompoundSelectStatement(&stmt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -36,7 +36,7 @@ func (p *Parser) parseSelectStatement() (*statement.SelectStmt, error) {
|
||||
return nil, errors.Wrap(err, "failed to parse OFFSET clause")
|
||||
}
|
||||
|
||||
return stmt, nil
|
||||
return &stmt, nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseCompoundSelectStatement(stmt *statement.SelectStmt) error {
|
||||
|
@@ -1,12 +1,10 @@
|
||||
package parser_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/chaisql/chai/internal/expr"
|
||||
"github.com/chaisql/chai/internal/expr/functions"
|
||||
"github.com/chaisql/chai/internal/query"
|
||||
"github.com/chaisql/chai/internal/query/statement"
|
||||
"github.com/chaisql/chai/internal/sql/parser"
|
||||
"github.com/chaisql/chai/internal/stream"
|
||||
@@ -37,7 +35,7 @@ func TestParserSelect(t *testing.T) {
|
||||
if len(table) > 0 {
|
||||
tb = table[0]
|
||||
}
|
||||
err := statement.BindExpr(&statement.Context{DB: db, Tx: tx, Conn: tx.Connection()}, tb, e)
|
||||
err := statement.BindExpr(&statement.Context{DB: db, Conn: tx.Connection()}, tb, e)
|
||||
require.NoError(t, err)
|
||||
return e
|
||||
}
|
||||
@@ -391,27 +389,20 @@ func TestParserSelect(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
q, err := parser.ParseQuery(test.s)
|
||||
stmts, err := parser.ParseQuery(test.s)
|
||||
if test.mustFail {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.Len(t, stmts, 1)
|
||||
|
||||
err = q.Prepare(&query.Context{
|
||||
Ctx: context.Background(),
|
||||
stmt, err := stmts[0].(statement.Preparer).Prepare(&statement.Context{
|
||||
DB: db,
|
||||
Conn: tx.Connection(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, q.Statements, 1)
|
||||
require.EqualValues(t, &statement.PreparedStreamStmt{ReadOnly: test.readOnly, Stream: test.expected}, q.Statements[0].(*statement.PreparedStreamStmt))
|
||||
require.Equal(t, test.expected.String(), stmt.(*statement.SelectStmt).Stream.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSelect(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = parser.ParseQuery("SELECT a, b AS `foo` FROM `some table` WHERE d.e[100] >= 12 AND c.d IN ([1, true], [2, false]) GROUP BY d.e[0] LIMIT 10 + 10 OFFSET 20 - 20 ORDER BY d DESC")
|
||||
}
|
||||
}
|
||||
|
@@ -29,14 +29,14 @@ func TestParserTransactions(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.s, func(t *testing.T) {
|
||||
q, err := parser.ParseQuery(test.s)
|
||||
stmts, err := parser.ParseQuery(test.s)
|
||||
if test.errored {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Len(t, q.Statements, 1)
|
||||
require.EqualValues(t, test.expected, q.Statements[0])
|
||||
require.Len(t, stmts, 1)
|
||||
require.EqualValues(t, test.expected, stmts[0])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -8,7 +8,7 @@ import (
|
||||
|
||||
// parseUpdateStatement parses a update string and returns a Statement AST row.
|
||||
func (p *Parser) parseUpdateStatement() (*statement.UpdateStmt, error) {
|
||||
stmt := statement.NewUpdateStatement()
|
||||
var stmt statement.UpdateStmt
|
||||
var err error
|
||||
|
||||
// Parse "UPDATE".
|
||||
@@ -42,7 +42,7 @@ func (p *Parser) parseUpdateStatement() (*statement.UpdateStmt, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return stmt, nil
|
||||
return &stmt, nil
|
||||
}
|
||||
|
||||
// parseSetClause parses the "SET" clause of the query.
|
||||
|
@@ -1,11 +1,9 @@
|
||||
package parser_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/chaisql/chai/internal/expr"
|
||||
"github.com/chaisql/chai/internal/query"
|
||||
"github.com/chaisql/chai/internal/query/statement"
|
||||
"github.com/chaisql/chai/internal/sql/parser"
|
||||
"github.com/chaisql/chai/internal/stream"
|
||||
@@ -28,7 +26,7 @@ func TestParserUpdate(t *testing.T) {
|
||||
if len(table) > 0 {
|
||||
tb = table[0]
|
||||
}
|
||||
err := statement.BindExpr(&statement.Context{DB: db, Tx: tx, Conn: tx.Connection()}, tb, e)
|
||||
err := statement.BindExpr(&statement.Context{DB: db, Conn: tx.Connection()}, tb, e)
|
||||
require.NoError(t, err)
|
||||
return e
|
||||
}
|
||||
@@ -66,22 +64,21 @@ func TestParserUpdate(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
q, err := parser.ParseQuery(test.s)
|
||||
stmts, err := parser.ParseQuery(test.s)
|
||||
if test.errored {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
err = q.Prepare(&query.Context{
|
||||
Ctx: context.Background(),
|
||||
stmt, err := stmts[0].(statement.Preparer).Prepare(&statement.Context{
|
||||
DB: db,
|
||||
Conn: tx.Connection(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, q.Statements, 1)
|
||||
require.EqualValues(t, &statement.PreparedStreamStmt{Stream: test.expected}, q.Statements[0].(*statement.PreparedStreamStmt))
|
||||
require.Len(t, stmts, 1)
|
||||
require.EqualValues(t, test.expected.String(), stmt.(*statement.UpdateStmt).Stream.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -85,8 +85,7 @@ func NewTestTx(t testing.TB) (*database.Database, *database.Transaction, func())
|
||||
require.NoError(t, err)
|
||||
|
||||
return db, tx, func() {
|
||||
err = tx.Rollback()
|
||||
require.NoError(t, err)
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,7 +96,7 @@ func Exec(db *database.Database, tx *database.Transaction, q string, params ...e
|
||||
}
|
||||
defer res.Close()
|
||||
|
||||
return res.Skip()
|
||||
return res.Skip(context.Background())
|
||||
}
|
||||
|
||||
func Query(db *database.Database, tx *database.Transaction, q string, params ...environment.Param) (*statement.Result, error) {
|
||||
@@ -107,12 +106,8 @@ func Query(db *database.Database, tx *database.Transaction, q string, params ...
|
||||
}
|
||||
|
||||
ctx := &query.Context{Ctx: context.Background(), DB: db, Conn: tx.Connection(), Params: params}
|
||||
err = pq.Prepare(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pq.Run(ctx)
|
||||
return query.New(pq...).Run(ctx)
|
||||
}
|
||||
|
||||
func MustExec(t *testing.T, db *database.Database, tx *database.Transaction, q string, params ...environment.Param) {
|
||||
|
Reference in New Issue
Block a user