Move stream building to query package

This commit is contained in:
Asdine El Hrychy
2021-05-23 23:50:18 +04:00
parent 6fff623cd1
commit 1a3d4f57a9
17 changed files with 526 additions and 429 deletions

93
internal/query/delete.go Normal file
View File

@@ -0,0 +1,93 @@
package query
import (
"github.com/genjidb/genji/internal/database"
"github.com/genjidb/genji/internal/expr"
"github.com/genjidb/genji/internal/sql/scanner"
"github.com/genjidb/genji/internal/stream"
"github.com/genjidb/genji/internal/stringutil"
)
// DeleteConfig holds DELETE configuration.
type DeleteStmt struct {
TableName string
WhereExpr expr.Expr
OffsetExpr expr.Expr
OrderBy expr.Path
LimitExpr expr.Expr
OrderByDirection scanner.Token
}
func (stmt *DeleteStmt) Run(tx *database.Transaction, params []expr.Param) (Result, error) {
var res Result
s, err := stmt.ToStream()
if err != nil {
return res, err
}
return s.Run(tx, params)
}
func (stmt *DeleteStmt) IsReadOnly() bool {
return false
}
func (stmt *DeleteStmt) ToStream() (*StreamStmt, error) {
s := stream.New(stream.SeqScan(stmt.TableName))
if stmt.WhereExpr != nil {
s = s.Pipe(stream.Filter(stmt.WhereExpr))
}
if stmt.OrderBy != nil {
if stmt.OrderByDirection == scanner.DESC {
s = s.Pipe(stream.SortReverse(stmt.OrderBy))
} else {
s = s.Pipe(stream.Sort(stmt.OrderBy))
}
}
if stmt.OffsetExpr != nil {
v, err := stmt.OffsetExpr.Eval(&expr.Environment{})
if err != nil {
return nil, err
}
if !v.Type.IsNumber() {
return nil, stringutil.Errorf("offset expression must evaluate to a number, got %q", v.Type)
}
v, err = v.CastAsInteger()
if err != nil {
return nil, err
}
s = s.Pipe(stream.Skip(v.V.(int64)))
}
if stmt.LimitExpr != nil {
v, err := stmt.LimitExpr.Eval(&expr.Environment{})
if err != nil {
return nil, err
}
if !v.Type.IsNumber() {
return nil, stringutil.Errorf("limit expression must evaluate to a number, got %q", v.Type)
}
v, err = v.CastAsInteger()
if err != nil {
return nil, err
}
s = s.Pipe(stream.Take(v.V.(int64)))
}
s = s.Pipe(stream.TableDelete(stmt.TableName))
return &StreamStmt{
Stream: s,
ReadOnly: false,
}, nil
}

View File

@@ -21,35 +21,50 @@ type ExplainStmt struct {
// If the statement is a stream, Optimize will be called prior to
// displaying all the operations.
// Explain currently only works on SELECT, UPDATE, INSERT and DELETE statements.
func (s *ExplainStmt) Run(tx *database.Transaction, params []expr.Param) (Result, error) {
switch t := s.Statement.(type) {
case *StreamStmt:
s, err := planner.Optimize(t.Stream, tx, params)
if err != nil {
return Result{}, err
}
func (stmt *ExplainStmt) Run(tx *database.Transaction, params []expr.Param) (Result, error) {
var ss *StreamStmt
var err error
var res Result
var plan string
if s != nil {
plan = s.String()
} else {
plan = "<no exec>"
}
newStatement := StreamStmt{
Stream: &stream.Stream{
Op: stream.Project(
&expr.NamedExpr{
ExprName: "plan",
Expr: expr.LiteralValue(document.NewTextValue(plan)),
}),
},
ReadOnly: true,
}
return newStatement.Run(tx, params)
switch t := stmt.Statement.(type) {
case *SelectStmt:
ss, err = t.ToStream()
case *UpdateStmt:
ss = t.ToStream()
case *InsertStmt:
ss, err = t.ToStream()
case *DeleteStmt:
ss, err = t.ToStream()
default:
return Result{}, errors.New("EXPLAIN only works on INSERT, SELECT, UPDATE AND DELETE statements")
}
if err != nil {
return res, err
}
return Result{}, errors.New("EXPLAIN only works on INSERT, SELECT, UPDATE AND DELETE statements")
s, err := planner.Optimize(ss.Stream, tx, params)
if err != nil {
return Result{}, err
}
var plan string
if s != nil {
plan = s.String()
} else {
plan = "<no exec>"
}
newStatement := StreamStmt{
Stream: &stream.Stream{
Op: stream.Project(
&expr.NamedExpr{
ExprName: "plan",
Expr: expr.LiteralValue(document.NewTextValue(plan)),
}),
},
ReadOnly: true,
}
return newStatement.Run(tx, params)
}
// IsReadOnly indicates that this statement doesn't write anything into

69
internal/query/insert.go Normal file
View File

@@ -0,0 +1,69 @@
package query
import (
"errors"
"github.com/genjidb/genji/internal/database"
"github.com/genjidb/genji/internal/expr"
"github.com/genjidb/genji/internal/stream"
)
// InsertStmt holds INSERT configuration.
type InsertStmt struct {
TableName string
Values []expr.Expr
Fields []string
SelectStmt *SelectStmt
Returning []expr.Expr
}
func (stmt *InsertStmt) Run(tx *database.Transaction, params []expr.Param) (Result, error) {
var res Result
s, err := stmt.ToStream()
if err != nil {
return res, err
}
return s.Run(tx, params)
}
func (stmt *InsertStmt) IsReadOnly() bool {
return false
}
func (stmt *InsertStmt) ToStream() (*StreamStmt, error) {
var s *stream.Stream
if stmt.Values != nil {
s = stream.New(stream.Expressions(stmt.Values...))
s = s.Pipe(stream.TableInsert(stmt.TableName))
} else {
st, err := stmt.SelectStmt.ToStream()
if err != nil {
return nil, err
}
s = st.Stream
// ensure we are not reading and writing to the same table.
if s.First().(*stream.SeqScanOperator).TableName == stmt.TableName {
return nil, errors.New("cannot read and write to the same table")
}
if len(stmt.Fields) > 0 {
s = s.Pipe(stream.IterRename(stmt.Fields...))
}
s = s.Pipe(stream.TableInsert(stmt.TableName))
}
if len(stmt.Returning) > 0 {
s = s.Pipe(stream.Project(stmt.Returning...))
}
return &StreamStmt{
Stream: s,
ReadOnly: false,
}, nil
}

View File

@@ -23,6 +23,7 @@ func TestInsertStmt(t *testing.T) {
{"Values / Invalid params", "INSERT INTO test (a, b, c) VALUES ('d', ?)", true, "", []interface{}{'e'}},
{"Documents / Named Params", "INSERT INTO test VALUES {a: $a, b: 2.3, c: $c}", false, `[{"pk()":1,"a":1,"b":2.3,"c":true}]`, []interface{}{sql.Named("c", true), sql.Named("a", 1)}},
{"Documents / List ", "INSERT INTO test VALUES {a: [1, 2, 3]}", false, `[{"pk()":1,"a":[1,2,3]}]`, nil},
{"Select / same table", "INSERT INTO test SELECT * FROM test", true, ``, nil},
}
for _, test := range tests {

188
internal/query/select.go Normal file
View File

@@ -0,0 +1,188 @@
package query
import (
"errors"
"github.com/genjidb/genji/internal/database"
"github.com/genjidb/genji/internal/expr"
"github.com/genjidb/genji/internal/sql/scanner"
"github.com/genjidb/genji/internal/stream"
"github.com/genjidb/genji/internal/stringutil"
)
// SelectStmt holds SELECT configuration.
type SelectStmt struct {
TableName string
Distinct bool
WhereExpr expr.Expr
GroupByExpr expr.Expr
OrderBy expr.Path
OrderByDirection scanner.Token
OffsetExpr expr.Expr
LimitExpr expr.Expr
ProjectionExprs []expr.Expr
}
func (stmt *SelectStmt) Run(tx *database.Transaction, params []expr.Param) (Result, error) {
var res Result
s, err := stmt.ToStream()
if err != nil {
return res, err
}
return s.Run(tx, params)
}
func (stmt *SelectStmt) IsReadOnly() bool {
return true
}
func (stmt *SelectStmt) ToStream() (*StreamStmt, error) {
var s *stream.Stream
if stmt.TableName != "" {
s = stream.New(stream.SeqScan(stmt.TableName))
}
if stmt.WhereExpr != nil {
s = s.Pipe(stream.Filter(stmt.WhereExpr))
}
// when using GROUP BY, only aggregation functions or GroupByExpr can be selected
if stmt.GroupByExpr != nil {
// add Group node
s = s.Pipe(stream.GroupBy(stmt.GroupByExpr))
var invalidProjectedField expr.Expr
var aggregators []expr.AggregatorBuilder
for _, pe := range stmt.ProjectionExprs {
ne, ok := pe.(*expr.NamedExpr)
if !ok {
invalidProjectedField = pe
break
}
e := ne.Expr
// check if the projected expression is an aggregation function
if agg, ok := e.(expr.AggregatorBuilder); ok {
aggregators = append(aggregators, agg)
continue
}
// check if this is the same expression as the one used in the GROUP BY clause
if expr.Equal(e, stmt.GroupByExpr) {
continue
}
// otherwise it's an error
invalidProjectedField = ne
break
}
if invalidProjectedField != nil {
return nil, stringutil.Errorf("field %q must appear in the GROUP BY clause or be used in an aggregate function", invalidProjectedField)
}
// add Aggregation node
s = s.Pipe(stream.HashAggregate(aggregators...))
} else {
// if there is no GROUP BY clause, check if there are any aggregation function
// and if so add an aggregation node
var aggregators []expr.AggregatorBuilder
for _, pe := range stmt.ProjectionExprs {
ne, ok := pe.(*expr.NamedExpr)
if !ok {
continue
}
e := ne.Expr
// check if the projected expression is an aggregation function
if agg, ok := e.(expr.AggregatorBuilder); ok {
aggregators = append(aggregators, agg)
}
}
// add Aggregation node
if len(aggregators) > 0 {
s = s.Pipe(stream.HashAggregate(aggregators...))
}
}
// If there is no FROM clause ensure there is no wildcard or path
if stmt.TableName == "" {
var err error
for _, e := range stmt.ProjectionExprs {
expr.Walk(e, func(e expr.Expr) bool {
switch e.(type) {
case expr.Path, expr.Wildcard:
err = errors.New("no tables specified")
return false
default:
return true
}
})
if err != nil {
return nil, err
}
}
}
s = s.Pipe(stream.Project(stmt.ProjectionExprs...))
if stmt.Distinct {
s = s.Pipe(stream.Distinct())
}
if stmt.OrderBy != nil {
if stmt.OrderByDirection == scanner.DESC {
s = s.Pipe(stream.SortReverse(stmt.OrderBy))
} else {
s = s.Pipe(stream.Sort(stmt.OrderBy))
}
}
if stmt.OffsetExpr != nil {
v, err := stmt.OffsetExpr.Eval(&expr.Environment{})
if err != nil {
return nil, err
}
if !v.Type.IsNumber() {
return nil, stringutil.Errorf("offset expression must evaluate to a number, got %q", v.Type)
}
v, err = v.CastAsInteger()
if err != nil {
return nil, err
}
s = s.Pipe(stream.Skip(v.V.(int64)))
}
if stmt.LimitExpr != nil {
v, err := stmt.LimitExpr.Eval(&expr.Environment{})
if err != nil {
return nil, err
}
if !v.Type.IsNumber() {
return nil, stringutil.Errorf("limit expression must evaluate to a number, got %q", v.Type)
}
v, err = v.CastAsInteger()
if err != nil {
return nil, err
}
s = s.Pipe(stream.Take(v.V.(int64)))
}
return &StreamStmt{
Stream: s,
ReadOnly: true,
}, nil
}

View File

@@ -57,6 +57,9 @@ func TestSelectStmt(t *testing.T) {
{"With group by and count", "SELECT COUNT(k) FROM test GROUP BY size", false, `[{"COUNT(k)":2},{"COUNT(k)":1}]`, nil},
{"With group by and count wildcard", "SELECT COUNT(* ) FROM test GROUP BY size", false, `[{"COUNT(*)":2},{"COUNT(*)":1}]`, nil},
{"With order by", "SELECT * FROM test ORDER BY color", false, `[{"k":3,"height":100,"weight":200},{"k":2,"color":"blue","size":10,"weight":100},{"k":1,"color":"red","size":10,"shape":"square"}]`, nil},
{"With invalid group by / wildcard", "SELECT * FROM test WHERE age = 10 GROUP BY a.b.c", true, ``, nil},
{"With invalid group by / a.b", "SELECT a.b FROM test WHERE age = 10 GROUP BY a.b.c", true, ``, nil},
{"With order by", "SELECT * FROM test ORDER BY color", false, `[{"k":3,"height":100,"weight":200},{"k":2,"color":"blue","size":10,"weight":100},{"k":1,"color":"red","size":10,"shape":"square"}]`, nil},
{"With order by asc", "SELECT * FROM test ORDER BY color ASC", false, `[{"k":3,"height":100,"weight":200},{"k":2,"color":"blue","size":10,"weight":100},{"k":1,"color":"red","size":10,"shape":"square"}]`, nil},
{"With order by asc numeric", "SELECT * FROM test ORDER BY weight ASC", false, `[{"k":1,"color":"red","size":10,"shape":"square"},{"k":2,"color":"blue","size":10,"weight":100},{"k":3,"height":100,"weight":200}]`, nil},
{"With order by asc with limit 2", "SELECT * FROM test ORDER BY color LIMIT 2", false, `[{"k":3,"height":100,"weight":200},{"k":2,"color":"blue","size":10,"weight":100}]`, nil},
@@ -94,6 +97,11 @@ func TestSelectStmt(t *testing.T) {
{"With two non existing idents, !=", "SELECT * FROM test WHERE z != y", false, `[]`, nil},
// See issue https://github.com/genjidb/genji/issues/283
{"With empty WHERE and IN", "SELECT * FROM test WHERE [] IN [];", false, `[]`, nil},
{"Invalid use of MIN() aggregator", "SELECT * FROM test LIMIT min(0)", true, ``, nil},
{"Invalid use of COUNT() aggregator", "SELECT * FROM test OFFSET x(*)", true, ``, nil},
{"Invalid use of MAX() aggregator", "SELECT * FROM test LIMIT max(0)", true, ``, nil},
{"Invalid use of SUM() aggregator", "SELECT * FROM test LIMIT sum(0)", true, ``, nil},
{"Invalid use of AVG() aggregator", "SELECT * FROM test LIMIT avg(0)", true, ``, nil},
}
for _, test := range tests {

65
internal/query/update.go Normal file
View File

@@ -0,0 +1,65 @@
package query
import (
"github.com/genjidb/genji/document"
"github.com/genjidb/genji/internal/database"
"github.com/genjidb/genji/internal/expr"
"github.com/genjidb/genji/internal/stream"
)
// UpdateConfig holds UPDATE configuration.
type UpdateStmt struct {
TableName string
// SetPairs is used along with the Set clause. It holds
// each path with its corresponding value that
// should be set in the document.
SetPairs []UpdateSetPair
// UnsetFields is used along with the Unset clause. It holds
// each path that should be unset from the document.
UnsetFields []string
WhereExpr expr.Expr
}
type UpdateSetPair struct {
Path document.Path
E expr.Expr
}
func (stmt *UpdateStmt) Run(tx *database.Transaction, params []expr.Param) (Result, error) {
s := stmt.ToStream()
return s.Run(tx, params)
}
func (stmt *UpdateStmt) IsReadOnly() bool {
return false
}
// ToTree turns the statement into a stream.
func (stmt *UpdateStmt) ToStream() *StreamStmt {
s := stream.New(stream.SeqScan(stmt.TableName))
if stmt.WhereExpr != nil {
s = s.Pipe(stream.Filter(stmt.WhereExpr))
}
if stmt.SetPairs != nil {
for _, pair := range stmt.SetPairs {
s = s.Pipe(stream.Set(pair.Path, pair.E))
}
} else if stmt.UnsetFields != nil {
for _, name := range stmt.UnsetFields {
s = s.Pipe(stream.Unset(name))
}
}
s = s.Pipe(stream.TableReplace(stmt.TableName))
return &StreamStmt{
Stream: s,
ReadOnly: false,
}
}

View File

@@ -1,17 +1,14 @@
package parser
import (
"github.com/genjidb/genji/internal/expr"
"github.com/genjidb/genji/internal/query"
"github.com/genjidb/genji/internal/sql/scanner"
"github.com/genjidb/genji/internal/stream"
"github.com/genjidb/genji/internal/stringutil"
)
// parseDeleteStatement parses a delete string and returns a Statement AST object.
// This function assumes the DELETE token has already been consumed.
func (p *Parser) parseDeleteStatement() (*query.StreamStmt, error) {
var cfg deleteConfig
func (p *Parser) parseDeleteStatement() (*query.DeleteStmt, error) {
var stmt query.DeleteStmt
var err error
// Parse "FROM".
@@ -20,7 +17,7 @@ func (p *Parser) parseDeleteStatement() (*query.StreamStmt, error) {
}
// Parse table name
cfg.TableName, err = p.parseIdent()
stmt.TableName, err = p.parseIdent()
if err != nil {
pErr := err.(*ParseError)
pErr.Expected = []string{"table_name"}
@@ -28,97 +25,28 @@ func (p *Parser) parseDeleteStatement() (*query.StreamStmt, error) {
}
// Parse condition: "WHERE EXPR".
cfg.WhereExpr, err = p.parseCondition()
stmt.WhereExpr, err = p.parseCondition()
if err != nil {
return nil, err
}
// Parse order by: "ORDER BY path [ASC|DESC]?"
cfg.OrderBy, cfg.OrderByDirection, err = p.parseOrderBy()
stmt.OrderBy, stmt.OrderByDirection, err = p.parseOrderBy()
if err != nil {
return nil, err
}
// Parse limit: "LIMIT expr"
cfg.LimitExpr, err = p.parseLimit()
stmt.LimitExpr, err = p.parseLimit()
if err != nil {
return nil, err
}
// Parse offset: "OFFSET expr"
cfg.OffsetExpr, err = p.parseOffset()
stmt.OffsetExpr, err = p.parseOffset()
if err != nil {
return nil, err
}
return cfg.ToStream()
}
// DeleteConfig holds DELETE configuration.
type deleteConfig struct {
TableName string
WhereExpr expr.Expr
OffsetExpr expr.Expr
OrderBy expr.Path
LimitExpr expr.Expr
OrderByDirection scanner.Token
}
func (cfg deleteConfig) ToStream() (*query.StreamStmt, error) {
s := stream.New(stream.SeqScan(cfg.TableName))
if cfg.WhereExpr != nil {
s = s.Pipe(stream.Filter(cfg.WhereExpr))
}
if cfg.OrderBy != nil {
if cfg.OrderByDirection == scanner.DESC {
s = s.Pipe(stream.SortReverse(cfg.OrderBy))
} else {
s = s.Pipe(stream.Sort(cfg.OrderBy))
}
}
if cfg.OffsetExpr != nil {
v, err := cfg.OffsetExpr.Eval(&expr.Environment{})
if err != nil {
return nil, err
}
if !v.Type.IsNumber() {
return nil, stringutil.Errorf("offset expression must evaluate to a number, got %q", v.Type)
}
v, err = v.CastAsInteger()
if err != nil {
return nil, err
}
s = s.Pipe(stream.Skip(v.V.(int64)))
}
if cfg.LimitExpr != nil {
v, err := cfg.LimitExpr.Eval(&expr.Environment{})
if err != nil {
return nil, err
}
if !v.Type.IsNumber() {
return nil, stringutil.Errorf("limit expression must evaluate to a number, got %q", v.Type)
}
v, err = v.CastAsInteger()
if err != nil {
return nil, err
}
s = s.Pipe(stream.Take(v.V.(int64)))
}
s = s.Pipe(stream.TableDelete(cfg.TableName))
return &query.StreamStmt{
Stream: s,
ReadOnly: false,
}, nil
return &stmt, nil
}

View File

@@ -54,7 +54,9 @@ func TestParserDelete(t *testing.T) {
q, err := parser.ParseQuery(test.s)
require.NoError(t, err)
require.Len(t, q.Statements, 1)
require.EqualValues(t, &query.StreamStmt{Stream: test.expected}, q.Statements[0])
stmt, err := q.Statements[0].(*query.DeleteStmt).ToStream()
require.NoError(t, err)
require.EqualValues(t, &query.StreamStmt{Stream: test.expected}, stmt)
})
}
}

View File

@@ -6,7 +6,6 @@ import (
"github.com/genjidb/genji/internal/expr"
"github.com/genjidb/genji/internal/query"
"github.com/genjidb/genji/internal/sql/parser"
"github.com/genjidb/genji/internal/stream"
"github.com/stretchr/testify/require"
)
@@ -17,7 +16,10 @@ func TestParserExplain(t *testing.T) {
expected query.Statement
errored bool
}{
{"Explain create table", "EXPLAIN SELECT * FROM test", &query.ExplainStmt{Statement: &query.StreamStmt{Stream: stream.New(stream.SeqScan("test")).Pipe(stream.Project(expr.Wildcard{})), ReadOnly: true}}, false},
{"Explain create table", "EXPLAIN SELECT * FROM test", &query.ExplainStmt{Statement: &query.SelectStmt{
TableName: "test",
ProjectionExprs: []expr.Expr{expr.Wildcard{}},
}}, false},
{"Multiple Explains", "EXPLAIN EXPLAIN CREATE TABLE test", nil, true},
}

View File

@@ -1,19 +1,16 @@
package parser
import (
"errors"
"github.com/genjidb/genji/internal/expr"
"github.com/genjidb/genji/internal/query"
"github.com/genjidb/genji/internal/sql/scanner"
"github.com/genjidb/genji/internal/stream"
"github.com/genjidb/genji/internal/stringutil"
)
// parseInsertStatement parses an insert string and returns a Statement AST object.
// This function assumes the INSERT token has already been consumed.
func (p *Parser) parseInsertStatement() (*query.StreamStmt, error) {
var cfg insertConfig
func (p *Parser) parseInsertStatement() (*query.InsertStmt, error) {
var stmt query.InsertStmt
var err error
// Parse "INTO".
@@ -22,7 +19,7 @@ func (p *Parser) parseInsertStatement() (*query.StreamStmt, error) {
}
// Parse table name
cfg.TableName, err = p.parseIdent()
stmt.TableName, err = p.parseIdent()
if err != nil {
pErr := err.(*ParseError)
pErr.Expected = []string{"table_name"}
@@ -30,7 +27,7 @@ func (p *Parser) parseInsertStatement() (*query.StreamStmt, error) {
}
// Parse path list: (a, b, c)
cfg.Fields, err = p.parseFieldList()
stmt.Fields, err = p.parseFieldList()
if err != nil {
return nil, err
}
@@ -40,12 +37,12 @@ func (p *Parser) parseInsertStatement() (*query.StreamStmt, error) {
switch tok {
case scanner.VALUES:
// Parse VALUES (v1, v2, v3)
cfg.Values, err = p.parseValues(cfg.Fields)
stmt.Values, err = p.parseValues(stmt.Fields)
if err != nil {
return nil, err
}
case scanner.SELECT:
cfg.SelectStmt, err = p.parseSelectStatement()
stmt.SelectStmt, err = p.parseSelectStatement()
if err != nil {
return nil, err
}
@@ -53,12 +50,12 @@ func (p *Parser) parseInsertStatement() (*query.StreamStmt, error) {
return nil, newParseError(scanner.Tokstr(tok, lit), []string{"VALUES", "SELECT"}, pos)
}
cfg.Returning, err = p.parseReturning()
stmt.Returning, err = p.parseReturning()
if err != nil {
return nil, err
}
return cfg.ToStream()
return &stmt, nil
}
// parseFieldList parses a list of fields in the form: (path, path, ...), if exists.
@@ -202,43 +199,3 @@ func (p *Parser) parseReturning() ([]expr.Expr, error) {
return p.parseProjectedExprs()
}
// insertConfig holds INSERT configuration.
type insertConfig struct {
TableName string
Values []expr.Expr
Fields []string
SelectStmt *query.StreamStmt
Returning []expr.Expr
}
func (cfg *insertConfig) ToStream() (*query.StreamStmt, error) {
var s *stream.Stream
if cfg.Values != nil {
s = stream.New(stream.Expressions(cfg.Values...))
s = s.Pipe(stream.TableInsert(cfg.TableName))
} else {
s = cfg.SelectStmt.Stream
// ensure we are not reading and writing to the same table.
if s.First().(*stream.SeqScanOperator).TableName == cfg.TableName {
return nil, errors.New("cannot read and write to the same table")
}
if len(cfg.Fields) > 0 {
s = s.Pipe(stream.IterRename(cfg.Fields...))
}
s = s.Pipe(stream.TableInsert(cfg.TableName))
}
if len(cfg.Returning) > 0 {
s = s.Pipe(stream.Project(cfg.Returning...))
}
return &query.StreamStmt{
Stream: s,
ReadOnly: false,
}, nil
}

View File

@@ -87,8 +87,6 @@ func TestParserInsert(t *testing.T) {
nil, true},
{"Values / Without fields / Wrong values", "INSERT INTO test VALUES {a: 1}, ('e', 'f')",
nil, true},
{"Select / same table", "INSERT INTO test SELECT * FROM test",
nil, true},
{"Select / Without fields", "INSERT INTO test SELECT * FROM foo",
stream.New(stream.SeqScan("foo")).
Pipe(stream.Project(expr.Wildcard{})).
@@ -135,9 +133,11 @@ func TestParserInsert(t *testing.T) {
}
require.NoError(t, err)
require.Len(t, q.Statements, 1)
stmt := q.Statements[0].(*query.StreamStmt)
require.False(t, stmt.ReadOnly)
require.EqualValues(t, test.expected.String(), stmt.Stream.String())
stmt := q.Statements[0].(*query.InsertStmt)
require.False(t, stmt.IsReadOnly())
ss, err := stmt.ToStream()
require.NoError(t, err)
require.EqualValues(t, test.expected.String(), ss.String())
})
}
}

View File

@@ -8,7 +8,6 @@ import (
"github.com/genjidb/genji/internal/expr"
"github.com/genjidb/genji/internal/query"
"github.com/genjidb/genji/internal/sql/parser"
"github.com/genjidb/genji/internal/stream"
)
func TestParserMultiStatement(t *testing.T) {
@@ -19,12 +18,12 @@ func TestParserMultiStatement(t *testing.T) {
}{
{"OnlyCommas", ";;;", nil},
{"TrailingComma", "SELECT * FROM foo;;;DELETE FROM foo;", []query.Statement{
&query.StreamStmt{
Stream: stream.New(stream.SeqScan("foo")).Pipe(stream.Project(expr.Wildcard{})),
ReadOnly: true,
&query.SelectStmt{
TableName: "foo",
ProjectionExprs: []expr.Expr{expr.Wildcard{}},
},
&query.StreamStmt{
Stream: stream.New(stream.SeqScan("foo")).Pipe(stream.TableDelete("foo")),
&query.DeleteStmt{
TableName: "foo",
},
}},
}

View File

@@ -1,73 +1,69 @@
package parser
import (
"errors"
"github.com/genjidb/genji/internal/expr"
"github.com/genjidb/genji/internal/query"
"github.com/genjidb/genji/internal/sql/scanner"
"github.com/genjidb/genji/internal/stream"
"github.com/genjidb/genji/internal/stringutil"
)
// parseSelectStatement parses a select string and returns a Statement AST object.
// This function assumes the SELECT token has already been consumed.
func (p *Parser) parseSelectStatement() (*query.StreamStmt, error) {
var cfg selectConfig
func (p *Parser) parseSelectStatement() (*query.SelectStmt, error) {
var stmt query.SelectStmt
var err error
cfg.Distinct, err = p.parseDistinct()
stmt.Distinct, err = p.parseDistinct()
if err != nil {
return nil, err
}
// Parse path list or query.Wildcard
cfg.ProjectionExprs, err = p.parseProjectedExprs()
stmt.ProjectionExprs, err = p.parseProjectedExprs()
if err != nil {
return nil, err
}
// Parse "FROM".
var found bool
cfg.TableName, found, err = p.parseFrom()
stmt.TableName, found, err = p.parseFrom()
if err != nil {
return nil, err
}
if !found {
return cfg.ToStream()
return &stmt, nil
}
// Parse condition: "WHERE expr".
cfg.WhereExpr, err = p.parseCondition()
stmt.WhereExpr, err = p.parseCondition()
if err != nil {
return nil, err
}
// Parse group by: "GROUP BY expr"
cfg.GroupByExpr, err = p.parseGroupBy()
stmt.GroupByExpr, err = p.parseGroupBy()
if err != nil {
return nil, err
}
// Parse order by: "ORDER BY path [ASC|DESC]?"
cfg.OrderBy, cfg.OrderByDirection, err = p.parseOrderBy()
stmt.OrderBy, stmt.OrderByDirection, err = p.parseOrderBy()
if err != nil {
return nil, err
}
// Parse limit: "LIMIT expr"
cfg.LimitExpr, err = p.parseLimit()
stmt.LimitExpr, err = p.parseLimit()
if err != nil {
return nil, err
}
// Parse offset: "OFFSET expr"
cfg.OffsetExpr, err = p.parseOffset()
stmt.OffsetExpr, err = p.parseOffset()
if err != nil {
return nil, err
}
return cfg.ToStream()
return &stmt, nil
}
// parseProjectedExprs parses the list of projected fields.
@@ -158,165 +154,3 @@ func (p *Parser) parseGroupBy() (expr.Expr, error) {
e, err := p.ParseExpr()
return e, err
}
// SelectConfig holds SELECT configuration.
type selectConfig struct {
TableName string
Distinct bool
WhereExpr expr.Expr
GroupByExpr expr.Expr
OrderBy expr.Path
OrderByDirection scanner.Token
OffsetExpr expr.Expr
LimitExpr expr.Expr
ProjectionExprs []expr.Expr
}
func (cfg selectConfig) ToStream() (*query.StreamStmt, error) {
var s *stream.Stream
if cfg.TableName != "" {
s = stream.New(stream.SeqScan(cfg.TableName))
}
if cfg.WhereExpr != nil {
s = s.Pipe(stream.Filter(cfg.WhereExpr))
}
// when using GROUP BY, only aggregation functions or GroupByExpr can be selected
if cfg.GroupByExpr != nil {
// add Group node
s = s.Pipe(stream.GroupBy(cfg.GroupByExpr))
var invalidProjectedField expr.Expr
var aggregators []expr.AggregatorBuilder
for _, pe := range cfg.ProjectionExprs {
ne, ok := pe.(*expr.NamedExpr)
if !ok {
invalidProjectedField = pe
break
}
e := ne.Expr
// check if the projected expression is an aggregation function
if agg, ok := e.(expr.AggregatorBuilder); ok {
aggregators = append(aggregators, agg)
continue
}
// check if this is the same expression as the one used in the GROUP BY clause
if expr.Equal(e, cfg.GroupByExpr) {
continue
}
// otherwise it's an error
invalidProjectedField = ne
break
}
if invalidProjectedField != nil {
return nil, stringutil.Errorf("field %q must appear in the GROUP BY clause or be used in an aggregate function", invalidProjectedField)
}
// add Aggregation node
s = s.Pipe(stream.HashAggregate(aggregators...))
} else {
// if there is no GROUP BY clause, check if there are any aggregation function
// and if so add an aggregation node
var aggregators []expr.AggregatorBuilder
for _, pe := range cfg.ProjectionExprs {
ne, ok := pe.(*expr.NamedExpr)
if !ok {
continue
}
e := ne.Expr
// check if the projected expression is an aggregation function
if agg, ok := e.(expr.AggregatorBuilder); ok {
aggregators = append(aggregators, agg)
}
}
// add Aggregation node
if len(aggregators) > 0 {
s = s.Pipe(stream.HashAggregate(aggregators...))
}
}
// If there is no FROM clause ensure there is no wildcard or path
if cfg.TableName == "" {
var err error
for _, e := range cfg.ProjectionExprs {
expr.Walk(e, func(e expr.Expr) bool {
switch e.(type) {
case expr.Path, expr.Wildcard:
err = errors.New("no tables specified")
return false
default:
return true
}
})
if err != nil {
return nil, err
}
}
}
s = s.Pipe(stream.Project(cfg.ProjectionExprs...))
if cfg.Distinct {
s = s.Pipe(stream.Distinct())
}
if cfg.OrderBy != nil {
if cfg.OrderByDirection == scanner.DESC {
s = s.Pipe(stream.SortReverse(cfg.OrderBy))
} else {
s = s.Pipe(stream.Sort(cfg.OrderBy))
}
}
if cfg.OffsetExpr != nil {
v, err := cfg.OffsetExpr.Eval(&expr.Environment{})
if err != nil {
return nil, err
}
if !v.Type.IsNumber() {
return nil, stringutil.Errorf("offset expression must evaluate to a number, got %q", v.Type)
}
v, err = v.CastAsInteger()
if err != nil {
return nil, err
}
s = s.Pipe(stream.Skip(v.V.(int64)))
}
if cfg.LimitExpr != nil {
v, err := cfg.LimitExpr.Eval(&expr.Environment{})
if err != nil {
return nil, err
}
if !v.Type.IsNumber() {
return nil, stringutil.Errorf("limit expression must evaluate to a number, got %q", v.Type)
}
v, err = v.CastAsInteger()
if err != nil {
return nil, err
}
s = s.Pipe(stream.Take(v.V.(int64)))
}
return &query.StreamStmt{
Stream: s,
ReadOnly: true,
}, nil
}

View File

@@ -22,15 +22,6 @@ func TestParserSelect(t *testing.T) {
stream.New(stream.Project(testutil.ParseNamedExpr(t, "1"))),
false,
},
{"NoTable/path", "SELECT a",
nil,
true,
},
{"NoTable/wildcard", "SELECT *",
nil,
true,
},
{"Wildcard with no FORM", "SELECT *", nil, true},
{"NoTableWithTuple", "SELECT (1, 2)",
stream.New(stream.Project(testutil.ParseNamedExpr(t, "[1, 2]"))),
false,
@@ -84,8 +75,6 @@ func TestParserSelect(t *testing.T) {
Pipe(stream.Project(testutil.ParseNamedExpr(t, "a.b.c"))),
false,
},
{"With Invalid GroupBy: Wildcard", "SELECT * FROM test WHERE age = 10 GROUP BY a.b.c", nil, true},
{"With Invalid GroupBy: a.b", "SELECT a.b FROM test WHERE age = 10 GROUP BY a.b.c", nil, true},
{"WithOrderBy", "SELECT * FROM test WHERE age = 10 ORDER BY a.b.c",
stream.New(stream.SeqScan("test")).
Pipe(stream.Filter(parser.MustParseExpr("age = 10"))).
@@ -135,11 +124,6 @@ func TestParserSelect(t *testing.T) {
Pipe(stream.HashAggregate(&expr.CountFunc{Wildcard: true})).
Pipe(stream.Project(testutil.ParseNamedExpr(t, "COUNT(*)"))),
false},
{"Invalid use of MIN() aggregator", "SELECT * FROM test LIMIT min(0)", nil, true},
{"Invalid use of COUNT() aggregator", "SELECT * FROM test OFFSET x(*)", nil, true},
{"Invalid use of MAX() aggregator", "SELECT * FROM test LIMIT max(0)", nil, true},
{"Invalid use of SUM() aggregator", "SELECT * FROM test LIMIT sum(0)", nil, true},
{"Invalid use of AVG() aggregator", "SELECT * FROM test LIMIT avg(0)", nil, true},
}
for _, test := range tests {
@@ -148,7 +132,9 @@ func TestParserSelect(t *testing.T) {
if !test.mustFail {
require.NoError(t, err)
require.Len(t, q.Statements, 1)
require.EqualValues(t, &query.StreamStmt{Stream: test.expected, ReadOnly: true}, q.Statements[0])
st, err := q.Statements[0].(*query.SelectStmt).ToStream()
require.NoError(t, err)
require.EqualValues(t, &query.StreamStmt{Stream: test.expected, ReadOnly: true}, st)
} else {
require.Error(t, err)
}

View File

@@ -1,21 +1,18 @@
package parser
import (
"github.com/genjidb/genji/document"
"github.com/genjidb/genji/internal/expr"
"github.com/genjidb/genji/internal/query"
"github.com/genjidb/genji/internal/sql/scanner"
"github.com/genjidb/genji/internal/stream"
)
// parseUpdateStatement parses a update string and returns a Statement AST object.
// This function assumes the UPDATE token has already been consumed.
func (p *Parser) parseUpdateStatement() (*query.StreamStmt, error) {
var cfg updateConfig
func (p *Parser) parseUpdateStatement() (*query.UpdateStmt, error) {
var stmt query.UpdateStmt
var err error
// Parse table name
cfg.TableName, err = p.parseIdent()
stmt.TableName, err = p.parseIdent()
if err != nil {
pErr := err.(*ParseError)
pErr.Expected = []string{"table_name"}
@@ -26,9 +23,9 @@ func (p *Parser) parseUpdateStatement() (*query.StreamStmt, error) {
tok, pos, lit := p.ScanIgnoreWhitespace()
switch tok {
case scanner.SET:
cfg.SetPairs, err = p.parseSetClause()
stmt.SetPairs, err = p.parseSetClause()
case scanner.UNSET:
cfg.UnsetFields, err = p.parseUnsetClause()
stmt.UnsetFields, err = p.parseUnsetClause()
default:
err = newParseError(scanner.Tokstr(tok, lit), []string{"SET", "UNSET"}, pos)
}
@@ -37,17 +34,17 @@ func (p *Parser) parseUpdateStatement() (*query.StreamStmt, error) {
}
// Parse condition: "WHERE EXPR".
cfg.WhereExpr, err = p.parseCondition()
stmt.WhereExpr, err = p.parseCondition()
if err != nil {
return nil, err
}
return cfg.ToStream(), nil
return &stmt, nil
}
// parseSetClause parses the "SET" clause of the query.
func (p *Parser) parseSetClause() ([]updateSetPair, error) {
var pairs []updateSetPair
func (p *Parser) parseSetClause() ([]query.UpdateSetPair, error) {
var pairs []query.UpdateSetPair
firstPair := true
for {
@@ -78,7 +75,7 @@ func (p *Parser) parseSetClause() ([]updateSetPair, error) {
if err != nil {
return nil, err
}
pairs = append(pairs, updateSetPair{path, expr})
pairs = append(pairs, query.UpdateSetPair{Path: path, E: expr})
firstPair = false
}
@@ -111,50 +108,3 @@ func (p *Parser) parseUnsetClause() ([]string, error) {
}
return fields, nil
}
// UpdateConfig holds UPDATE configuration.
type updateConfig struct {
TableName string
// SetPairs is used along with the Set clause. It holds
// each path with its corresponding value that
// should be set in the document.
SetPairs []updateSetPair
// UnsetFields is used along with the Unset clause. It holds
// each path that should be unset from the document.
UnsetFields []string
WhereExpr expr.Expr
}
type updateSetPair struct {
path document.Path
e expr.Expr
}
// ToTree turns the statement into a stream.
func (cfg updateConfig) ToStream() *query.StreamStmt {
s := stream.New(stream.SeqScan(cfg.TableName))
if cfg.WhereExpr != nil {
s = s.Pipe(stream.Filter(cfg.WhereExpr))
}
if cfg.SetPairs != nil {
for _, pair := range cfg.SetPairs {
s = s.Pipe(stream.Set(pair.path, pair.e))
}
} else if cfg.UnsetFields != nil {
for _, name := range cfg.UnsetFields {
s = s.Pipe(stream.Unset(name))
}
}
s = s.Pipe(stream.TableReplace(cfg.TableName))
return &query.StreamStmt{
Stream: s,
ReadOnly: false,
}
}

View File

@@ -75,9 +75,9 @@ func TestParserUpdate(t *testing.T) {
require.NoError(t, err)
require.Len(t, q.Statements, 1)
stmt := q.Statements[0].(*query.StreamStmt)
require.False(t, stmt.ReadOnly)
require.EqualValues(t, test.expected, stmt.Stream)
stmt := q.Statements[0].(*query.UpdateStmt)
require.False(t, stmt.IsReadOnly())
require.EqualValues(t, test.expected, stmt.ToStream().Stream)
})
}
}