Files
chaisql/internal/query/statement/select.go
Jean Hadrien Chabran 4a6e68439a Refactor to handle errors with internal/errors (#432)
All new error handling code now rely on internal/errors package
which provides a compilation time toggle that enables to capture
stacktraces for easier debugging while developing.

It also comes with a new testutil/assert package which replaces the require
package when it comes to checking or comparing errors and printing the
stack traces if needed.

Finally, the test target of the Makefile uses the debug build tag by default. 
A testnodebug target is also provided for convenience and to make sure no
tests are broken due to not having used the internal/errors or testutil/assert package.

See #431 for more details
2021-08-22 11:47:54 +03:00

197 lines
4.6 KiB
Go

package statement
import (
"github.com/genjidb/genji/document"
"github.com/genjidb/genji/internal/environment"
"github.com/genjidb/genji/internal/errors"
"github.com/genjidb/genji/internal/expr"
"github.com/genjidb/genji/internal/sql/scanner"
"github.com/genjidb/genji/internal/stream"
"github.com/genjidb/genji/internal/stringutil"
)
// SelectStmt holds SELECT configuration.
type SelectStmt struct {
TableName string
Distinct bool
WhereExpr expr.Expr
GroupByExpr expr.Expr
OrderBy expr.Path
OrderByDirection scanner.Token
OffsetExpr expr.Expr
LimitExpr expr.Expr
ProjectionExprs []expr.Expr
Union struct {
All bool
SelectStmt *StreamStmt
}
}
func (stmt *SelectStmt) ToStream() (*StreamStmt, error) {
isReadOnly := true
var s *stream.Stream
if stmt.TableName != "" {
s = stream.New(stream.SeqScan(stmt.TableName))
}
if stmt.WhereExpr != nil {
s = s.Pipe(stream.Filter(stmt.WhereExpr))
}
// when using GROUP BY, only aggregation functions or GroupByExpr can be selected
if stmt.GroupByExpr != nil {
// add Group node
s = s.Pipe(stream.GroupBy(stmt.GroupByExpr))
var invalidProjectedField expr.Expr
var aggregators []expr.AggregatorBuilder
for _, pe := range stmt.ProjectionExprs {
ne, ok := pe.(*expr.NamedExpr)
if !ok {
invalidProjectedField = pe
break
}
e := ne.Expr
// check if the projected expression is an aggregation function
if agg, ok := e.(expr.AggregatorBuilder); ok {
aggregators = append(aggregators, agg)
continue
}
// check if this is the same expression as the one used in the GROUP BY clause
if expr.Equal(e, stmt.GroupByExpr) {
continue
}
// otherwise it's an error
invalidProjectedField = ne
break
}
if invalidProjectedField != nil {
return nil, stringutil.Errorf("field %q must appear in the GROUP BY clause or be used in an aggregate function", invalidProjectedField)
}
// add Aggregation node
s = s.Pipe(stream.HashAggregate(aggregators...))
} else {
// if there is no GROUP BY clause, check if there are any aggregation function
// and if so add an aggregation node
var aggregators []expr.AggregatorBuilder
for _, pe := range stmt.ProjectionExprs {
ne, ok := pe.(*expr.NamedExpr)
if !ok {
continue
}
e := ne.Expr
// check if the projected expression is an aggregation function
if agg, ok := e.(expr.AggregatorBuilder); ok {
aggregators = append(aggregators, agg)
}
}
// add Aggregation node
if len(aggregators) > 0 {
s = s.Pipe(stream.HashAggregate(aggregators...))
}
}
// If there is no FROM clause ensure there is no wildcard or path
if stmt.TableName == "" {
var err error
for _, e := range stmt.ProjectionExprs {
expr.Walk(e, func(e expr.Expr) bool {
switch e.(type) {
case expr.Path, expr.Wildcard:
err = errors.New("no tables specified")
return false
default:
return true
}
})
if err != nil {
return nil, err
}
}
}
s = s.Pipe(stream.Project(stmt.ProjectionExprs...))
if stmt.Distinct {
s = s.Pipe(stream.Distinct())
}
if stmt.OrderBy != nil {
if stmt.OrderByDirection == scanner.DESC {
s = s.Pipe(stream.SortReverse(stmt.OrderBy))
} else {
s = s.Pipe(stream.Sort(stmt.OrderBy))
}
}
if stmt.OffsetExpr != nil {
v, err := stmt.OffsetExpr.Eval(&environment.Environment{})
if err != nil {
return nil, err
}
if !v.Type().IsNumber() {
return nil, stringutil.Errorf("offset expression must evaluate to a number, got %q", v.Type())
}
v, err = document.CastAsInteger(v)
if err != nil {
return nil, err
}
s = s.Pipe(stream.Skip(v.V().(int64)))
}
if stmt.LimitExpr != nil {
v, err := stmt.LimitExpr.Eval(&environment.Environment{})
if err != nil {
return nil, err
}
if !v.Type().IsNumber() {
return nil, stringutil.Errorf("limit expression must evaluate to a number, got %q", v.Type())
}
v, err = document.CastAsInteger(v)
if err != nil {
return nil, err
}
s = s.Pipe(stream.Take(v.V().(int64)))
}
if stmt.Union.SelectStmt != nil {
s = stream.New(stream.Concat(s, stmt.Union.SelectStmt.Stream))
}
// SELECT is read-only most of the time, unless it's using some expressions
// that require write access and that are allowed to be run, such as NEXT VALUE FOR
for _, e := range stmt.ProjectionExprs {
expr.Walk(e, func(e expr.Expr) bool {
switch e.(type) {
case expr.NextValueFor:
isReadOnly = false
return false
default:
return true
}
})
}
return &StreamStmt{
Stream: s,
ReadOnly: isReadOnly,
}, nil
}