expr: allow wildcard as an argument for all functions

This commit is contained in:
Asdine El Hrychy
2023-12-03 09:16:33 +04:00
parent edcb91d4c5
commit 31dbf28801
12 changed files with 52 additions and 39 deletions

View File

@@ -85,7 +85,7 @@ func newQueryInputModel(shell *Shell) queryInputModel {
ta.Cursor.SetMode(cursor.CursorStatic) ta.Cursor.SetMode(cursor.CursorStatic)
ta.MaxWidth = 0 ta.MaxWidth = 0
ta.SetHeight(1) ta.SetHeight(1)
ta.SetPromptFunc(7, func(lineIdx int) string { ta.SetPromptFunc(6, func(lineIdx int) string {
if lineIdx == 0 { if lineIdx == 0 {
return "chai> " return "chai> "
} }

View File

@@ -103,7 +103,7 @@ func (e *NamedExpr) Name() string {
} }
func (e *NamedExpr) String() string { func (e *NamedExpr) String() string {
return fmt.Sprintf("%s", e.Expr) return e.Expr.String()
} }
// A Function is an expression whose evaluation calls a function previously defined. // A Function is an expression whose evaluation calls a function previously defined.

View File

@@ -30,7 +30,7 @@ var builtinFunctions = Definitions{
name: "count", name: "count",
arity: 1, arity: 1,
constructorFn: func(args ...expr.Expr) (expr.Function, error) { constructorFn: func(args ...expr.Expr) (expr.Function, error) {
return &Count{Expr: args[0]}, nil return NewCount(args[0]), nil
}, },
}, },
"min": &definition{ "min": &definition{
@@ -189,10 +189,18 @@ var _ expr.AggregatorBuilder = (*Count)(nil)
// in a stream. // in a stream.
type Count struct { type Count struct {
Expr expr.Expr Expr expr.Expr
Wildcard bool wildcard bool
Count int64 Count int64
} }
func NewCount(e expr.Expr) *Count {
_, wc := e.(expr.Wildcard)
return &Count{
Expr: e,
wildcard: wc,
}
}
func (c *Count) Eval(env *environment.Environment) (types.Value, error) { func (c *Count) Eval(env *environment.Environment) (types.Value, error) {
d, ok := env.GetRow() d, ok := env.GetRow()
if !ok { if !ok {
@@ -214,20 +222,12 @@ func (c *Count) IsEqual(other expr.Expr) bool {
return false return false
} }
if c.Wildcard && o.Wildcard {
return c.Expr == nil && o.Expr == nil
}
return expr.Equal(c.Expr, o.Expr) return expr.Equal(c.Expr, o.Expr)
} }
func (c *Count) Params() []expr.Expr { return []expr.Expr{c.Expr} } func (c *Count) Params() []expr.Expr { return []expr.Expr{c.Expr} }
func (c *Count) String() string { func (c *Count) String() string {
if c.Wildcard {
return "COUNT(*)"
}
return fmt.Sprintf("COUNT(%v)", c.Expr) return fmt.Sprintf("COUNT(%v)", c.Expr)
} }
@@ -246,7 +246,7 @@ type CountAggregator struct {
// Aggregate increments the counter if the count expression evaluates to a non-null value. // Aggregate increments the counter if the count expression evaluates to a non-null value.
func (c *CountAggregator) Aggregate(env *environment.Environment) error { func (c *CountAggregator) Aggregate(env *environment.Environment) error {
if c.Fn.Wildcard { if c.Fn.wildcard {
c.Count++ c.Count++
return nil return nil
} }

View File

@@ -169,7 +169,7 @@ func (kvp *KVPairs) String() string {
if i > 0 { if i > 0 {
b.WriteString(", ") b.WriteString(", ")
} }
b.WriteString(fmt.Sprintf("%s", p)) b.WriteString(p.String())
} }
b.WriteRune('}') b.WriteRune('}')

View File

@@ -63,15 +63,20 @@ func (w Wildcard) String() string {
} }
func (w Wildcard) Eval(env *environment.Environment) (types.Value, error) { func (w Wildcard) Eval(env *environment.Environment) (types.Value, error) {
return nil, errors.New("no table specified") r, ok := env.GetRow()
if !ok {
return nil, errors.New("no table specified")
}
return types.NewObjectValue(r.Object()), nil
} }
// Iterate call the object iterate method. // Iterate call the object iterate method.
func (w Wildcard) Iterate(env environment.Environment, fn func(field string, value types.Value) error) error { func (w Wildcard) Iterate(env environment.Environment, fn func(field string, value types.Value) error) error {
d, ok := env.GetRow() r, ok := env.GetRow()
if !ok { if !ok {
return errors.New("no table specified") return errors.New("no table specified")
} }
return d.Iterate(fn) return r.Iterate(fn)
} }

View File

@@ -355,3 +355,7 @@ func NewFromCSV(headers, columns []string) types.Object {
return fb return fb
} }
func NewArrayFromSlice[T any](l []T) types.Array {
return &sliceArray{ref: reflect.ValueOf(l)}
}

View File

@@ -8,7 +8,6 @@ import (
"github.com/chaisql/chai/internal/environment" "github.com/chaisql/chai/internal/environment"
"github.com/chaisql/chai/internal/expr" "github.com/chaisql/chai/internal/expr"
"github.com/chaisql/chai/internal/expr/functions"
"github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/object"
"github.com/chaisql/chai/internal/sql/scanner" "github.com/chaisql/chai/internal/sql/scanner"
"github.com/chaisql/chai/internal/types" "github.com/chaisql/chai/internal/types"
@@ -292,6 +291,8 @@ func (p *Parser) parseUnaryExpr(allowed ...scanner.Token) (expr.Expr, error) {
return expr.LiteralValue{Value: types.NewBoolValue(tok == scanner.TRUE)}, nil return expr.LiteralValue{Value: types.NewBoolValue(tok == scanner.TRUE)}, nil
case scanner.NULL: case scanner.NULL:
return expr.LiteralValue{Value: types.NewNullValue()}, nil return expr.LiteralValue{Value: types.NewNullValue()}, nil
case scanner.MUL:
return expr.Wildcard{}, nil
case scanner.LBRACKET: case scanner.LBRACKET:
p.Unscan() p.Unscan()
e, err := p.ParseObject() e, err := p.ParseObject()
@@ -669,16 +670,6 @@ func (p *Parser) parseFunction() (expr.Expr, error) {
return nil, err return nil, err
} }
// Special case: If the function is COUNT, support the special case COUNT(*)
if tok, pos, lit := p.ScanIgnoreWhitespace(); tok == scanner.MUL {
if tok, _, _ := p.ScanIgnoreWhitespace(); tok != scanner.RPAREN {
return nil, newParseError(scanner.Tokstr(tok, lit), []string{")"}, pos)
}
return &functions.Count{Wildcard: true}, nil
}
p.Unscan()
// Check if the function is called without arguments. // Check if the function is called without arguments.
if tok, _, _ := p.ScanIgnoreWhitespace(); tok == scanner.RPAREN { if tok, _, _ := p.ScanIgnoreWhitespace(); tok == scanner.RPAREN {
def, err := p.packagesTable.GetFunc(pkgName, funcName) def, err := p.packagesTable.GetFunc(pkgName, funcName)

View File

@@ -170,8 +170,8 @@ func TestParserExpr(t *testing.T) {
// functions // functions
{"pk() function", "pk()", &functions.PK{}, false}, {"pk() function", "pk()", &functions.PK{}, false},
{"count(expr) function", "count(a)", &functions.Count{Expr: testutil.ParsePath(t, "a")}, false}, {"count(expr) function", "count(a)", &functions.Count{Expr: testutil.ParsePath(t, "a")}, false},
{"count(*) function", "count(*)", &functions.Count{Wildcard: true}, false}, {"count(*) function", "count(*)", functions.NewCount(expr.Wildcard{}), false},
{"count (*) function with spaces", "count (*)", &functions.Count{Wildcard: true}, false}, {"count (*) function with spaces", "count (*)", functions.NewCount(expr.Wildcard{}), false},
{"packaged function", "math.floor(1.2)", testutil.FunctionExpr(t, "math.floor", testutil.DoubleValue(1.2)), false}, {"packaged function", "math.floor(1.2)", testutil.FunctionExpr(t, "math.floor", testutil.DoubleValue(1.2)), false},
} }

View File

@@ -126,7 +126,7 @@ func TestParserSelect(t *testing.T) {
{"WithOffsetThenLimit", "SELECT * FROM test WHERE age = 10 OFFSET 20 LIMIT 10", nil, true, true}, {"WithOffsetThenLimit", "SELECT * FROM test WHERE age = 10 OFFSET 20 LIMIT 10", nil, true, true},
{"With aggregation function", "SELECT COUNT(*) FROM test", {"With aggregation function", "SELECT COUNT(*) FROM test",
stream.New(table.Scan("test")). stream.New(table.Scan("test")).
Pipe(rows.GroupAggregate(nil, &functions.Count{Wildcard: true})). Pipe(rows.GroupAggregate(nil, functions.NewCount(expr.Wildcard{}))).
Pipe(rows.Project(testutil.ParseNamedExpr(t, "COUNT(*)"))), Pipe(rows.Project(testutil.ParseNamedExpr(t, "COUNT(*)"))),
true, false}, true, false},
{"With NEXT VALUE FOR", "SELECT NEXT VALUE FOR foo FROM test", {"With NEXT VALUE FOR", "SELECT NEXT VALUE FOR foo FROM test",

View File

@@ -38,7 +38,7 @@ func TestAggregate(t *testing.T) {
{ {
"count", "count",
nil, nil,
[]expr.AggregatorBuilder{&functions.Count{Wildcard: true}}, []expr.AggregatorBuilder{functions.NewCount(expr.Wildcard{})},
[]types.Object{testutil.MakeObject(t, `{"a": 10}`)}, []types.Object{testutil.MakeObject(t, `{"a": 10}`)},
[]types.Object{testutil.MakeObject(t, `{"COUNT(*)": 1}`)}, []types.Object{testutil.MakeObject(t, `{"COUNT(*)": 1}`)},
false, false,

View File

@@ -172,22 +172,22 @@ func ExprRunner(t *testing.T, testfile string) {
t.Helper() t.Helper()
// parse the expected result // parse the expected result
e, err := parser.NewParser(strings.NewReader(stmt.Res)).ParseExpr() e, err := parser.NewParser(strings.NewReader(stmt.Res)).ParseExpr()
assert.NoErrorf(t, err, "parse error at %s:%d\n`%s`", testfile, stmt.ResLine, stmt.Res) assert.NoErrorf(t, err, "parse error at %s:%d\n`%s`: %v", testfile, stmt.ResLine, stmt.Res, err)
// eval it to get a proper Value // eval it to get a proper Value
want, err := e.Eval(environment.New(nil)) want, err := e.Eval(environment.New(nil))
assert.NoErrorf(t, err, "eval error at %s:%d\n`%s`", testfile, stmt.ResLine, stmt.Res) assert.NoErrorf(t, err, "eval error at %s:%d\n`%s`: %v", testfile, stmt.ResLine, stmt.Res, err)
// parse the given expr // parse the given expr
e, err = parser.NewParser(strings.NewReader(stmt.Expr)).ParseExpr() e, err = parser.NewParser(strings.NewReader(stmt.Expr)).ParseExpr()
assert.NoErrorf(t, err, "parse error at %s:%d\n`%s`", testfile, stmt.ExprLine, stmt.Expr) assert.NoErrorf(t, err, "parse error at %s:%d\n`%s`: %v", testfile, stmt.ExprLine, stmt.Expr, err)
// eval it to get a proper Value // eval it to get a proper Value
got, err := e.Eval(environment.New(nil)) got, err := e.Eval(environment.New(nil))
assert.NoErrorf(t, err, "eval error at %s:%d\n`%s`", testfile, stmt.ExprLine, stmt.Expr) assert.NoErrorf(t, err, "eval error at %s:%d\n`%s`: %v", testfile, stmt.ExprLine, stmt.Expr, err)
// finally, compare those two // finally, compare those two
require.Equalf(t, want, got, "assertion error at %s:%d", testfile, stmt.ResLine) RequireValueEqual(t, want, got, "assertion error at %s:%d", testfile, stmt.ResLine)
}) })
} else { } else {
t.Run("NOK "+stmt.Expr, func(t *testing.T) { t.Run("NOK "+stmt.Expr, func(t *testing.T) {

View File

@@ -18,7 +18,7 @@ import (
) )
// MakeValue turns v into a types.Value. // MakeValue turns v into a types.Value.
func MakeValue(t testing.TB, v interface{}) types.Value { func MakeValue(t testing.TB, v any) types.Value {
t.Helper() t.Helper()
vv, err := object.NewValue(v) vv, err := object.NewValue(v)
@@ -26,7 +26,7 @@ func MakeValue(t testing.TB, v interface{}) types.Value {
return vv return vv
} }
func MakeArrayValue(t testing.TB, vs ...interface{}) types.Value { func MakeArrayValue(t testing.TB, vs ...any) types.Value {
t.Helper() t.Helper()
vvs := []types.Value{} vvs := []types.Value{}
@@ -168,6 +168,19 @@ func RequireArrayEqual(t testing.TB, want, got types.Array) {
} }
} }
func RequireValueEqual(t testing.TB, want, got types.Value, msg string, args ...any) {
t.Helper()
tWant, err := types.MarshalTextIndent(want, "\n", " ")
require.NoError(t, err)
tGot, err := types.MarshalTextIndent(got, "\n", " ")
require.NoError(t, err)
if diff := cmp.Diff(string(tWant), string(tGot), cmp.Comparer(strings.EqualFold)); diff != "" {
require.Failf(t, "mismatched values, (-want, +got)", "%s\n%s", diff, fmt.Sprintf(msg, args...))
}
}
func CloneObject(t testing.TB, d types.Object) *object.FieldBuffer { func CloneObject(t testing.TB, d types.Object) *object.FieldBuffer {
t.Helper() t.Helper()