mirror of
https://github.com/chaisql/chai.git
synced 2025-10-29 18:11:50 +08:00
Move stream building to query package
This commit is contained in:
93
internal/query/delete.go
Normal file
93
internal/query/delete.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"github.com/genjidb/genji/internal/database"
|
||||
"github.com/genjidb/genji/internal/expr"
|
||||
"github.com/genjidb/genji/internal/sql/scanner"
|
||||
"github.com/genjidb/genji/internal/stream"
|
||||
"github.com/genjidb/genji/internal/stringutil"
|
||||
)
|
||||
|
||||
// DeleteConfig holds DELETE configuration.
|
||||
type DeleteStmt struct {
|
||||
TableName string
|
||||
WhereExpr expr.Expr
|
||||
OffsetExpr expr.Expr
|
||||
OrderBy expr.Path
|
||||
LimitExpr expr.Expr
|
||||
OrderByDirection scanner.Token
|
||||
}
|
||||
|
||||
func (stmt *DeleteStmt) Run(tx *database.Transaction, params []expr.Param) (Result, error) {
|
||||
var res Result
|
||||
|
||||
s, err := stmt.ToStream()
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
return s.Run(tx, params)
|
||||
}
|
||||
|
||||
func (stmt *DeleteStmt) IsReadOnly() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (stmt *DeleteStmt) ToStream() (*StreamStmt, error) {
|
||||
s := stream.New(stream.SeqScan(stmt.TableName))
|
||||
|
||||
if stmt.WhereExpr != nil {
|
||||
s = s.Pipe(stream.Filter(stmt.WhereExpr))
|
||||
}
|
||||
|
||||
if stmt.OrderBy != nil {
|
||||
if stmt.OrderByDirection == scanner.DESC {
|
||||
s = s.Pipe(stream.SortReverse(stmt.OrderBy))
|
||||
} else {
|
||||
s = s.Pipe(stream.Sort(stmt.OrderBy))
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.OffsetExpr != nil {
|
||||
v, err := stmt.OffsetExpr.Eval(&expr.Environment{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !v.Type.IsNumber() {
|
||||
return nil, stringutil.Errorf("offset expression must evaluate to a number, got %q", v.Type)
|
||||
}
|
||||
|
||||
v, err = v.CastAsInteger()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s = s.Pipe(stream.Skip(v.V.(int64)))
|
||||
}
|
||||
|
||||
if stmt.LimitExpr != nil {
|
||||
v, err := stmt.LimitExpr.Eval(&expr.Environment{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !v.Type.IsNumber() {
|
||||
return nil, stringutil.Errorf("limit expression must evaluate to a number, got %q", v.Type)
|
||||
}
|
||||
|
||||
v, err = v.CastAsInteger()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s = s.Pipe(stream.Take(v.V.(int64)))
|
||||
}
|
||||
|
||||
s = s.Pipe(stream.TableDelete(stmt.TableName))
|
||||
|
||||
return &StreamStmt{
|
||||
Stream: s,
|
||||
ReadOnly: false,
|
||||
}, nil
|
||||
}
|
||||
@@ -21,35 +21,50 @@ type ExplainStmt struct {
|
||||
// If the statement is a stream, Optimize will be called prior to
|
||||
// displaying all the operations.
|
||||
// Explain currently only works on SELECT, UPDATE, INSERT and DELETE statements.
|
||||
func (s *ExplainStmt) Run(tx *database.Transaction, params []expr.Param) (Result, error) {
|
||||
switch t := s.Statement.(type) {
|
||||
case *StreamStmt:
|
||||
s, err := planner.Optimize(t.Stream, tx, params)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
func (stmt *ExplainStmt) Run(tx *database.Transaction, params []expr.Param) (Result, error) {
|
||||
var ss *StreamStmt
|
||||
var err error
|
||||
var res Result
|
||||
|
||||
var plan string
|
||||
if s != nil {
|
||||
plan = s.String()
|
||||
} else {
|
||||
plan = "<no exec>"
|
||||
}
|
||||
|
||||
newStatement := StreamStmt{
|
||||
Stream: &stream.Stream{
|
||||
Op: stream.Project(
|
||||
&expr.NamedExpr{
|
||||
ExprName: "plan",
|
||||
Expr: expr.LiteralValue(document.NewTextValue(plan)),
|
||||
}),
|
||||
},
|
||||
ReadOnly: true,
|
||||
}
|
||||
return newStatement.Run(tx, params)
|
||||
switch t := stmt.Statement.(type) {
|
||||
case *SelectStmt:
|
||||
ss, err = t.ToStream()
|
||||
case *UpdateStmt:
|
||||
ss = t.ToStream()
|
||||
case *InsertStmt:
|
||||
ss, err = t.ToStream()
|
||||
case *DeleteStmt:
|
||||
ss, err = t.ToStream()
|
||||
default:
|
||||
return Result{}, errors.New("EXPLAIN only works on INSERT, SELECT, UPDATE AND DELETE statements")
|
||||
}
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
return Result{}, errors.New("EXPLAIN only works on INSERT, SELECT, UPDATE AND DELETE statements")
|
||||
s, err := planner.Optimize(ss.Stream, tx, params)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
|
||||
var plan string
|
||||
if s != nil {
|
||||
plan = s.String()
|
||||
} else {
|
||||
plan = "<no exec>"
|
||||
}
|
||||
|
||||
newStatement := StreamStmt{
|
||||
Stream: &stream.Stream{
|
||||
Op: stream.Project(
|
||||
&expr.NamedExpr{
|
||||
ExprName: "plan",
|
||||
Expr: expr.LiteralValue(document.NewTextValue(plan)),
|
||||
}),
|
||||
},
|
||||
ReadOnly: true,
|
||||
}
|
||||
return newStatement.Run(tx, params)
|
||||
}
|
||||
|
||||
// IsReadOnly indicates that this statement doesn't write anything into
|
||||
|
||||
69
internal/query/insert.go
Normal file
69
internal/query/insert.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/genjidb/genji/internal/database"
|
||||
"github.com/genjidb/genji/internal/expr"
|
||||
"github.com/genjidb/genji/internal/stream"
|
||||
)
|
||||
|
||||
// InsertStmt holds INSERT configuration.
|
||||
type InsertStmt struct {
|
||||
TableName string
|
||||
Values []expr.Expr
|
||||
Fields []string
|
||||
SelectStmt *SelectStmt
|
||||
Returning []expr.Expr
|
||||
}
|
||||
|
||||
func (stmt *InsertStmt) Run(tx *database.Transaction, params []expr.Param) (Result, error) {
|
||||
var res Result
|
||||
|
||||
s, err := stmt.ToStream()
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
return s.Run(tx, params)
|
||||
}
|
||||
|
||||
func (stmt *InsertStmt) IsReadOnly() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (stmt *InsertStmt) ToStream() (*StreamStmt, error) {
|
||||
var s *stream.Stream
|
||||
|
||||
if stmt.Values != nil {
|
||||
s = stream.New(stream.Expressions(stmt.Values...))
|
||||
|
||||
s = s.Pipe(stream.TableInsert(stmt.TableName))
|
||||
} else {
|
||||
st, err := stmt.SelectStmt.ToStream()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s = st.Stream
|
||||
|
||||
// ensure we are not reading and writing to the same table.
|
||||
if s.First().(*stream.SeqScanOperator).TableName == stmt.TableName {
|
||||
return nil, errors.New("cannot read and write to the same table")
|
||||
}
|
||||
|
||||
if len(stmt.Fields) > 0 {
|
||||
s = s.Pipe(stream.IterRename(stmt.Fields...))
|
||||
}
|
||||
|
||||
s = s.Pipe(stream.TableInsert(stmt.TableName))
|
||||
}
|
||||
|
||||
if len(stmt.Returning) > 0 {
|
||||
s = s.Pipe(stream.Project(stmt.Returning...))
|
||||
}
|
||||
|
||||
return &StreamStmt{
|
||||
Stream: s,
|
||||
ReadOnly: false,
|
||||
}, nil
|
||||
}
|
||||
@@ -23,6 +23,7 @@ func TestInsertStmt(t *testing.T) {
|
||||
{"Values / Invalid params", "INSERT INTO test (a, b, c) VALUES ('d', ?)", true, "", []interface{}{'e'}},
|
||||
{"Documents / Named Params", "INSERT INTO test VALUES {a: $a, b: 2.3, c: $c}", false, `[{"pk()":1,"a":1,"b":2.3,"c":true}]`, []interface{}{sql.Named("c", true), sql.Named("a", 1)}},
|
||||
{"Documents / List ", "INSERT INTO test VALUES {a: [1, 2, 3]}", false, `[{"pk()":1,"a":[1,2,3]}]`, nil},
|
||||
{"Select / same table", "INSERT INTO test SELECT * FROM test", true, ``, nil},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
|
||||
188
internal/query/select.go
Normal file
188
internal/query/select.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/genjidb/genji/internal/database"
|
||||
"github.com/genjidb/genji/internal/expr"
|
||||
"github.com/genjidb/genji/internal/sql/scanner"
|
||||
"github.com/genjidb/genji/internal/stream"
|
||||
"github.com/genjidb/genji/internal/stringutil"
|
||||
)
|
||||
|
||||
// SelectStmt holds SELECT configuration.
|
||||
type SelectStmt struct {
|
||||
TableName string
|
||||
Distinct bool
|
||||
WhereExpr expr.Expr
|
||||
GroupByExpr expr.Expr
|
||||
OrderBy expr.Path
|
||||
OrderByDirection scanner.Token
|
||||
OffsetExpr expr.Expr
|
||||
LimitExpr expr.Expr
|
||||
ProjectionExprs []expr.Expr
|
||||
}
|
||||
|
||||
func (stmt *SelectStmt) Run(tx *database.Transaction, params []expr.Param) (Result, error) {
|
||||
var res Result
|
||||
|
||||
s, err := stmt.ToStream()
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
return s.Run(tx, params)
|
||||
}
|
||||
|
||||
func (stmt *SelectStmt) IsReadOnly() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (stmt *SelectStmt) ToStream() (*StreamStmt, error) {
|
||||
var s *stream.Stream
|
||||
|
||||
if stmt.TableName != "" {
|
||||
s = stream.New(stream.SeqScan(stmt.TableName))
|
||||
}
|
||||
|
||||
if stmt.WhereExpr != nil {
|
||||
s = s.Pipe(stream.Filter(stmt.WhereExpr))
|
||||
}
|
||||
|
||||
// when using GROUP BY, only aggregation functions or GroupByExpr can be selected
|
||||
if stmt.GroupByExpr != nil {
|
||||
// add Group node
|
||||
s = s.Pipe(stream.GroupBy(stmt.GroupByExpr))
|
||||
|
||||
var invalidProjectedField expr.Expr
|
||||
var aggregators []expr.AggregatorBuilder
|
||||
|
||||
for _, pe := range stmt.ProjectionExprs {
|
||||
ne, ok := pe.(*expr.NamedExpr)
|
||||
if !ok {
|
||||
invalidProjectedField = pe
|
||||
break
|
||||
}
|
||||
e := ne.Expr
|
||||
|
||||
// check if the projected expression is an aggregation function
|
||||
if agg, ok := e.(expr.AggregatorBuilder); ok {
|
||||
aggregators = append(aggregators, agg)
|
||||
continue
|
||||
}
|
||||
|
||||
// check if this is the same expression as the one used in the GROUP BY clause
|
||||
if expr.Equal(e, stmt.GroupByExpr) {
|
||||
continue
|
||||
}
|
||||
|
||||
// otherwise it's an error
|
||||
invalidProjectedField = ne
|
||||
break
|
||||
}
|
||||
|
||||
if invalidProjectedField != nil {
|
||||
return nil, stringutil.Errorf("field %q must appear in the GROUP BY clause or be used in an aggregate function", invalidProjectedField)
|
||||
}
|
||||
|
||||
// add Aggregation node
|
||||
s = s.Pipe(stream.HashAggregate(aggregators...))
|
||||
} else {
|
||||
// if there is no GROUP BY clause, check if there are any aggregation function
|
||||
// and if so add an aggregation node
|
||||
var aggregators []expr.AggregatorBuilder
|
||||
|
||||
for _, pe := range stmt.ProjectionExprs {
|
||||
ne, ok := pe.(*expr.NamedExpr)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
e := ne.Expr
|
||||
|
||||
// check if the projected expression is an aggregation function
|
||||
if agg, ok := e.(expr.AggregatorBuilder); ok {
|
||||
aggregators = append(aggregators, agg)
|
||||
}
|
||||
}
|
||||
|
||||
// add Aggregation node
|
||||
if len(aggregators) > 0 {
|
||||
s = s.Pipe(stream.HashAggregate(aggregators...))
|
||||
}
|
||||
}
|
||||
|
||||
// If there is no FROM clause ensure there is no wildcard or path
|
||||
if stmt.TableName == "" {
|
||||
var err error
|
||||
|
||||
for _, e := range stmt.ProjectionExprs {
|
||||
expr.Walk(e, func(e expr.Expr) bool {
|
||||
switch e.(type) {
|
||||
case expr.Path, expr.Wildcard:
|
||||
err = errors.New("no tables specified")
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s = s.Pipe(stream.Project(stmt.ProjectionExprs...))
|
||||
|
||||
if stmt.Distinct {
|
||||
s = s.Pipe(stream.Distinct())
|
||||
}
|
||||
|
||||
if stmt.OrderBy != nil {
|
||||
if stmt.OrderByDirection == scanner.DESC {
|
||||
s = s.Pipe(stream.SortReverse(stmt.OrderBy))
|
||||
} else {
|
||||
s = s.Pipe(stream.Sort(stmt.OrderBy))
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.OffsetExpr != nil {
|
||||
v, err := stmt.OffsetExpr.Eval(&expr.Environment{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !v.Type.IsNumber() {
|
||||
return nil, stringutil.Errorf("offset expression must evaluate to a number, got %q", v.Type)
|
||||
}
|
||||
|
||||
v, err = v.CastAsInteger()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s = s.Pipe(stream.Skip(v.V.(int64)))
|
||||
}
|
||||
|
||||
if stmt.LimitExpr != nil {
|
||||
v, err := stmt.LimitExpr.Eval(&expr.Environment{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !v.Type.IsNumber() {
|
||||
return nil, stringutil.Errorf("limit expression must evaluate to a number, got %q", v.Type)
|
||||
}
|
||||
|
||||
v, err = v.CastAsInteger()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s = s.Pipe(stream.Take(v.V.(int64)))
|
||||
}
|
||||
|
||||
return &StreamStmt{
|
||||
Stream: s,
|
||||
ReadOnly: true,
|
||||
}, nil
|
||||
}
|
||||
@@ -57,6 +57,9 @@ func TestSelectStmt(t *testing.T) {
|
||||
{"With group by and count", "SELECT COUNT(k) FROM test GROUP BY size", false, `[{"COUNT(k)":2},{"COUNT(k)":1}]`, nil},
|
||||
{"With group by and count wildcard", "SELECT COUNT(* ) FROM test GROUP BY size", false, `[{"COUNT(*)":2},{"COUNT(*)":1}]`, nil},
|
||||
{"With order by", "SELECT * FROM test ORDER BY color", false, `[{"k":3,"height":100,"weight":200},{"k":2,"color":"blue","size":10,"weight":100},{"k":1,"color":"red","size":10,"shape":"square"}]`, nil},
|
||||
{"With invalid group by / wildcard", "SELECT * FROM test WHERE age = 10 GROUP BY a.b.c", true, ``, nil},
|
||||
{"With invalid group by / a.b", "SELECT a.b FROM test WHERE age = 10 GROUP BY a.b.c", true, ``, nil},
|
||||
{"With order by", "SELECT * FROM test ORDER BY color", false, `[{"k":3,"height":100,"weight":200},{"k":2,"color":"blue","size":10,"weight":100},{"k":1,"color":"red","size":10,"shape":"square"}]`, nil},
|
||||
{"With order by asc", "SELECT * FROM test ORDER BY color ASC", false, `[{"k":3,"height":100,"weight":200},{"k":2,"color":"blue","size":10,"weight":100},{"k":1,"color":"red","size":10,"shape":"square"}]`, nil},
|
||||
{"With order by asc numeric", "SELECT * FROM test ORDER BY weight ASC", false, `[{"k":1,"color":"red","size":10,"shape":"square"},{"k":2,"color":"blue","size":10,"weight":100},{"k":3,"height":100,"weight":200}]`, nil},
|
||||
{"With order by asc with limit 2", "SELECT * FROM test ORDER BY color LIMIT 2", false, `[{"k":3,"height":100,"weight":200},{"k":2,"color":"blue","size":10,"weight":100}]`, nil},
|
||||
@@ -94,6 +97,11 @@ func TestSelectStmt(t *testing.T) {
|
||||
{"With two non existing idents, !=", "SELECT * FROM test WHERE z != y", false, `[]`, nil},
|
||||
// See issue https://github.com/genjidb/genji/issues/283
|
||||
{"With empty WHERE and IN", "SELECT * FROM test WHERE [] IN [];", false, `[]`, nil},
|
||||
{"Invalid use of MIN() aggregator", "SELECT * FROM test LIMIT min(0)", true, ``, nil},
|
||||
{"Invalid use of COUNT() aggregator", "SELECT * FROM test OFFSET x(*)", true, ``, nil},
|
||||
{"Invalid use of MAX() aggregator", "SELECT * FROM test LIMIT max(0)", true, ``, nil},
|
||||
{"Invalid use of SUM() aggregator", "SELECT * FROM test LIMIT sum(0)", true, ``, nil},
|
||||
{"Invalid use of AVG() aggregator", "SELECT * FROM test LIMIT avg(0)", true, ``, nil},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
|
||||
65
internal/query/update.go
Normal file
65
internal/query/update.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"github.com/genjidb/genji/document"
|
||||
"github.com/genjidb/genji/internal/database"
|
||||
"github.com/genjidb/genji/internal/expr"
|
||||
"github.com/genjidb/genji/internal/stream"
|
||||
)
|
||||
|
||||
// UpdateConfig holds UPDATE configuration.
|
||||
type UpdateStmt struct {
|
||||
TableName string
|
||||
|
||||
// SetPairs is used along with the Set clause. It holds
|
||||
// each path with its corresponding value that
|
||||
// should be set in the document.
|
||||
SetPairs []UpdateSetPair
|
||||
|
||||
// UnsetFields is used along with the Unset clause. It holds
|
||||
// each path that should be unset from the document.
|
||||
UnsetFields []string
|
||||
|
||||
WhereExpr expr.Expr
|
||||
}
|
||||
|
||||
type UpdateSetPair struct {
|
||||
Path document.Path
|
||||
E expr.Expr
|
||||
}
|
||||
|
||||
func (stmt *UpdateStmt) Run(tx *database.Transaction, params []expr.Param) (Result, error) {
|
||||
s := stmt.ToStream()
|
||||
|
||||
return s.Run(tx, params)
|
||||
}
|
||||
|
||||
func (stmt *UpdateStmt) IsReadOnly() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// ToTree turns the statement into a stream.
|
||||
func (stmt *UpdateStmt) ToStream() *StreamStmt {
|
||||
s := stream.New(stream.SeqScan(stmt.TableName))
|
||||
|
||||
if stmt.WhereExpr != nil {
|
||||
s = s.Pipe(stream.Filter(stmt.WhereExpr))
|
||||
}
|
||||
|
||||
if stmt.SetPairs != nil {
|
||||
for _, pair := range stmt.SetPairs {
|
||||
s = s.Pipe(stream.Set(pair.Path, pair.E))
|
||||
}
|
||||
} else if stmt.UnsetFields != nil {
|
||||
for _, name := range stmt.UnsetFields {
|
||||
s = s.Pipe(stream.Unset(name))
|
||||
}
|
||||
}
|
||||
|
||||
s = s.Pipe(stream.TableReplace(stmt.TableName))
|
||||
|
||||
return &StreamStmt{
|
||||
Stream: s,
|
||||
ReadOnly: false,
|
||||
}
|
||||
}
|
||||
@@ -1,17 +1,14 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"github.com/genjidb/genji/internal/expr"
|
||||
"github.com/genjidb/genji/internal/query"
|
||||
"github.com/genjidb/genji/internal/sql/scanner"
|
||||
"github.com/genjidb/genji/internal/stream"
|
||||
"github.com/genjidb/genji/internal/stringutil"
|
||||
)
|
||||
|
||||
// parseDeleteStatement parses a delete string and returns a Statement AST object.
|
||||
// This function assumes the DELETE token has already been consumed.
|
||||
func (p *Parser) parseDeleteStatement() (*query.StreamStmt, error) {
|
||||
var cfg deleteConfig
|
||||
func (p *Parser) parseDeleteStatement() (*query.DeleteStmt, error) {
|
||||
var stmt query.DeleteStmt
|
||||
var err error
|
||||
|
||||
// Parse "FROM".
|
||||
@@ -20,7 +17,7 @@ func (p *Parser) parseDeleteStatement() (*query.StreamStmt, error) {
|
||||
}
|
||||
|
||||
// Parse table name
|
||||
cfg.TableName, err = p.parseIdent()
|
||||
stmt.TableName, err = p.parseIdent()
|
||||
if err != nil {
|
||||
pErr := err.(*ParseError)
|
||||
pErr.Expected = []string{"table_name"}
|
||||
@@ -28,97 +25,28 @@ func (p *Parser) parseDeleteStatement() (*query.StreamStmt, error) {
|
||||
}
|
||||
|
||||
// Parse condition: "WHERE EXPR".
|
||||
cfg.WhereExpr, err = p.parseCondition()
|
||||
stmt.WhereExpr, err = p.parseCondition()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse order by: "ORDER BY path [ASC|DESC]?"
|
||||
cfg.OrderBy, cfg.OrderByDirection, err = p.parseOrderBy()
|
||||
stmt.OrderBy, stmt.OrderByDirection, err = p.parseOrderBy()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse limit: "LIMIT expr"
|
||||
cfg.LimitExpr, err = p.parseLimit()
|
||||
stmt.LimitExpr, err = p.parseLimit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse offset: "OFFSET expr"
|
||||
cfg.OffsetExpr, err = p.parseOffset()
|
||||
stmt.OffsetExpr, err = p.parseOffset()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cfg.ToStream()
|
||||
}
|
||||
|
||||
// DeleteConfig holds DELETE configuration.
|
||||
type deleteConfig struct {
|
||||
TableName string
|
||||
WhereExpr expr.Expr
|
||||
OffsetExpr expr.Expr
|
||||
OrderBy expr.Path
|
||||
LimitExpr expr.Expr
|
||||
OrderByDirection scanner.Token
|
||||
}
|
||||
|
||||
func (cfg deleteConfig) ToStream() (*query.StreamStmt, error) {
|
||||
s := stream.New(stream.SeqScan(cfg.TableName))
|
||||
|
||||
if cfg.WhereExpr != nil {
|
||||
s = s.Pipe(stream.Filter(cfg.WhereExpr))
|
||||
}
|
||||
|
||||
if cfg.OrderBy != nil {
|
||||
if cfg.OrderByDirection == scanner.DESC {
|
||||
s = s.Pipe(stream.SortReverse(cfg.OrderBy))
|
||||
} else {
|
||||
s = s.Pipe(stream.Sort(cfg.OrderBy))
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.OffsetExpr != nil {
|
||||
v, err := cfg.OffsetExpr.Eval(&expr.Environment{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !v.Type.IsNumber() {
|
||||
return nil, stringutil.Errorf("offset expression must evaluate to a number, got %q", v.Type)
|
||||
}
|
||||
|
||||
v, err = v.CastAsInteger()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s = s.Pipe(stream.Skip(v.V.(int64)))
|
||||
}
|
||||
|
||||
if cfg.LimitExpr != nil {
|
||||
v, err := cfg.LimitExpr.Eval(&expr.Environment{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !v.Type.IsNumber() {
|
||||
return nil, stringutil.Errorf("limit expression must evaluate to a number, got %q", v.Type)
|
||||
}
|
||||
|
||||
v, err = v.CastAsInteger()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s = s.Pipe(stream.Take(v.V.(int64)))
|
||||
}
|
||||
|
||||
s = s.Pipe(stream.TableDelete(cfg.TableName))
|
||||
|
||||
return &query.StreamStmt{
|
||||
Stream: s,
|
||||
ReadOnly: false,
|
||||
}, nil
|
||||
return &stmt, nil
|
||||
}
|
||||
|
||||
@@ -54,7 +54,9 @@ func TestParserDelete(t *testing.T) {
|
||||
q, err := parser.ParseQuery(test.s)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, q.Statements, 1)
|
||||
require.EqualValues(t, &query.StreamStmt{Stream: test.expected}, q.Statements[0])
|
||||
stmt, err := q.Statements[0].(*query.DeleteStmt).ToStream()
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, &query.StreamStmt{Stream: test.expected}, stmt)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"github.com/genjidb/genji/internal/expr"
|
||||
"github.com/genjidb/genji/internal/query"
|
||||
"github.com/genjidb/genji/internal/sql/parser"
|
||||
"github.com/genjidb/genji/internal/stream"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -17,7 +16,10 @@ func TestParserExplain(t *testing.T) {
|
||||
expected query.Statement
|
||||
errored bool
|
||||
}{
|
||||
{"Explain create table", "EXPLAIN SELECT * FROM test", &query.ExplainStmt{Statement: &query.StreamStmt{Stream: stream.New(stream.SeqScan("test")).Pipe(stream.Project(expr.Wildcard{})), ReadOnly: true}}, false},
|
||||
{"Explain create table", "EXPLAIN SELECT * FROM test", &query.ExplainStmt{Statement: &query.SelectStmt{
|
||||
TableName: "test",
|
||||
ProjectionExprs: []expr.Expr{expr.Wildcard{}},
|
||||
}}, false},
|
||||
{"Multiple Explains", "EXPLAIN EXPLAIN CREATE TABLE test", nil, true},
|
||||
}
|
||||
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/genjidb/genji/internal/expr"
|
||||
"github.com/genjidb/genji/internal/query"
|
||||
"github.com/genjidb/genji/internal/sql/scanner"
|
||||
"github.com/genjidb/genji/internal/stream"
|
||||
"github.com/genjidb/genji/internal/stringutil"
|
||||
)
|
||||
|
||||
// parseInsertStatement parses an insert string and returns a Statement AST object.
|
||||
// This function assumes the INSERT token has already been consumed.
|
||||
func (p *Parser) parseInsertStatement() (*query.StreamStmt, error) {
|
||||
var cfg insertConfig
|
||||
func (p *Parser) parseInsertStatement() (*query.InsertStmt, error) {
|
||||
var stmt query.InsertStmt
|
||||
var err error
|
||||
|
||||
// Parse "INTO".
|
||||
@@ -22,7 +19,7 @@ func (p *Parser) parseInsertStatement() (*query.StreamStmt, error) {
|
||||
}
|
||||
|
||||
// Parse table name
|
||||
cfg.TableName, err = p.parseIdent()
|
||||
stmt.TableName, err = p.parseIdent()
|
||||
if err != nil {
|
||||
pErr := err.(*ParseError)
|
||||
pErr.Expected = []string{"table_name"}
|
||||
@@ -30,7 +27,7 @@ func (p *Parser) parseInsertStatement() (*query.StreamStmt, error) {
|
||||
}
|
||||
|
||||
// Parse path list: (a, b, c)
|
||||
cfg.Fields, err = p.parseFieldList()
|
||||
stmt.Fields, err = p.parseFieldList()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -40,12 +37,12 @@ func (p *Parser) parseInsertStatement() (*query.StreamStmt, error) {
|
||||
switch tok {
|
||||
case scanner.VALUES:
|
||||
// Parse VALUES (v1, v2, v3)
|
||||
cfg.Values, err = p.parseValues(cfg.Fields)
|
||||
stmt.Values, err = p.parseValues(stmt.Fields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case scanner.SELECT:
|
||||
cfg.SelectStmt, err = p.parseSelectStatement()
|
||||
stmt.SelectStmt, err = p.parseSelectStatement()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -53,12 +50,12 @@ func (p *Parser) parseInsertStatement() (*query.StreamStmt, error) {
|
||||
return nil, newParseError(scanner.Tokstr(tok, lit), []string{"VALUES", "SELECT"}, pos)
|
||||
}
|
||||
|
||||
cfg.Returning, err = p.parseReturning()
|
||||
stmt.Returning, err = p.parseReturning()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cfg.ToStream()
|
||||
return &stmt, nil
|
||||
}
|
||||
|
||||
// parseFieldList parses a list of fields in the form: (path, path, ...), if exists.
|
||||
@@ -202,43 +199,3 @@ func (p *Parser) parseReturning() ([]expr.Expr, error) {
|
||||
|
||||
return p.parseProjectedExprs()
|
||||
}
|
||||
|
||||
// insertConfig holds INSERT configuration.
|
||||
type insertConfig struct {
|
||||
TableName string
|
||||
Values []expr.Expr
|
||||
Fields []string
|
||||
SelectStmt *query.StreamStmt
|
||||
Returning []expr.Expr
|
||||
}
|
||||
|
||||
func (cfg *insertConfig) ToStream() (*query.StreamStmt, error) {
|
||||
var s *stream.Stream
|
||||
if cfg.Values != nil {
|
||||
s = stream.New(stream.Expressions(cfg.Values...))
|
||||
|
||||
s = s.Pipe(stream.TableInsert(cfg.TableName))
|
||||
} else {
|
||||
s = cfg.SelectStmt.Stream
|
||||
|
||||
// ensure we are not reading and writing to the same table.
|
||||
if s.First().(*stream.SeqScanOperator).TableName == cfg.TableName {
|
||||
return nil, errors.New("cannot read and write to the same table")
|
||||
}
|
||||
|
||||
if len(cfg.Fields) > 0 {
|
||||
s = s.Pipe(stream.IterRename(cfg.Fields...))
|
||||
}
|
||||
|
||||
s = s.Pipe(stream.TableInsert(cfg.TableName))
|
||||
}
|
||||
|
||||
if len(cfg.Returning) > 0 {
|
||||
s = s.Pipe(stream.Project(cfg.Returning...))
|
||||
}
|
||||
|
||||
return &query.StreamStmt{
|
||||
Stream: s,
|
||||
ReadOnly: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -87,8 +87,6 @@ func TestParserInsert(t *testing.T) {
|
||||
nil, true},
|
||||
{"Values / Without fields / Wrong values", "INSERT INTO test VALUES {a: 1}, ('e', 'f')",
|
||||
nil, true},
|
||||
{"Select / same table", "INSERT INTO test SELECT * FROM test",
|
||||
nil, true},
|
||||
{"Select / Without fields", "INSERT INTO test SELECT * FROM foo",
|
||||
stream.New(stream.SeqScan("foo")).
|
||||
Pipe(stream.Project(expr.Wildcard{})).
|
||||
@@ -135,9 +133,11 @@ func TestParserInsert(t *testing.T) {
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Len(t, q.Statements, 1)
|
||||
stmt := q.Statements[0].(*query.StreamStmt)
|
||||
require.False(t, stmt.ReadOnly)
|
||||
require.EqualValues(t, test.expected.String(), stmt.Stream.String())
|
||||
stmt := q.Statements[0].(*query.InsertStmt)
|
||||
require.False(t, stmt.IsReadOnly())
|
||||
ss, err := stmt.ToStream()
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, test.expected.String(), ss.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"github.com/genjidb/genji/internal/expr"
|
||||
"github.com/genjidb/genji/internal/query"
|
||||
"github.com/genjidb/genji/internal/sql/parser"
|
||||
"github.com/genjidb/genji/internal/stream"
|
||||
)
|
||||
|
||||
func TestParserMultiStatement(t *testing.T) {
|
||||
@@ -19,12 +18,12 @@ func TestParserMultiStatement(t *testing.T) {
|
||||
}{
|
||||
{"OnlyCommas", ";;;", nil},
|
||||
{"TrailingComma", "SELECT * FROM foo;;;DELETE FROM foo;", []query.Statement{
|
||||
&query.StreamStmt{
|
||||
Stream: stream.New(stream.SeqScan("foo")).Pipe(stream.Project(expr.Wildcard{})),
|
||||
ReadOnly: true,
|
||||
&query.SelectStmt{
|
||||
TableName: "foo",
|
||||
ProjectionExprs: []expr.Expr{expr.Wildcard{}},
|
||||
},
|
||||
&query.StreamStmt{
|
||||
Stream: stream.New(stream.SeqScan("foo")).Pipe(stream.TableDelete("foo")),
|
||||
&query.DeleteStmt{
|
||||
TableName: "foo",
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
@@ -1,73 +1,69 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/genjidb/genji/internal/expr"
|
||||
"github.com/genjidb/genji/internal/query"
|
||||
"github.com/genjidb/genji/internal/sql/scanner"
|
||||
"github.com/genjidb/genji/internal/stream"
|
||||
"github.com/genjidb/genji/internal/stringutil"
|
||||
)
|
||||
|
||||
// parseSelectStatement parses a select string and returns a Statement AST object.
|
||||
// This function assumes the SELECT token has already been consumed.
|
||||
func (p *Parser) parseSelectStatement() (*query.StreamStmt, error) {
|
||||
var cfg selectConfig
|
||||
func (p *Parser) parseSelectStatement() (*query.SelectStmt, error) {
|
||||
var stmt query.SelectStmt
|
||||
var err error
|
||||
|
||||
cfg.Distinct, err = p.parseDistinct()
|
||||
stmt.Distinct, err = p.parseDistinct()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse path list or query.Wildcard
|
||||
cfg.ProjectionExprs, err = p.parseProjectedExprs()
|
||||
stmt.ProjectionExprs, err = p.parseProjectedExprs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse "FROM".
|
||||
var found bool
|
||||
cfg.TableName, found, err = p.parseFrom()
|
||||
stmt.TableName, found, err = p.parseFrom()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !found {
|
||||
return cfg.ToStream()
|
||||
return &stmt, nil
|
||||
}
|
||||
|
||||
// Parse condition: "WHERE expr".
|
||||
cfg.WhereExpr, err = p.parseCondition()
|
||||
stmt.WhereExpr, err = p.parseCondition()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse group by: "GROUP BY expr"
|
||||
cfg.GroupByExpr, err = p.parseGroupBy()
|
||||
stmt.GroupByExpr, err = p.parseGroupBy()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse order by: "ORDER BY path [ASC|DESC]?"
|
||||
cfg.OrderBy, cfg.OrderByDirection, err = p.parseOrderBy()
|
||||
stmt.OrderBy, stmt.OrderByDirection, err = p.parseOrderBy()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse limit: "LIMIT expr"
|
||||
cfg.LimitExpr, err = p.parseLimit()
|
||||
stmt.LimitExpr, err = p.parseLimit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse offset: "OFFSET expr"
|
||||
cfg.OffsetExpr, err = p.parseOffset()
|
||||
stmt.OffsetExpr, err = p.parseOffset()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cfg.ToStream()
|
||||
return &stmt, nil
|
||||
}
|
||||
|
||||
// parseProjectedExprs parses the list of projected fields.
|
||||
@@ -158,165 +154,3 @@ func (p *Parser) parseGroupBy() (expr.Expr, error) {
|
||||
e, err := p.ParseExpr()
|
||||
return e, err
|
||||
}
|
||||
|
||||
// SelectConfig holds SELECT configuration.
|
||||
type selectConfig struct {
|
||||
TableName string
|
||||
Distinct bool
|
||||
WhereExpr expr.Expr
|
||||
GroupByExpr expr.Expr
|
||||
OrderBy expr.Path
|
||||
OrderByDirection scanner.Token
|
||||
OffsetExpr expr.Expr
|
||||
LimitExpr expr.Expr
|
||||
ProjectionExprs []expr.Expr
|
||||
}
|
||||
|
||||
func (cfg selectConfig) ToStream() (*query.StreamStmt, error) {
|
||||
var s *stream.Stream
|
||||
|
||||
if cfg.TableName != "" {
|
||||
s = stream.New(stream.SeqScan(cfg.TableName))
|
||||
}
|
||||
|
||||
if cfg.WhereExpr != nil {
|
||||
s = s.Pipe(stream.Filter(cfg.WhereExpr))
|
||||
}
|
||||
|
||||
// when using GROUP BY, only aggregation functions or GroupByExpr can be selected
|
||||
if cfg.GroupByExpr != nil {
|
||||
// add Group node
|
||||
s = s.Pipe(stream.GroupBy(cfg.GroupByExpr))
|
||||
|
||||
var invalidProjectedField expr.Expr
|
||||
var aggregators []expr.AggregatorBuilder
|
||||
|
||||
for _, pe := range cfg.ProjectionExprs {
|
||||
ne, ok := pe.(*expr.NamedExpr)
|
||||
if !ok {
|
||||
invalidProjectedField = pe
|
||||
break
|
||||
}
|
||||
e := ne.Expr
|
||||
|
||||
// check if the projected expression is an aggregation function
|
||||
if agg, ok := e.(expr.AggregatorBuilder); ok {
|
||||
aggregators = append(aggregators, agg)
|
||||
continue
|
||||
}
|
||||
|
||||
// check if this is the same expression as the one used in the GROUP BY clause
|
||||
if expr.Equal(e, cfg.GroupByExpr) {
|
||||
continue
|
||||
}
|
||||
|
||||
// otherwise it's an error
|
||||
invalidProjectedField = ne
|
||||
break
|
||||
}
|
||||
|
||||
if invalidProjectedField != nil {
|
||||
return nil, stringutil.Errorf("field %q must appear in the GROUP BY clause or be used in an aggregate function", invalidProjectedField)
|
||||
}
|
||||
|
||||
// add Aggregation node
|
||||
s = s.Pipe(stream.HashAggregate(aggregators...))
|
||||
} else {
|
||||
// if there is no GROUP BY clause, check if there are any aggregation function
|
||||
// and if so add an aggregation node
|
||||
var aggregators []expr.AggregatorBuilder
|
||||
|
||||
for _, pe := range cfg.ProjectionExprs {
|
||||
ne, ok := pe.(*expr.NamedExpr)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
e := ne.Expr
|
||||
|
||||
// check if the projected expression is an aggregation function
|
||||
if agg, ok := e.(expr.AggregatorBuilder); ok {
|
||||
aggregators = append(aggregators, agg)
|
||||
}
|
||||
}
|
||||
|
||||
// add Aggregation node
|
||||
if len(aggregators) > 0 {
|
||||
s = s.Pipe(stream.HashAggregate(aggregators...))
|
||||
}
|
||||
}
|
||||
|
||||
// If there is no FROM clause ensure there is no wildcard or path
|
||||
if cfg.TableName == "" {
|
||||
var err error
|
||||
|
||||
for _, e := range cfg.ProjectionExprs {
|
||||
expr.Walk(e, func(e expr.Expr) bool {
|
||||
switch e.(type) {
|
||||
case expr.Path, expr.Wildcard:
|
||||
err = errors.New("no tables specified")
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s = s.Pipe(stream.Project(cfg.ProjectionExprs...))
|
||||
|
||||
if cfg.Distinct {
|
||||
s = s.Pipe(stream.Distinct())
|
||||
}
|
||||
|
||||
if cfg.OrderBy != nil {
|
||||
if cfg.OrderByDirection == scanner.DESC {
|
||||
s = s.Pipe(stream.SortReverse(cfg.OrderBy))
|
||||
} else {
|
||||
s = s.Pipe(stream.Sort(cfg.OrderBy))
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.OffsetExpr != nil {
|
||||
v, err := cfg.OffsetExpr.Eval(&expr.Environment{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !v.Type.IsNumber() {
|
||||
return nil, stringutil.Errorf("offset expression must evaluate to a number, got %q", v.Type)
|
||||
}
|
||||
|
||||
v, err = v.CastAsInteger()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s = s.Pipe(stream.Skip(v.V.(int64)))
|
||||
}
|
||||
|
||||
if cfg.LimitExpr != nil {
|
||||
v, err := cfg.LimitExpr.Eval(&expr.Environment{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !v.Type.IsNumber() {
|
||||
return nil, stringutil.Errorf("limit expression must evaluate to a number, got %q", v.Type)
|
||||
}
|
||||
|
||||
v, err = v.CastAsInteger()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s = s.Pipe(stream.Take(v.V.(int64)))
|
||||
}
|
||||
|
||||
return &query.StreamStmt{
|
||||
Stream: s,
|
||||
ReadOnly: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -22,15 +22,6 @@ func TestParserSelect(t *testing.T) {
|
||||
stream.New(stream.Project(testutil.ParseNamedExpr(t, "1"))),
|
||||
false,
|
||||
},
|
||||
{"NoTable/path", "SELECT a",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{"NoTable/wildcard", "SELECT *",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{"Wildcard with no FORM", "SELECT *", nil, true},
|
||||
{"NoTableWithTuple", "SELECT (1, 2)",
|
||||
stream.New(stream.Project(testutil.ParseNamedExpr(t, "[1, 2]"))),
|
||||
false,
|
||||
@@ -84,8 +75,6 @@ func TestParserSelect(t *testing.T) {
|
||||
Pipe(stream.Project(testutil.ParseNamedExpr(t, "a.b.c"))),
|
||||
false,
|
||||
},
|
||||
{"With Invalid GroupBy: Wildcard", "SELECT * FROM test WHERE age = 10 GROUP BY a.b.c", nil, true},
|
||||
{"With Invalid GroupBy: a.b", "SELECT a.b FROM test WHERE age = 10 GROUP BY a.b.c", nil, true},
|
||||
{"WithOrderBy", "SELECT * FROM test WHERE age = 10 ORDER BY a.b.c",
|
||||
stream.New(stream.SeqScan("test")).
|
||||
Pipe(stream.Filter(parser.MustParseExpr("age = 10"))).
|
||||
@@ -135,11 +124,6 @@ func TestParserSelect(t *testing.T) {
|
||||
Pipe(stream.HashAggregate(&expr.CountFunc{Wildcard: true})).
|
||||
Pipe(stream.Project(testutil.ParseNamedExpr(t, "COUNT(*)"))),
|
||||
false},
|
||||
{"Invalid use of MIN() aggregator", "SELECT * FROM test LIMIT min(0)", nil, true},
|
||||
{"Invalid use of COUNT() aggregator", "SELECT * FROM test OFFSET x(*)", nil, true},
|
||||
{"Invalid use of MAX() aggregator", "SELECT * FROM test LIMIT max(0)", nil, true},
|
||||
{"Invalid use of SUM() aggregator", "SELECT * FROM test LIMIT sum(0)", nil, true},
|
||||
{"Invalid use of AVG() aggregator", "SELECT * FROM test LIMIT avg(0)", nil, true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
@@ -148,7 +132,9 @@ func TestParserSelect(t *testing.T) {
|
||||
if !test.mustFail {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, q.Statements, 1)
|
||||
require.EqualValues(t, &query.StreamStmt{Stream: test.expected, ReadOnly: true}, q.Statements[0])
|
||||
st, err := q.Statements[0].(*query.SelectStmt).ToStream()
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, &query.StreamStmt{Stream: test.expected, ReadOnly: true}, st)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -1,21 +1,18 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"github.com/genjidb/genji/document"
|
||||
"github.com/genjidb/genji/internal/expr"
|
||||
"github.com/genjidb/genji/internal/query"
|
||||
"github.com/genjidb/genji/internal/sql/scanner"
|
||||
"github.com/genjidb/genji/internal/stream"
|
||||
)
|
||||
|
||||
// parseUpdateStatement parses a update string and returns a Statement AST object.
|
||||
// This function assumes the UPDATE token has already been consumed.
|
||||
func (p *Parser) parseUpdateStatement() (*query.StreamStmt, error) {
|
||||
var cfg updateConfig
|
||||
func (p *Parser) parseUpdateStatement() (*query.UpdateStmt, error) {
|
||||
var stmt query.UpdateStmt
|
||||
var err error
|
||||
|
||||
// Parse table name
|
||||
cfg.TableName, err = p.parseIdent()
|
||||
stmt.TableName, err = p.parseIdent()
|
||||
if err != nil {
|
||||
pErr := err.(*ParseError)
|
||||
pErr.Expected = []string{"table_name"}
|
||||
@@ -26,9 +23,9 @@ func (p *Parser) parseUpdateStatement() (*query.StreamStmt, error) {
|
||||
tok, pos, lit := p.ScanIgnoreWhitespace()
|
||||
switch tok {
|
||||
case scanner.SET:
|
||||
cfg.SetPairs, err = p.parseSetClause()
|
||||
stmt.SetPairs, err = p.parseSetClause()
|
||||
case scanner.UNSET:
|
||||
cfg.UnsetFields, err = p.parseUnsetClause()
|
||||
stmt.UnsetFields, err = p.parseUnsetClause()
|
||||
default:
|
||||
err = newParseError(scanner.Tokstr(tok, lit), []string{"SET", "UNSET"}, pos)
|
||||
}
|
||||
@@ -37,17 +34,17 @@ func (p *Parser) parseUpdateStatement() (*query.StreamStmt, error) {
|
||||
}
|
||||
|
||||
// Parse condition: "WHERE EXPR".
|
||||
cfg.WhereExpr, err = p.parseCondition()
|
||||
stmt.WhereExpr, err = p.parseCondition()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cfg.ToStream(), nil
|
||||
return &stmt, nil
|
||||
}
|
||||
|
||||
// parseSetClause parses the "SET" clause of the query.
|
||||
func (p *Parser) parseSetClause() ([]updateSetPair, error) {
|
||||
var pairs []updateSetPair
|
||||
func (p *Parser) parseSetClause() ([]query.UpdateSetPair, error) {
|
||||
var pairs []query.UpdateSetPair
|
||||
|
||||
firstPair := true
|
||||
for {
|
||||
@@ -78,7 +75,7 @@ func (p *Parser) parseSetClause() ([]updateSetPair, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pairs = append(pairs, updateSetPair{path, expr})
|
||||
pairs = append(pairs, query.UpdateSetPair{Path: path, E: expr})
|
||||
|
||||
firstPair = false
|
||||
}
|
||||
@@ -111,50 +108,3 @@ func (p *Parser) parseUnsetClause() ([]string, error) {
|
||||
}
|
||||
return fields, nil
|
||||
}
|
||||
|
||||
// UpdateConfig holds UPDATE configuration.
|
||||
type updateConfig struct {
|
||||
TableName string
|
||||
|
||||
// SetPairs is used along with the Set clause. It holds
|
||||
// each path with its corresponding value that
|
||||
// should be set in the document.
|
||||
SetPairs []updateSetPair
|
||||
|
||||
// UnsetFields is used along with the Unset clause. It holds
|
||||
// each path that should be unset from the document.
|
||||
UnsetFields []string
|
||||
|
||||
WhereExpr expr.Expr
|
||||
}
|
||||
|
||||
type updateSetPair struct {
|
||||
path document.Path
|
||||
e expr.Expr
|
||||
}
|
||||
|
||||
// ToTree turns the statement into a stream.
|
||||
func (cfg updateConfig) ToStream() *query.StreamStmt {
|
||||
s := stream.New(stream.SeqScan(cfg.TableName))
|
||||
|
||||
if cfg.WhereExpr != nil {
|
||||
s = s.Pipe(stream.Filter(cfg.WhereExpr))
|
||||
}
|
||||
|
||||
if cfg.SetPairs != nil {
|
||||
for _, pair := range cfg.SetPairs {
|
||||
s = s.Pipe(stream.Set(pair.path, pair.e))
|
||||
}
|
||||
} else if cfg.UnsetFields != nil {
|
||||
for _, name := range cfg.UnsetFields {
|
||||
s = s.Pipe(stream.Unset(name))
|
||||
}
|
||||
}
|
||||
|
||||
s = s.Pipe(stream.TableReplace(cfg.TableName))
|
||||
|
||||
return &query.StreamStmt{
|
||||
Stream: s,
|
||||
ReadOnly: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,9 +75,9 @@ func TestParserUpdate(t *testing.T) {
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, q.Statements, 1)
|
||||
stmt := q.Statements[0].(*query.StreamStmt)
|
||||
require.False(t, stmt.ReadOnly)
|
||||
require.EqualValues(t, test.expected, stmt.Stream)
|
||||
stmt := q.Statements[0].(*query.UpdateStmt)
|
||||
require.False(t, stmt.IsReadOnly())
|
||||
require.EqualValues(t, test.expected, stmt.ToStream().Stream)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user