diff --git a/internal/expr/arithmeric.go b/internal/expr/arithmeric.go index 362bf2ab..e15ef774 100644 --- a/internal/expr/arithmeric.go +++ b/internal/expr/arithmeric.go @@ -17,6 +17,12 @@ type arithmeticOperator struct { *simpleOperator } +func (op *arithmeticOperator) Clone() Expr { + return &arithmeticOperator{ + simpleOperator: op.simpleOperator.Clone(), + } +} + func (op *arithmeticOperator) Eval(env *environment.Environment) (types.Value, error) { return op.simpleOperator.eval(env, func(va, vb types.Value) (types.Value, error) { a, ok := va.(types.Numeric) diff --git a/internal/expr/comparison.go b/internal/expr/comparison.go index ccbcdfbd..01c64780 100644 --- a/internal/expr/comparison.go +++ b/internal/expr/comparison.go @@ -59,6 +59,10 @@ func (op *cmpOp) compare(l, r types.Value) (bool, error) { } } +func (op *cmpOp) Clone() Expr { + return &cmpOp{op.simpleOperator.Clone()} +} + // Eq creates an expression that returns true if a equals b. func Eq(a, b Expr) Expr { return newCmpOp(a, b, scanner.EQ) @@ -102,6 +106,13 @@ func Between(a Expr) func(x, b Expr) Expr { } } +func (op *BetweenOperator) Clone() Expr { + return &BetweenOperator{ + op.simpleOperator.Clone(), + Clone(op.X), + } +} + func (op *BetweenOperator) Eval(env *environment.Environment) (types.Value, error) { x, err := op.X.Eval(env) if err != nil { @@ -152,15 +163,23 @@ func In(a Expr, b Expr) Expr { return &InOperator{a, b, scanner.IN} } -func (op InOperator) Precedence() int { +func (op *InOperator) Clone() Expr { + return &InOperator{ + Clone(op.a), + Clone(op.b), + op.op, + } +} + +func (op *InOperator) Precedence() int { return op.op.Precedence() } -func (op InOperator) LeftHand() Expr { +func (op *InOperator) LeftHand() Expr { return op.a } -func (op InOperator) RightHand() Expr { +func (op *InOperator) RightHand() Expr { return op.b } @@ -242,12 +261,18 @@ func (op *InOperator) validateRightExpression(b Expr) (LiteralExprList, error) { } type NotInOperator struct { - InOperator + *InOperator } // NotIn creates an expression that evaluates to the result of a NOT IN b. func NotIn(a Expr, b Expr) Expr { - return &NotInOperator{InOperator{a, b, scanner.NIN}} + return &NotInOperator{&InOperator{a, b, scanner.NIN}} +} + +func (op *NotInOperator) Clone() Expr { + return &NotInOperator{ + op.InOperator.Clone().(*InOperator), + } } func (op *NotInOperator) Eval(env *environment.Environment) (types.Value, error) { @@ -267,6 +292,12 @@ func Is(a, b Expr) Expr { return &IsOperator{&simpleOperator{a, b, scanner.IN}} } +func (op *IsOperator) Clone() Expr { + return &IsOperator{ + op.simpleOperator.Clone(), + } +} + func (op *IsOperator) Eval(env *environment.Environment) (types.Value, error) { return op.simpleOperator.eval(env, func(a, b types.Value) (types.Value, error) { ok, err := a.EQ(b) @@ -290,6 +321,12 @@ func IsNot(a, b Expr) Expr { return &IsNotOperator{&simpleOperator{a, b, scanner.ISN}} } +func (op *IsNotOperator) Clone() Expr { + return &IsNotOperator{ + op.simpleOperator.Clone(), + } +} + func (op *IsNotOperator) Eval(env *environment.Environment) (types.Value, error) { return op.simpleOperator.eval(env, func(a, b types.Value) (types.Value, error) { eq, err := a.EQ(b) diff --git a/internal/expr/expr.go b/internal/expr/expr.go index d34da353..c3a7c1d4 100644 --- a/internal/expr/expr.go +++ b/internal/expr/expr.go @@ -112,6 +112,7 @@ type Function interface { // Params returns the list of parameters this function has received. Params() []Expr + Clone() Expr } // An Aggregator is an expression that aggregates objects into one result. @@ -238,3 +239,39 @@ func (n NextValueFor) String() string { // return types.TypeNull, fmt.Errorf("unexpected expression type: %T", e) // } + +func Clone(e Expr) Expr { + if e == nil { + return nil + } + + switch e := e.(type) { + case cloner: + return e.Clone() + case Parentheses: + return Parentheses{E: Clone(e.E)} + case *NamedExpr: + return &NamedExpr{ + Expr: Clone(e.Expr), + ExprName: e.ExprName, + } + case *Cast: + return &Cast{ + Expr: Clone(e.Expr), + CastAs: e.CastAs, + } + case LiteralValue, + Column, + NamedParam, + PositionalParam, + NextValueFor, + Wildcard: + return e + } + + panic(fmt.Sprintf("clone: unexpected expression type: %T", e)) +} + +type cloner interface { + Clone() Expr +} diff --git a/internal/expr/functions/builtins.go b/internal/expr/functions/builtins.go index 5e7c973d..bb6c1019 100644 --- a/internal/expr/functions/builtins.go +++ b/internal/expr/functions/builtins.go @@ -127,6 +127,12 @@ type TypeOf struct { Expr expr.Expr } +func (t *TypeOf) Clone() expr.Expr { + return &TypeOf{ + Expr: expr.Clone(t.Expr), + } +} + func (t *TypeOf) Eval(env *environment.Environment) (types.Value, error) { v, err := t.Expr.Eval(env) if err != nil { @@ -175,6 +181,14 @@ func NewCount(e expr.Expr) *Count { } } +func (t *Count) Clone() expr.Expr { + return &Count{ + Expr: expr.Clone(t.Expr), + wildcard: t.wildcard, + Count: t.Count, + } +} + func (c *Count) Eval(env *environment.Environment) (types.Value, error) { d, ok := env.GetRow() if !ok { @@ -250,6 +264,12 @@ type Min struct { Expr expr.Expr } +func (t *Min) Clone() expr.Expr { + return &Min{ + Expr: expr.Clone(t.Expr), + } +} + // Eval extracts the min value from the given object and returns it. func (m *Min) Eval(env *environment.Environment) (types.Value, error) { r, ok := env.GetRow() @@ -354,6 +374,12 @@ type Max struct { Expr expr.Expr } +func (t *Max) Clone() expr.Expr { + return &Max{ + Expr: expr.Clone(t.Expr), + } +} + // Eval extracts the max value from the given object and returns it. func (m *Max) Eval(env *environment.Environment) (types.Value, error) { r, ok := env.GetRow() @@ -453,6 +479,12 @@ type Sum struct { Expr expr.Expr } +func (t *Sum) Clone() expr.Expr { + return &Sum{ + Expr: expr.Clone(t.Expr), + } +} + // Eval extracts the sum value from the given object and returns it. func (s *Sum) Eval(env *environment.Environment) (types.Value, error) { r, ok := env.GetRow() @@ -564,6 +596,12 @@ type Avg struct { Expr expr.Expr } +func (t *Avg) Clone() expr.Expr { + return &Avg{ + Expr: expr.Clone(t.Expr), + } +} + // Eval extracts the average value from the given object and returns it. func (s *Avg) Eval(env *environment.Environment) (types.Value, error) { r, ok := env.GetRow() @@ -651,6 +689,12 @@ type Len struct { Expr expr.Expr } +func (t *Len) Clone() expr.Expr { + return &Len{ + Expr: expr.Clone(t.Expr), + } +} + // Eval extracts the average value from the given object and returns it. func (s *Len) Eval(env *environment.Environment) (types.Value, error) { val, err := s.Expr.Eval(env) @@ -694,6 +738,16 @@ type Coalesce struct { Exprs []expr.Expr } +func (c *Coalesce) Clone() expr.Expr { + var clone Coalesce + clone.Exprs = make([]expr.Expr, 0, len(c.Exprs)) + for _, e := range c.Exprs { + clone.Exprs = append(clone.Exprs, expr.Clone(e)) + } + + return &clone +} + func (c *Coalesce) Eval(e *environment.Environment) (types.Value, error) { for _, exp := range c.Exprs { v, err := exp.Eval(e) @@ -717,6 +771,10 @@ func (c *Coalesce) Params() []expr.Expr { type Now struct{} +func (n *Now) Clone() expr.Expr { + return &Now{} +} + func (n *Now) Eval(env *environment.Environment) (types.Value, error) { tx := env.GetTx() if tx == nil { diff --git a/internal/expr/functions/scalar_definition.go b/internal/expr/functions/scalar_definition.go index d8afa268..909266e5 100644 --- a/internal/expr/functions/scalar_definition.go +++ b/internal/expr/functions/scalar_definition.go @@ -56,6 +56,18 @@ type ScalarFunction struct { params []expr.Expr } +func (sf *ScalarFunction) Clone() expr.Expr { + exprs := make([]expr.Expr, 0, len(sf.params)) + for _, e := range sf.params { + exprs = append(exprs, expr.Clone(e)) + } + + return &ScalarFunction{ + def: sf.def, + params: exprs, + } +} + // Eval returns a row.Value based on the given environment and the underlying function // definition. func (sf *ScalarFunction) Eval(env *environment.Environment) (types.Value, error) { diff --git a/internal/expr/functions/strings.go b/internal/expr/functions/strings.go index 45dbdb62..75d13499 100644 --- a/internal/expr/functions/strings.go +++ b/internal/expr/functions/strings.go @@ -15,6 +15,12 @@ type Lower struct { Expr expr.Expr } +func (s *Lower) Clone() expr.Expr { + return &Lower{ + Expr: expr.Clone(s.Expr), + } +} + func (s *Lower) Eval(env *environment.Environment) (types.Value, error) { val, err := s.Expr.Eval(env) if err != nil { @@ -55,6 +61,12 @@ type Upper struct { Expr expr.Expr } +func (s *Upper) Clone() expr.Expr { + return &Upper{ + Expr: expr.Clone(s.Expr), + } +} + func (s *Upper) Eval(env *environment.Environment) (types.Value, error) { val, err := s.Expr.Eval(env) if err != nil { @@ -101,6 +113,18 @@ type Trim struct { type TrimFunc func(string, string) string +func (s *Trim) Clone() expr.Expr { + exprs := make([]expr.Expr, len(s.Expr)) + for i := range s.Expr { + exprs[i] = expr.Clone(s.Expr[i]) + } + return &Trim{ + Expr: exprs, + TrimFunc: s.TrimFunc, + Name: s.Name, + } +} + func (s *Trim) Eval(env *environment.Environment) (types.Value, error) { if len(s.Expr) > 2 { return nil, fmt.Errorf("misuse of string function %v()", s.Name) diff --git a/internal/expr/like.go b/internal/expr/like.go index d9e70c3b..9880034f 100644 --- a/internal/expr/like.go +++ b/internal/expr/like.go @@ -22,6 +22,12 @@ func Like(a, b Expr) Expr { return &LikeOperator{&simpleOperator{a, b, scanner.LIKE}} } +func (op *LikeOperator) Clone() Expr { + return &LikeOperator{ + simpleOperator: op.simpleOperator.Clone(), + } +} + func (op *LikeOperator) Eval(env *environment.Environment) (types.Value, error) { return op.simpleOperator.eval(env, func(a, b types.Value) (types.Value, error) { if a.Type() != types.TypeText || b.Type() != types.TypeText { @@ -36,6 +42,10 @@ func (op *LikeOperator) Eval(env *environment.Environment) (types.Value, error) }) } +func (op *LikeOperator) String() string { + return fmt.Sprintf("%v LIKE %v", op.a, op.b) +} + type NotLikeOperator struct { *LikeOperator } @@ -45,6 +55,12 @@ func NotLike(a, b Expr) Expr { return &NotLikeOperator{&LikeOperator{&simpleOperator{a, b, scanner.NLIKE}}} } +func (op *NotLikeOperator) Clone() Expr { + return &NotLikeOperator{ + LikeOperator: op.LikeOperator.Clone().(*LikeOperator), + } +} + func (op *NotLikeOperator) Eval(env *environment.Environment) (types.Value, error) { return invertBoolResult(op.LikeOperator.Eval)(env) } diff --git a/internal/expr/literal.go b/internal/expr/literal.go index 044bfab5..7748e290 100644 --- a/internal/expr/literal.go +++ b/internal/expr/literal.go @@ -37,6 +37,14 @@ func (v LiteralValue) Eval(*environment.Environment) (types.Value, error) { // LiteralExprList is a list of expressions. type LiteralExprList []Expr +func (l LiteralExprList) Clone() Expr { + exprs := make(LiteralExprList, len(l)) + for i, e := range l { + exprs[i] = Clone(e) + } + return exprs +} + // IsEqual compares this expression with the other expression and returns // true if they are equal. func (l LiteralExprList) IsEqual(o LiteralExprList) bool { diff --git a/internal/expr/logical.go b/internal/expr/logical.go index 38bf4009..b5b518d1 100644 --- a/internal/expr/logical.go +++ b/internal/expr/logical.go @@ -18,6 +18,12 @@ func And(a, b Expr) Expr { return &AndOp{&simpleOperator{a, b, scanner.AND}} } +func (op *AndOp) Clone() Expr { + return &AndOp{ + simpleOperator: op.simpleOperator.Clone(), + } +} + // Eval implements the Expr interface. It evaluates a and b and returns true if both evaluate // to true. func (op *AndOp) Eval(env *environment.Environment) (types.Value, error) { @@ -52,6 +58,12 @@ func Or(a, b Expr) Expr { return &OrOp{&simpleOperator{a, b, scanner.OR}} } +func (op *OrOp) Clone() Expr { + return &OrOp{ + simpleOperator: op.simpleOperator.Clone(), + } +} + // Eval implements the Expr interface. It evaluates a and b and returns true if a or b evalutate // to true. func (op *OrOp) Eval(env *environment.Environment) (types.Value, error) { @@ -92,6 +104,12 @@ func Not(e Expr) Expr { return &NotOp{&simpleOperator{a: e}} } +func (op *NotOp) Clone() Expr { + return &NotOp{ + simpleOperator: op.simpleOperator.Clone(), + } +} + // Eval implements the Expr interface. It evaluates e and returns true if b is falsy func (op *NotOp) Eval(env *environment.Environment) (types.Value, error) { s, err := op.a.Eval(env) diff --git a/internal/expr/operator.go b/internal/expr/operator.go index a8aaddda..dce0f5a3 100644 --- a/internal/expr/operator.go +++ b/internal/expr/operator.go @@ -38,6 +38,14 @@ func (op *simpleOperator) Token() scanner.Token { return op.Tok } +func (op *simpleOperator) Clone() *simpleOperator { + return &simpleOperator{ + a: Clone(op.a), + b: Clone(op.b), + Tok: op.Tok, + } +} + func (op *simpleOperator) eval(env *environment.Environment, fn func(a, b types.Value) (types.Value, error)) (types.Value, error) { if op.a == nil || op.b == nil { return NullLiteral, errors.New("missing operand") @@ -117,7 +125,7 @@ type Cast struct { } // Eval returns the primary key of the current row. -func (c Cast) Eval(env *environment.Environment) (types.Value, error) { +func (c *Cast) Eval(env *environment.Environment) (types.Value, error) { v, err := c.Expr.Eval(env) if err != nil { return v, err @@ -128,12 +136,12 @@ func (c Cast) Eval(env *environment.Environment) (types.Value, error) { // IsEqual compares this expression with the other expression and returns // true if they are equal. -func (c Cast) IsEqual(other Expr) bool { +func (c *Cast) IsEqual(other Expr) bool { if other == nil { return false } - o, ok := other.(Cast) + o, ok := other.(*Cast) if !ok { return false } @@ -149,8 +157,8 @@ func (c Cast) IsEqual(other Expr) bool { return o.Expr != nil } -func (c Cast) Params() []Expr { return []Expr{c.Expr} } +func (c *Cast) Params() []Expr { return []Expr{c.Expr} } -func (c Cast) String() string { +func (c *Cast) String() string { return fmt.Sprintf("CAST(%v AS %v)", c.Expr, c.CastAs) } diff --git a/internal/query/statement/stream.go b/internal/query/statement/stream.go index 53b4c7ba..bfd363e1 100644 --- a/internal/query/statement/stream.go +++ b/internal/query/statement/stream.go @@ -31,7 +31,7 @@ type PreparedStreamStmt struct { // Run returns a result containing the stream. The stream will be executed by calling the Iterate method of // the result. func (s *PreparedStreamStmt) Run(ctx *Context) (Result, error) { - st, err := planner.Optimize(s.Stream, ctx.Tx.Catalog, ctx.Params) + st, err := planner.Optimize(s.Stream.Clone(), ctx.Tx.Catalog, ctx.Params) if err != nil { return Result{}, err } diff --git a/internal/sql/parser/expr.go b/internal/sql/parser/expr.go index 3f0e70cb..24d6bdff 100644 --- a/internal/sql/parser/expr.go +++ b/internal/sql/parser/expr.go @@ -578,7 +578,7 @@ func (p *Parser) parseCastExpression() (expr.Expr, error) { return nil, err } - return expr.Cast{Expr: e, CastAs: tp}, nil + return &expr.Cast{Expr: e, CastAs: tp}, nil } // tokenIsAllowed is a helper function that determines if a token is allowed. diff --git a/internal/sql/parser/expr_test.go b/internal/sql/parser/expr_test.go index 5547efe4..5aa7cfef 100644 --- a/internal/sql/parser/expr_test.go +++ b/internal/sql/parser/expr_test.go @@ -112,7 +112,7 @@ func TestParserExpr(t *testing.T) { {"with NULL", "age > NULL", expr.Gt(expr.Column("age"), testutil.NullValue()), false}, // unary operators - {"CAST", "CAST(a AS TEXT)", expr.Cast{Expr: expr.Column("a"), CastAs: types.TypeText}, false}, + {"CAST", "CAST(a AS TEXT)", &expr.Cast{Expr: expr.Column("a"), CastAs: types.TypeText}, false}, {"NOT", "NOT 10", expr.Not(testutil.IntegerValue(10)), false}, {"NOT", "NOT NOT", nil, true}, {"NOT", "NOT NOT 10", expr.Not(expr.Not(testutil.IntegerValue(10))), false}, diff --git a/internal/stream/concat.go b/internal/stream/concat.go index 8b4d5017..48ea2ed7 100644 --- a/internal/stream/concat.go +++ b/internal/stream/concat.go @@ -17,6 +17,18 @@ func Concat(s ...*Stream) *ConcatOperator { return &ConcatOperator{Streams: s} } +func (it *ConcatOperator) Clone() Operator { + streams := make([]*Stream, len(it.Streams)) + for i, s := range it.Streams { + streams[i] = s.Clone() + } + + return &ConcatOperator{ + BaseOperator: it.BaseOperator.Clone(), + Streams: streams, + } +} + func (it *ConcatOperator) Iterate(in *environment.Environment, fn func(*environment.Environment) error) error { for _, s := range it.Streams { if err := s.Iterate(in, fn); err != nil { diff --git a/internal/stream/index/delete.go b/internal/stream/index/delete.go index 81e02bd7..3386496e 100644 --- a/internal/stream/index/delete.go +++ b/internal/stream/index/delete.go @@ -22,6 +22,13 @@ func Delete(indexName string) *DeleteOperator { } } +func (op *DeleteOperator) Clone() stream.Operator { + return &DeleteOperator{ + BaseOperator: op.BaseOperator.Clone(), + indexName: op.indexName, + } +} + func (op *DeleteOperator) Iterate(in *environment.Environment, fn func(out *environment.Environment) error) error { tx := in.GetTx() diff --git a/internal/stream/index/insert.go b/internal/stream/index/insert.go index 0804e4b5..ee081afc 100644 --- a/internal/stream/index/insert.go +++ b/internal/stream/index/insert.go @@ -22,6 +22,13 @@ func Insert(indexName string) *InsertOperator { } } +func (op *InsertOperator) Clone() stream.Operator { + return &InsertOperator{ + BaseOperator: op.BaseOperator.Clone(), + indexName: op.indexName, + } +} + func (op *InsertOperator) Iterate(in *environment.Environment, fn func(out *environment.Environment) error) error { tx := in.GetTx() diff --git a/internal/stream/index/scan.go b/internal/stream/index/scan.go index 87a0ff5e..cdf7916b 100644 --- a/internal/stream/index/scan.go +++ b/internal/stream/index/scan.go @@ -35,6 +35,15 @@ func ScanReverse(name string, ranges ...stream.Range) *ScanOperator { return &ScanOperator{IndexName: name, Ranges: ranges, Reverse: true} } +func (op *ScanOperator) Clone() stream.Operator { + return &ScanOperator{ + BaseOperator: op.BaseOperator.Clone(), + IndexName: op.IndexName, + Ranges: op.Ranges.Clone(), + Reverse: op.Reverse, + } +} + // Iterate over the objects of the table. Each object is stored in the environment // that is passed to the fn function, using SetCurrentValue. func (it *ScanOperator) Iterate(in *environment.Environment, fn func(out *environment.Environment) error) error { diff --git a/internal/stream/index/validate.go b/internal/stream/index/validate.go index 5eafc8db..c54752bc 100644 --- a/internal/stream/index/validate.go +++ b/internal/stream/index/validate.go @@ -23,6 +23,13 @@ func Validate(indexName string) *ValidateOperator { } } +func (op *ValidateOperator) Clone() stream.Operator { + return &ValidateOperator{ + BaseOperator: op.BaseOperator.Clone(), + indexName: op.indexName, + } +} + func (op *ValidateOperator) Iterate(in *environment.Environment, fn func(out *environment.Environment) error) error { tx := in.GetTx() diff --git a/internal/stream/on_conflict.go b/internal/stream/on_conflict.go index 3aec4322..656742db 100644 --- a/internal/stream/on_conflict.go +++ b/internal/stream/on_conflict.go @@ -21,6 +21,13 @@ func OnConflict(onConflict *Stream) *OnConflictOperator { } } +func (it *OnConflictOperator) Clone() Operator { + return &OnConflictOperator{ + BaseOperator: it.BaseOperator.Clone(), + OnConflict: it.OnConflict.Clone(), + } +} + func (op *OnConflictOperator) Iterate(in *environment.Environment, fn func(out *environment.Environment) error) error { var newEnv environment.Environment diff --git a/internal/stream/operator.go b/internal/stream/operator.go index 4a45bd7a..5574cec1 100644 --- a/internal/stream/operator.go +++ b/internal/stream/operator.go @@ -25,6 +25,7 @@ type Operator interface { GetNext() Operator GetPrev() Operator String() string + Clone() Operator } // An OperatorFunc is the function that will receive each value of the stream. @@ -59,3 +60,7 @@ func (op *BaseOperator) GetPrev() Operator { func (op *BaseOperator) GetNext() Operator { return op.Next } + +func (op BaseOperator) Clone() BaseOperator { + return op +} diff --git a/internal/stream/path/rename.go b/internal/stream/path/rename.go index 7b1b2a23..d1ed8257 100644 --- a/internal/stream/path/rename.go +++ b/internal/stream/path/rename.go @@ -27,6 +27,14 @@ func PathsRename(columnNames ...string) *RenameOperator { } } +func (op *RenameOperator) Clone() stream.Operator { + return &RenameOperator{ + BaseOperator: op.BaseOperator.Clone(), + // No need to clone the column names, they are immutable. + ColumnNames: op.ColumnNames, + } +} + // Iterate implements the Operator interface. func (op *RenameOperator) Iterate(in *environment.Environment, f func(out *environment.Environment) error) error { var cb row.ColumnBuffer diff --git a/internal/stream/path/set.go b/internal/stream/path/set.go index 94edf661..5f83bb7c 100644 --- a/internal/stream/path/set.go +++ b/internal/stream/path/set.go @@ -27,6 +27,14 @@ func Set(column string, e expr.Expr) *SetOperator { } } +func (op *SetOperator) Clone() stream.Operator { + return &SetOperator{ + BaseOperator: op.BaseOperator.Clone(), + Column: op.Column, + Expr: expr.Clone(op.Expr), + } +} + // Iterate implements the Operator interface. func (op *SetOperator) Iterate(in *environment.Environment, f func(out *environment.Environment) error) error { var cb row.ColumnBuffer diff --git a/internal/stream/range.go b/internal/stream/range.go index 1d68e497..d04bcb97 100644 --- a/internal/stream/range.go +++ b/internal/stream/range.go @@ -23,6 +23,17 @@ type Range struct { Exact bool } +func (r *Range) Clone() Range { + return Range{ + Min: expr.Clone(r.Min).(expr.LiteralExprList), + Max: expr.Clone(r.Max).(expr.LiteralExprList), + // No need to clone the columns, they are immutable. + Columns: r.Columns, + Exclusive: r.Exclusive, + Exact: r.Exact, + } +} + func (r *Range) Eval(env *environment.Environment) (*database.Range, error) { rng := database.Range{ Exclusive: r.Exclusive, @@ -119,6 +130,19 @@ func (r *Range) IsEqual(other *Range) bool { type Ranges []Range +func (r Ranges) Clone() Ranges { + if r == nil { + return nil + } + + clone := make(Ranges, len(r)) + for i := range r { + clone[i] = r[i].Clone() + } + + return clone +} + // Encode each range using the given value encoder. func (r Ranges) Eval(env *environment.Environment) ([]*database.Range, error) { ranges := make([]*database.Range, 0, len(r)) diff --git a/internal/stream/rows/emit.go b/internal/stream/rows/emit.go index cacb7c3d..196b9f5b 100644 --- a/internal/stream/rows/emit.go +++ b/internal/stream/rows/emit.go @@ -40,6 +40,13 @@ func (op *EmitOperator) Iterate(in *environment.Environment, fn func(out *enviro return nil } +func (op *EmitOperator) Clone() stream.Operator { + return &EmitOperator{ + BaseOperator: op.BaseOperator.Clone(), + Rows: op.Rows, + } +} + func (op *EmitOperator) String() string { var sb strings.Builder diff --git a/internal/stream/rows/filter.go b/internal/stream/rows/filter.go index 0d000391..548c5e98 100644 --- a/internal/stream/rows/filter.go +++ b/internal/stream/rows/filter.go @@ -37,6 +37,13 @@ func (op *FilterOperator) Iterate(in *environment.Environment, f func(out *envir }) } +func (op *FilterOperator) Clone() stream.Operator { + return &FilterOperator{ + BaseOperator: op.BaseOperator.Clone(), + Expr: expr.Clone(op.Expr), + } +} + func (op *FilterOperator) String() string { return fmt.Sprintf("rows.Filter(%s)", op.Expr) } diff --git a/internal/stream/rows/group_aggregate.go b/internal/stream/rows/group_aggregate.go index 030d562e..46dd777a 100644 --- a/internal/stream/rows/group_aggregate.go +++ b/internal/stream/rows/group_aggregate.go @@ -25,6 +25,18 @@ func GroupAggregate(groupBy expr.Expr, builders ...expr.AggregatorBuilder) *Grou return &GroupAggregateOperator{E: groupBy, Builders: builders} } +func (op *GroupAggregateOperator) Clone() stream.Operator { + builders := make([]expr.AggregatorBuilder, len(op.Builders)) + for i, b := range op.Builders { + builders[i] = expr.Clone(b).(expr.AggregatorBuilder) + } + return &GroupAggregateOperator{ + BaseOperator: op.BaseOperator.Clone(), + Builders: builders, + E: expr.Clone(op.E), + } +} + func (op *GroupAggregateOperator) Iterate(in *environment.Environment, f func(out *environment.Environment) error) error { var lastGroup types.Value var ga *groupAggregator diff --git a/internal/stream/rows/project.go b/internal/stream/rows/project.go index b1744db9..5fa0ce88 100644 --- a/internal/stream/rows/project.go +++ b/internal/stream/rows/project.go @@ -24,6 +24,17 @@ func Project(exprs ...expr.Expr) *ProjectOperator { return &ProjectOperator{Exprs: exprs} } +func (op *ProjectOperator) Clone() stream.Operator { + exprs := make([]expr.Expr, len(op.Exprs)) + for i, e := range op.Exprs { + exprs[i] = expr.Clone(e) + } + return &ProjectOperator{ + BaseOperator: op.BaseOperator.Clone(), + Exprs: exprs, + } +} + // Iterate implements the Operator interface. func (op *ProjectOperator) Iterate(in *environment.Environment, f func(out *environment.Environment) error) error { var mask RowMask diff --git a/internal/stream/rows/skip.go b/internal/stream/rows/skip.go index 95202f7c..7e458359 100644 --- a/internal/stream/rows/skip.go +++ b/internal/stream/rows/skip.go @@ -20,6 +20,13 @@ func Skip(e expr.Expr) *SkipOperator { return &SkipOperator{E: e} } +func (op *SkipOperator) Clone() stream.Operator { + return &SkipOperator{ + BaseOperator: op.BaseOperator.Clone(), + E: expr.Clone(op.E), + } +} + // Iterate implements the Operator interface. func (op *SkipOperator) Iterate(in *environment.Environment, f func(out *environment.Environment) error) error { v, err := op.E.Eval(in) diff --git a/internal/stream/rows/take.go b/internal/stream/rows/take.go index f86d8623..d96a2b82 100644 --- a/internal/stream/rows/take.go +++ b/internal/stream/rows/take.go @@ -21,6 +21,13 @@ func Take(e expr.Expr) *TakeOperator { return &TakeOperator{E: e} } +func (op *TakeOperator) Clone() stream.Operator { + return &TakeOperator{ + BaseOperator: op.BaseOperator.Clone(), + E: expr.Clone(op.E), + } +} + // Iterate implements the Operator interface. func (op *TakeOperator) Iterate(in *environment.Environment, f func(out *environment.Environment) error) error { v, err := op.E.Eval(in) diff --git a/internal/stream/rows/temp_tree_sort.go b/internal/stream/rows/temp_tree_sort.go index 32ea28b0..33586a63 100644 --- a/internal/stream/rows/temp_tree_sort.go +++ b/internal/stream/rows/temp_tree_sort.go @@ -31,6 +31,14 @@ func TempTreeSortReverse(e expr.Expr) *TempTreeSortOperator { return &TempTreeSortOperator{Expr: e, Desc: true} } +func (op *TempTreeSortOperator) Clone() stream.Operator { + return &TempTreeSortOperator{ + BaseOperator: op.BaseOperator.Clone(), + Expr: expr.Clone(op.Expr), + Desc: op.Desc, + } +} + func (op *TempTreeSortOperator) Iterate(in *environment.Environment, fn func(out *environment.Environment) error) error { db := in.GetDB() diff --git a/internal/stream/stream.go b/internal/stream/stream.go index 11a15a10..b4dd38c9 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -82,6 +82,25 @@ func (s *Stream) String() string { return sb.String() } +func (s *Stream) Clone() *Stream { + if s == nil { + return nil + } + + if s.Op == nil { + return New(nil) + } + + op := s.First() + var ops []Operator + for op != nil { + ops = append(ops, op.Clone()) + op = op.GetNext() + } + + return New(Pipe(ops...)) +} + func InsertBefore(op, newOp Operator) Operator { if op != nil { prev := op.GetPrev() @@ -123,6 +142,11 @@ type DiscardOperator struct { func Discard() *DiscardOperator { return &DiscardOperator{} } +func (it *DiscardOperator) Clone() Operator { + return &DiscardOperator{ + BaseOperator: it.BaseOperator.Clone(), + } +} // Iterate iterates over all the streams and returns their union. func (op *DiscardOperator) Iterate(in *environment.Environment, _ func(out *environment.Environment) error) (err error) { diff --git a/internal/stream/table/delete.go b/internal/stream/table/delete.go index 308174bc..030fca86 100644 --- a/internal/stream/table/delete.go +++ b/internal/stream/table/delete.go @@ -20,6 +20,13 @@ func Delete(tableName string) *DeleteOperator { return &DeleteOperator{Name: tableName} } +func (op *DeleteOperator) Clone() stream.Operator { + return &DeleteOperator{ + BaseOperator: op.BaseOperator.Clone(), + Name: op.Name, + } +} + // Iterate implements the Operator interface. func (op *DeleteOperator) Iterate(in *environment.Environment, f func(out *environment.Environment) error) error { var table *database.Table diff --git a/internal/stream/table/insert.go b/internal/stream/table/insert.go index c001c414..b57398bb 100644 --- a/internal/stream/table/insert.go +++ b/internal/stream/table/insert.go @@ -20,6 +20,13 @@ func Insert(tableName string) *InsertOperator { return &InsertOperator{Name: tableName} } +func (op *InsertOperator) Clone() stream.Operator { + return &InsertOperator{ + BaseOperator: op.BaseOperator.Clone(), + Name: op.Name, + } +} + // Iterate implements the Operator interface. func (op *InsertOperator) Iterate(in *environment.Environment, f func(out *environment.Environment) error) error { var newEnv environment.Environment diff --git a/internal/stream/table/replace.go b/internal/stream/table/replace.go index 63484c95..12dfe417 100644 --- a/internal/stream/table/replace.go +++ b/internal/stream/table/replace.go @@ -20,6 +20,13 @@ func Replace(tableName string) *ReplaceOperator { return &ReplaceOperator{Name: tableName} } +func (op *ReplaceOperator) Clone() stream.Operator { + return &ReplaceOperator{ + BaseOperator: op.BaseOperator.Clone(), + Name: op.Name, + } +} + // Iterate implements the Operator interface. func (op *ReplaceOperator) Iterate(in *environment.Environment, f func(out *environment.Environment) error) error { var table *database.Table diff --git a/internal/stream/table/scan.go b/internal/stream/table/scan.go index 6ea42368..86d4fb71 100644 --- a/internal/stream/table/scan.go +++ b/internal/stream/table/scan.go @@ -33,6 +33,16 @@ func ScanReverse(tableName string, ranges ...stream.Range) *ScanOperator { return &ScanOperator{TableName: tableName, Ranges: ranges, Reverse: true} } +func (op *ScanOperator) Clone() stream.Operator { + return &ScanOperator{ + BaseOperator: op.BaseOperator.Clone(), + TableName: op.TableName, + Ranges: op.Ranges.Clone(), + Reverse: op.Reverse, + Table: op.Table, + } +} + // Iterate over the objects of the table. Each object is stored in the environment // that is passed to the fn function, using SetCurrentValue. func (it *ScanOperator) Iterate(in *environment.Environment, fn func(out *environment.Environment) error) error { diff --git a/internal/stream/table/validate.go b/internal/stream/table/validate.go index 57eaf17f..09159160 100644 --- a/internal/stream/table/validate.go +++ b/internal/stream/table/validate.go @@ -22,6 +22,13 @@ func Validate(tableName string) *ValidateOperator { } } +func (op *ValidateOperator) Clone() stream.Operator { + return &ValidateOperator{ + BaseOperator: op.BaseOperator.Clone(), + tableName: op.tableName, + } +} + func (op *ValidateOperator) Iterate(in *environment.Environment, fn func(out *environment.Environment) error) error { tx := in.GetTx() diff --git a/internal/stream/union.go b/internal/stream/union.go index 24d8e5d2..65d71ef9 100644 --- a/internal/stream/union.go +++ b/internal/stream/union.go @@ -22,6 +22,18 @@ func Union(s ...*Stream) *UnionOperator { return &UnionOperator{Streams: s} } +func (it *UnionOperator) Clone() Operator { + streams := make([]*Stream, len(it.Streams)) + for i, s := range it.Streams { + streams[i] = s.Clone() + } + + return &UnionOperator{ + BaseOperator: it.BaseOperator.Clone(), + Streams: streams, + } +} + // Iterate iterates over all the streams and returns their union. func (it *UnionOperator) Iterate(in *environment.Environment, fn func(out *environment.Environment) error) (err error) { var temp *tree.Tree