mirror of
https://github.com/chaisql/chai.git
synced 2025-11-02 11:44:02 +08:00
expr: allow wildcard as an argument for all functions
This commit is contained in:
@@ -85,7 +85,7 @@ func newQueryInputModel(shell *Shell) queryInputModel {
|
|||||||
ta.Cursor.SetMode(cursor.CursorStatic)
|
ta.Cursor.SetMode(cursor.CursorStatic)
|
||||||
ta.MaxWidth = 0
|
ta.MaxWidth = 0
|
||||||
ta.SetHeight(1)
|
ta.SetHeight(1)
|
||||||
ta.SetPromptFunc(7, func(lineIdx int) string {
|
ta.SetPromptFunc(6, func(lineIdx int) string {
|
||||||
if lineIdx == 0 {
|
if lineIdx == 0 {
|
||||||
return "chai> "
|
return "chai> "
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ func (e *NamedExpr) Name() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *NamedExpr) String() string {
|
func (e *NamedExpr) String() string {
|
||||||
return fmt.Sprintf("%s", e.Expr)
|
return e.Expr.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// A Function is an expression whose evaluation calls a function previously defined.
|
// A Function is an expression whose evaluation calls a function previously defined.
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ var builtinFunctions = Definitions{
|
|||||||
name: "count",
|
name: "count",
|
||||||
arity: 1,
|
arity: 1,
|
||||||
constructorFn: func(args ...expr.Expr) (expr.Function, error) {
|
constructorFn: func(args ...expr.Expr) (expr.Function, error) {
|
||||||
return &Count{Expr: args[0]}, nil
|
return NewCount(args[0]), nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"min": &definition{
|
"min": &definition{
|
||||||
@@ -189,10 +189,18 @@ var _ expr.AggregatorBuilder = (*Count)(nil)
|
|||||||
// in a stream.
|
// in a stream.
|
||||||
type Count struct {
|
type Count struct {
|
||||||
Expr expr.Expr
|
Expr expr.Expr
|
||||||
Wildcard bool
|
wildcard bool
|
||||||
Count int64
|
Count int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewCount(e expr.Expr) *Count {
|
||||||
|
_, wc := e.(expr.Wildcard)
|
||||||
|
return &Count{
|
||||||
|
Expr: e,
|
||||||
|
wildcard: wc,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Count) Eval(env *environment.Environment) (types.Value, error) {
|
func (c *Count) Eval(env *environment.Environment) (types.Value, error) {
|
||||||
d, ok := env.GetRow()
|
d, ok := env.GetRow()
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -214,20 +222,12 @@ func (c *Count) IsEqual(other expr.Expr) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.Wildcard && o.Wildcard {
|
|
||||||
return c.Expr == nil && o.Expr == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return expr.Equal(c.Expr, o.Expr)
|
return expr.Equal(c.Expr, o.Expr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Count) Params() []expr.Expr { return []expr.Expr{c.Expr} }
|
func (c *Count) Params() []expr.Expr { return []expr.Expr{c.Expr} }
|
||||||
|
|
||||||
func (c *Count) String() string {
|
func (c *Count) String() string {
|
||||||
if c.Wildcard {
|
|
||||||
return "COUNT(*)"
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprintf("COUNT(%v)", c.Expr)
|
return fmt.Sprintf("COUNT(%v)", c.Expr)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -246,7 +246,7 @@ type CountAggregator struct {
|
|||||||
|
|
||||||
// Aggregate increments the counter if the count expression evaluates to a non-null value.
|
// Aggregate increments the counter if the count expression evaluates to a non-null value.
|
||||||
func (c *CountAggregator) Aggregate(env *environment.Environment) error {
|
func (c *CountAggregator) Aggregate(env *environment.Environment) error {
|
||||||
if c.Fn.Wildcard {
|
if c.Fn.wildcard {
|
||||||
c.Count++
|
c.Count++
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -169,7 +169,7 @@ func (kvp *KVPairs) String() string {
|
|||||||
if i > 0 {
|
if i > 0 {
|
||||||
b.WriteString(", ")
|
b.WriteString(", ")
|
||||||
}
|
}
|
||||||
b.WriteString(fmt.Sprintf("%s", p))
|
b.WriteString(p.String())
|
||||||
}
|
}
|
||||||
b.WriteRune('}')
|
b.WriteRune('}')
|
||||||
|
|
||||||
|
|||||||
@@ -63,15 +63,20 @@ func (w Wildcard) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w Wildcard) Eval(env *environment.Environment) (types.Value, error) {
|
func (w Wildcard) Eval(env *environment.Environment) (types.Value, error) {
|
||||||
return nil, errors.New("no table specified")
|
r, ok := env.GetRow()
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("no table specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
return types.NewObjectValue(r.Object()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Iterate call the object iterate method.
|
// Iterate call the object iterate method.
|
||||||
func (w Wildcard) Iterate(env environment.Environment, fn func(field string, value types.Value) error) error {
|
func (w Wildcard) Iterate(env environment.Environment, fn func(field string, value types.Value) error) error {
|
||||||
d, ok := env.GetRow()
|
r, ok := env.GetRow()
|
||||||
if !ok {
|
if !ok {
|
||||||
return errors.New("no table specified")
|
return errors.New("no table specified")
|
||||||
}
|
}
|
||||||
|
|
||||||
return d.Iterate(fn)
|
return r.Iterate(fn)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -355,3 +355,7 @@ func NewFromCSV(headers, columns []string) types.Object {
|
|||||||
|
|
||||||
return fb
|
return fb
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewArrayFromSlice[T any](l []T) types.Array {
|
||||||
|
return &sliceArray{ref: reflect.ValueOf(l)}
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
|
|
||||||
"github.com/chaisql/chai/internal/environment"
|
"github.com/chaisql/chai/internal/environment"
|
||||||
"github.com/chaisql/chai/internal/expr"
|
"github.com/chaisql/chai/internal/expr"
|
||||||
"github.com/chaisql/chai/internal/expr/functions"
|
|
||||||
"github.com/chaisql/chai/internal/object"
|
"github.com/chaisql/chai/internal/object"
|
||||||
"github.com/chaisql/chai/internal/sql/scanner"
|
"github.com/chaisql/chai/internal/sql/scanner"
|
||||||
"github.com/chaisql/chai/internal/types"
|
"github.com/chaisql/chai/internal/types"
|
||||||
@@ -292,6 +291,8 @@ func (p *Parser) parseUnaryExpr(allowed ...scanner.Token) (expr.Expr, error) {
|
|||||||
return expr.LiteralValue{Value: types.NewBoolValue(tok == scanner.TRUE)}, nil
|
return expr.LiteralValue{Value: types.NewBoolValue(tok == scanner.TRUE)}, nil
|
||||||
case scanner.NULL:
|
case scanner.NULL:
|
||||||
return expr.LiteralValue{Value: types.NewNullValue()}, nil
|
return expr.LiteralValue{Value: types.NewNullValue()}, nil
|
||||||
|
case scanner.MUL:
|
||||||
|
return expr.Wildcard{}, nil
|
||||||
case scanner.LBRACKET:
|
case scanner.LBRACKET:
|
||||||
p.Unscan()
|
p.Unscan()
|
||||||
e, err := p.ParseObject()
|
e, err := p.ParseObject()
|
||||||
@@ -669,16 +670,6 @@ func (p *Parser) parseFunction() (expr.Expr, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Special case: If the function is COUNT, support the special case COUNT(*)
|
|
||||||
if tok, pos, lit := p.ScanIgnoreWhitespace(); tok == scanner.MUL {
|
|
||||||
if tok, _, _ := p.ScanIgnoreWhitespace(); tok != scanner.RPAREN {
|
|
||||||
return nil, newParseError(scanner.Tokstr(tok, lit), []string{")"}, pos)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &functions.Count{Wildcard: true}, nil
|
|
||||||
}
|
|
||||||
p.Unscan()
|
|
||||||
|
|
||||||
// Check if the function is called without arguments.
|
// Check if the function is called without arguments.
|
||||||
if tok, _, _ := p.ScanIgnoreWhitespace(); tok == scanner.RPAREN {
|
if tok, _, _ := p.ScanIgnoreWhitespace(); tok == scanner.RPAREN {
|
||||||
def, err := p.packagesTable.GetFunc(pkgName, funcName)
|
def, err := p.packagesTable.GetFunc(pkgName, funcName)
|
||||||
|
|||||||
@@ -170,8 +170,8 @@ func TestParserExpr(t *testing.T) {
|
|||||||
// functions
|
// functions
|
||||||
{"pk() function", "pk()", &functions.PK{}, false},
|
{"pk() function", "pk()", &functions.PK{}, false},
|
||||||
{"count(expr) function", "count(a)", &functions.Count{Expr: testutil.ParsePath(t, "a")}, false},
|
{"count(expr) function", "count(a)", &functions.Count{Expr: testutil.ParsePath(t, "a")}, false},
|
||||||
{"count(*) function", "count(*)", &functions.Count{Wildcard: true}, false},
|
{"count(*) function", "count(*)", functions.NewCount(expr.Wildcard{}), false},
|
||||||
{"count (*) function with spaces", "count (*)", &functions.Count{Wildcard: true}, false},
|
{"count (*) function with spaces", "count (*)", functions.NewCount(expr.Wildcard{}), false},
|
||||||
{"packaged function", "math.floor(1.2)", testutil.FunctionExpr(t, "math.floor", testutil.DoubleValue(1.2)), false},
|
{"packaged function", "math.floor(1.2)", testutil.FunctionExpr(t, "math.floor", testutil.DoubleValue(1.2)), false},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -126,7 +126,7 @@ func TestParserSelect(t *testing.T) {
|
|||||||
{"WithOffsetThenLimit", "SELECT * FROM test WHERE age = 10 OFFSET 20 LIMIT 10", nil, true, true},
|
{"WithOffsetThenLimit", "SELECT * FROM test WHERE age = 10 OFFSET 20 LIMIT 10", nil, true, true},
|
||||||
{"With aggregation function", "SELECT COUNT(*) FROM test",
|
{"With aggregation function", "SELECT COUNT(*) FROM test",
|
||||||
stream.New(table.Scan("test")).
|
stream.New(table.Scan("test")).
|
||||||
Pipe(rows.GroupAggregate(nil, &functions.Count{Wildcard: true})).
|
Pipe(rows.GroupAggregate(nil, functions.NewCount(expr.Wildcard{}))).
|
||||||
Pipe(rows.Project(testutil.ParseNamedExpr(t, "COUNT(*)"))),
|
Pipe(rows.Project(testutil.ParseNamedExpr(t, "COUNT(*)"))),
|
||||||
true, false},
|
true, false},
|
||||||
{"With NEXT VALUE FOR", "SELECT NEXT VALUE FOR foo FROM test",
|
{"With NEXT VALUE FOR", "SELECT NEXT VALUE FOR foo FROM test",
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ func TestAggregate(t *testing.T) {
|
|||||||
{
|
{
|
||||||
"count",
|
"count",
|
||||||
nil,
|
nil,
|
||||||
[]expr.AggregatorBuilder{&functions.Count{Wildcard: true}},
|
[]expr.AggregatorBuilder{functions.NewCount(expr.Wildcard{})},
|
||||||
[]types.Object{testutil.MakeObject(t, `{"a": 10}`)},
|
[]types.Object{testutil.MakeObject(t, `{"a": 10}`)},
|
||||||
[]types.Object{testutil.MakeObject(t, `{"COUNT(*)": 1}`)},
|
[]types.Object{testutil.MakeObject(t, `{"COUNT(*)": 1}`)},
|
||||||
false,
|
false,
|
||||||
|
|||||||
@@ -172,22 +172,22 @@ func ExprRunner(t *testing.T, testfile string) {
|
|||||||
t.Helper()
|
t.Helper()
|
||||||
// parse the expected result
|
// parse the expected result
|
||||||
e, err := parser.NewParser(strings.NewReader(stmt.Res)).ParseExpr()
|
e, err := parser.NewParser(strings.NewReader(stmt.Res)).ParseExpr()
|
||||||
assert.NoErrorf(t, err, "parse error at %s:%d\n`%s`", testfile, stmt.ResLine, stmt.Res)
|
assert.NoErrorf(t, err, "parse error at %s:%d\n`%s`: %v", testfile, stmt.ResLine, stmt.Res, err)
|
||||||
|
|
||||||
// eval it to get a proper Value
|
// eval it to get a proper Value
|
||||||
want, err := e.Eval(environment.New(nil))
|
want, err := e.Eval(environment.New(nil))
|
||||||
assert.NoErrorf(t, err, "eval error at %s:%d\n`%s`", testfile, stmt.ResLine, stmt.Res)
|
assert.NoErrorf(t, err, "eval error at %s:%d\n`%s`: %v", testfile, stmt.ResLine, stmt.Res, err)
|
||||||
|
|
||||||
// parse the given expr
|
// parse the given expr
|
||||||
e, err = parser.NewParser(strings.NewReader(stmt.Expr)).ParseExpr()
|
e, err = parser.NewParser(strings.NewReader(stmt.Expr)).ParseExpr()
|
||||||
assert.NoErrorf(t, err, "parse error at %s:%d\n`%s`", testfile, stmt.ExprLine, stmt.Expr)
|
assert.NoErrorf(t, err, "parse error at %s:%d\n`%s`: %v", testfile, stmt.ExprLine, stmt.Expr, err)
|
||||||
|
|
||||||
// eval it to get a proper Value
|
// eval it to get a proper Value
|
||||||
got, err := e.Eval(environment.New(nil))
|
got, err := e.Eval(environment.New(nil))
|
||||||
assert.NoErrorf(t, err, "eval error at %s:%d\n`%s`", testfile, stmt.ExprLine, stmt.Expr)
|
assert.NoErrorf(t, err, "eval error at %s:%d\n`%s`: %v", testfile, stmt.ExprLine, stmt.Expr, err)
|
||||||
|
|
||||||
// finally, compare those two
|
// finally, compare those two
|
||||||
require.Equalf(t, want, got, "assertion error at %s:%d", testfile, stmt.ResLine)
|
RequireValueEqual(t, want, got, "assertion error at %s:%d", testfile, stmt.ResLine)
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
t.Run("NOK "+stmt.Expr, func(t *testing.T) {
|
t.Run("NOK "+stmt.Expr, func(t *testing.T) {
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// MakeValue turns v into a types.Value.
|
// MakeValue turns v into a types.Value.
|
||||||
func MakeValue(t testing.TB, v interface{}) types.Value {
|
func MakeValue(t testing.TB, v any) types.Value {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
vv, err := object.NewValue(v)
|
vv, err := object.NewValue(v)
|
||||||
@@ -26,7 +26,7 @@ func MakeValue(t testing.TB, v interface{}) types.Value {
|
|||||||
return vv
|
return vv
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeArrayValue(t testing.TB, vs ...interface{}) types.Value {
|
func MakeArrayValue(t testing.TB, vs ...any) types.Value {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
vvs := []types.Value{}
|
vvs := []types.Value{}
|
||||||
@@ -168,6 +168,19 @@ func RequireArrayEqual(t testing.TB, want, got types.Array) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func RequireValueEqual(t testing.TB, want, got types.Value, msg string, args ...any) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tWant, err := types.MarshalTextIndent(want, "\n", " ")
|
||||||
|
require.NoError(t, err)
|
||||||
|
tGot, err := types.MarshalTextIndent(got, "\n", " ")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if diff := cmp.Diff(string(tWant), string(tGot), cmp.Comparer(strings.EqualFold)); diff != "" {
|
||||||
|
require.Failf(t, "mismatched values, (-want, +got)", "%s\n%s", diff, fmt.Sprintf(msg, args...))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func CloneObject(t testing.TB, d types.Object) *object.FieldBuffer {
|
func CloneObject(t testing.TB, d types.Object) *object.FieldBuffer {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user