stream: add Columns method

This commit is contained in:
Asdine El Hrychy
2024-02-18 13:46:09 +04:00
parent 6bc4992d70
commit 6f8c2d2b47
22 changed files with 188 additions and 129 deletions

30
db.go
View File

@@ -20,7 +20,6 @@ import (
"github.com/chaisql/chai/internal/row"
"github.com/chaisql/chai/internal/sql/parser"
"github.com/chaisql/chai/internal/stream"
"github.com/chaisql/chai/internal/stream/rows"
"github.com/chaisql/chai/internal/types"
"github.com/cockroachdb/errors"
)
@@ -411,35 +410,22 @@ func (r *Result) GetFirst() (*Row, error) {
return rr, nil
}
func (r *Result) Columns() []string {
func (r *Result) Columns() ([]string, error) {
if r.result.Iterator == nil {
return nil
return nil, nil
}
stmt, ok := r.result.Iterator.(*statement.StreamStmtIterator)
if !ok || stmt.Stream.Op == nil {
return nil
return nil, nil
}
// Search for the ProjectOperator. If found, extract the projected expression list
for op := stmt.Stream.First(); op != nil; op = op.GetNext() {
if po, ok := op.(*rows.ProjectOperator); ok {
// if there are no projected expression, it's a wildcard
if len(po.Exprs) == 0 {
break
}
var env environment.Environment
env.DB = stmt.Context.DB
env.Tx = stmt.Context.Tx
env.SetParams(stmt.Context.Params)
columns := make([]string, len(po.Exprs))
for i := range po.Exprs {
columns[i] = po.Exprs[i].String()
}
return columns
}
}
// the stream will output rows in a single field
return []string{"*"}
return stmt.Stream.Columns(&env)
}
// Close the result stream.

View File

@@ -1,7 +1,6 @@
package chai_test
import (
"context"
"fmt"
"os"
"path/filepath"
@@ -10,7 +9,6 @@ import (
"github.com/chaisql/chai"
"github.com/chaisql/chai/internal/testutil"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
)
func ExampleTx() {
@@ -212,42 +210,6 @@ func TestQueryRow(t *testing.T) {
})
}
func TestPrepareThreadSafe(t *testing.T) {
db, err := chai.Open(":memory:")
require.NoError(t, err)
defer db.Close()
conn, err := db.Connect()
require.NoError(t, err)
defer conn.Close()
err = conn.Exec("CREATE TABLE test(a int unique, b text); INSERT INTO test(a, b) VALUES (1, 'a'), (2, 'a')")
require.NoError(t, err)
stmt, err := conn.Prepare("SELECT COUNT(a) FROM test WHERE a < ? GROUP BY b ORDER BY a DESC LIMIT 5")
require.NoError(t, err)
g, _ := errgroup.WithContext(context.Background())
for i := 1; i <= 3; i++ {
arg := i
g.Go(func() error {
res, err := stmt.Query(arg)
if err != nil {
return err
}
defer res.Close()
return res.Iterate(func(d *chai.Row) error {
return nil
})
})
}
err = g.Wait()
require.NoError(t, err)
}
func TestIterateDeepCopy(t *testing.T) {
db, err := chai.Open(":memory:")
require.NoError(t, err)

View File

@@ -10,7 +10,6 @@ import (
"github.com/chaisql/chai"
"github.com/chaisql/chai/internal/environment"
"github.com/chaisql/chai/internal/row"
"github.com/chaisql/chai/internal/types"
"github.com/cockroachdb/errors"
)
@@ -192,9 +191,7 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
return nil, err
}
rs := newRows(res)
rs.columns = res.Columns()
return rs, nil
return newRows(res)
}
func namedValueToParams(args []driver.NamedValue) []any {
@@ -229,7 +226,7 @@ type Row struct {
err error
}
func newRows(res *chai.Result) *Rows {
func newRows(res *chai.Result) (*Rows, error) {
ctx, cancel := context.WithCancel(context.Background())
rs := Rows{
@@ -239,9 +236,16 @@ func newRows(res *chai.Result) *Rows {
}
rs.wg.Add(1)
cols, err := rs.res.Columns()
if err != nil {
return nil, err
}
rs.columns = cols
go rs.iterate(ctx)
return &rs
return &rs, nil
}
func (rs *Rows) iterate(ctx context.Context) {
@@ -284,7 +288,7 @@ func (rs *Rows) iterate(ctx context.Context) {
// Columns returns the fields selected by the SELECT statement.
func (rs *Rows) Columns() []string {
return rs.res.Columns()
return rs.columns
}
// Close closes the rows iterator.
@@ -376,26 +380,3 @@ func (rs *Rows) Next(dest []driver.Value) error {
return nil
}
type valueScanner struct {
dest any
}
func (v valueScanner) Scan(src any) error {
if r, ok := src.(*chai.Row); ok {
return r.StructScan(v.dest)
}
vv, err := row.NewValue(src)
if err != nil {
return err
}
return row.ScanValue(vv, v.dest)
}
// Scanner turns a variable into a sql.Scanner.
// x must be a pointer to a valid variable.
func Scanner(x any) sql.Scanner {
return valueScanner{x}
}

View File

@@ -40,7 +40,7 @@ func TestDriver(t *testing.T) {
var count int
var rt rowtest
for rows.Next() {
err = rows.Scan(Scanner(&rt))
err = rows.Scan(&rt.A, &rt.B, &rt.C)
require.NoError(t, err)
require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, rt)
count++
@@ -50,7 +50,7 @@ func TestDriver(t *testing.T) {
require.Equal(t, 10, count)
})
t.Run("Multiple fields", func(t *testing.T) {
t.Run("Multiple columns", func(t *testing.T) {
rows, err := db.Query("SELECT a, c FROM test")
require.NoError(t, err)
defer rows.Close()
@@ -59,7 +59,7 @@ func TestDriver(t *testing.T) {
var a int
var c bool
for rows.Next() {
err = rows.Scan(&a, Scanner(&c))
err = rows.Scan(&a, &c)
require.NoError(t, err)
require.Equal(t, count, a)
require.Equal(t, count%2 == 0, c)
@@ -78,7 +78,7 @@ func TestDriver(t *testing.T) {
var a int
var c bool
for rows.Next() {
err = rows.Scan(&a, Scanner(&c))
err = rows.Scan(&a, &c)
require.NoError(t, err)
require.Equal(t, count, a)
require.Equal(t, count%2 == 0, c)
@@ -96,7 +96,7 @@ func TestDriver(t *testing.T) {
var count int
var rt rowtest
for rows.Next() {
err = rows.Scan(Scanner(&rt))
err = rows.Scan(&rt.A, &rt.B, &rt.C)
require.NoError(t, err)
require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, rt)
count++
@@ -113,7 +113,7 @@ func TestDriver(t *testing.T) {
var count int
var rt rowtest
for rows.Next() {
err = rows.Scan(Scanner(&rt))
err = rows.Scan(&rt.A, &rt.B, &rt.C)
require.NoError(t, err)
require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, rt)
count++
@@ -134,7 +134,7 @@ func TestDriver(t *testing.T) {
var c bool
var dt1, dt2 rowtest
for rows.Next() {
err = rows.Scan(&a, Scanner(&aa), Scanner(&dt1), Scanner(&b), Scanner(&c), Scanner(&dt2))
err = rows.Scan(&a, &aa, &dt1.A, &dt1.B, &dt1.C, &b, &c, &dt2.A, &dt2.B, &dt2.C)
require.NoError(t, err)
require.Equal(t, count, a)
require.Equal(t, fmt.Sprintf("foo%d", count), b)
@@ -192,11 +192,11 @@ func TestDriver(t *testing.T) {
defer rows.Close()
var count int
var dt rowtest
var rt rowtest
for rows.Next() {
err = rows.Scan(Scanner(&dt))
err = rows.Scan(&rt.A, &rt.B, &rt.C)
require.NoError(t, err)
require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, dt)
require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, rt)
count++
}
require.NoError(t, rows.Err())
@@ -213,11 +213,11 @@ func TestDriver(t *testing.T) {
defer rows.Close()
var count int
var dt rowtest
var rt rowtest
for rows.Next() {
err = rows.Scan(Scanner(&dt))
err = rows.Scan(&rt.A, &rt.B, &rt.C)
require.NoError(t, err)
require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, dt)
require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, rt)
count++
}
require.NoError(t, rows.Err())
@@ -238,11 +238,11 @@ func TestDriver(t *testing.T) {
defer rows.Close()
var count int
var dt rowtest
var rt rowtest
for rows.Next() {
err = rows.Scan(Scanner(&dt))
err = rows.Scan(&rt.A, &rt.B, &rt.C)
require.NoError(t, err)
require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, dt)
require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, rt)
count++
}
require.NoError(t, rows.Err())
@@ -277,7 +277,7 @@ func TestDriverWithTimeValues(t *testing.T) {
defer tx.Rollback()
var tt time.Time
err = tx.QueryRow(`SELECT a FROM test`).Scan(Scanner(&tt))
err = tx.QueryRow(`SELECT a FROM test`).Scan(&tt)
require.NoError(t, err)
require.Equal(t, now, tt)
}

View File

@@ -3,8 +3,6 @@ package driver_test
import (
"database/sql"
"fmt"
"github.com/chaisql/chai/driver"
)
type User struct {
@@ -43,7 +41,7 @@ func Example() {
for rows.Next() {
var u User
err = rows.Scan(driver.Scanner(&u))
err = rows.Scan(&u.ID, &u.Name, &u.Age)
if err != nil {
panic(err)
}

2
go.mod
View File

@@ -9,7 +9,6 @@ require (
github.com/golang-module/carbon/v2 v2.3.8
github.com/google/go-cmp v0.6.0
github.com/stretchr/testify v1.8.4
golang.org/x/sync v0.6.0
)
require (
@@ -34,6 +33,7 @@ require (
github.com/prometheus/procfs v0.12.0 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
golang.org/x/exp v0.0.0-20240213143201-ec583247a57a // indirect
golang.org/x/sync v0.6.0 // indirect
golang.org/x/sys v0.17.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.32.0 // indirect

View File

@@ -37,6 +37,7 @@ func NewInsertStatement() *InsertStmt {
func (stmt *InsertStmt) Prepare(c *Context) (Statement, error) {
var s *stream.Stream
var columns []string
if stmt.Values != nil {
ti, err := c.Tx.Catalog.GetTableInfo(stmt.TableName)
if err != nil {
@@ -46,6 +47,7 @@ func (stmt *InsertStmt) Prepare(c *Context) (Statement, error) {
var rowList []expr.Row
// if no columns have been specified, we need to inject the columns from the defined table info
if len(stmt.Columns) == 0 {
rowList = make([]expr.Row, 0, len(stmt.Values))
for i := range stmt.Values {
var r expr.Row
@@ -60,9 +62,13 @@ func (stmt *InsertStmt) Prepare(c *Context) (Statement, error) {
r.Columns = append(r.Columns, ti.ColumnConstraints.Ordered[i].Column)
}
columns = r.Columns
rowList = append(rowList, r)
}
} else {
columns = stmt.Columns
rowList = make([]expr.Row, 0, len(stmt.Values))
for i := range stmt.Columns {
_, ok := ti.ColumnConstraints.ByColumn[stmt.Columns[i]]
@@ -81,10 +87,14 @@ func (stmt *InsertStmt) Prepare(c *Context) (Statement, error) {
}
r.Columns = stmt.Columns
if len(stmt.Columns) != len(r.Exprs) {
return nil, errors.Errorf("expected %d columns, got %d", len(stmt.Columns), len(stmt.Values))
}
rowList = append(rowList, r)
}
}
s = stream.New(rows.Emit(rowList...))
s = stream.New(rows.Emit(columns, rowList...))
} else {
selectStream, err := stmt.SelectStmt.Prepare(c)
if err != nil {

View File

@@ -25,6 +25,7 @@ func TestParserInsert(t *testing.T) {
}{
{"Values / With fields", "INSERT INTO test (a, b) VALUES ('c', 'd')",
stream.New(rows.Emit(
[]string{"a", "b"},
expr.Row{
Columns: []string{"a", "b"},
Exprs: []expr.Expr{
@@ -41,6 +42,7 @@ func TestParserInsert(t *testing.T) {
nil, true},
{"Values / Multiple", "INSERT INTO test (a, b) VALUES ('c', 'd'), ('e', 'f')",
stream.New(rows.Emit(
[]string{"a", "b"},
expr.Row{
Columns: []string{"a", "b"},
Exprs: []expr.Expr{
@@ -62,6 +64,7 @@ func TestParserInsert(t *testing.T) {
false},
{"Values / Returning", "INSERT INTO test (a, b) VALUES ('c', 'd') RETURNING *, a, b as B, c",
stream.New(rows.Emit(
[]string{"a", "b"},
expr.Row{
Columns: []string{"a", "b"},
Exprs: []expr.Expr{
@@ -80,6 +83,7 @@ func TestParserInsert(t *testing.T) {
nil, true},
{"Values / ON CONFLICT DO NOTHING", "INSERT INTO test (a, b) VALUES ('c', 'd') ON CONFLICT DO NOTHING RETURNING *",
stream.New(rows.Emit(
[]string{"a", "b"},
expr.Row{
Columns: []string{"a", "b"},
Exprs: []expr.Expr{
@@ -95,6 +99,7 @@ func TestParserInsert(t *testing.T) {
false},
{"Values / ON CONFLICT IGNORE", "INSERT INTO test (a, b) VALUES ('c', 'd') ON CONFLICT IGNORE RETURNING *",
stream.New(rows.Emit(
[]string{"a", "b"},
expr.Row{
Columns: []string{"a", "b"},
Exprs: []expr.Expr{
@@ -109,6 +114,7 @@ func TestParserInsert(t *testing.T) {
false},
{"Values / ON CONFLICT DO REPLACE", "INSERT INTO test (a, b) VALUES ('c', 'd') ON CONFLICT DO REPLACE RETURNING *",
stream.New(rows.Emit(
[]string{"a", "b"},
expr.Row{
Columns: []string{"a", "b"},
Exprs: []expr.Expr{
@@ -124,6 +130,7 @@ func TestParserInsert(t *testing.T) {
false},
{"Values / ON CONFLICT REPLACE", "INSERT INTO test (a, b) VALUES ('c', 'd') ON CONFLICT REPLACE RETURNING *",
stream.New(rows.Emit(
[]string{"a", "b"},
expr.Row{
Columns: []string{"a", "b"},
Exprs: []expr.Expr{

View File

@@ -29,6 +29,14 @@ func (it *ConcatOperator) Clone() Operator {
}
}
func (it *ConcatOperator) Columns(env *environment.Environment) ([]string, error) {
if len(it.Streams) == 0 {
return nil, nil
}
return it.Streams[0].Columns(env)
}
func (it *ConcatOperator) Iterate(in *environment.Environment, fn func(*environment.Environment) error) error {
for _, s := range it.Streams {
if err := s.Iterate(in, fn); err != nil {

View File

@@ -106,6 +106,27 @@ func (it *ScanOperator) Iterate(in *environment.Environment, fn func(out *enviro
return nil
}
func (it *ScanOperator) Columns(env *environment.Environment) ([]string, error) {
tx := env.GetTx()
idxInfo, err := tx.Catalog.GetIndexInfo(it.IndexName)
if err != nil {
return nil, err
}
info, err := tx.Catalog.GetTableInfo(idxInfo.Owner.TableName)
if err != nil {
return nil, err
}
columns := make([]string, len(info.ColumnConstraints.Ordered))
for i, c := range info.ColumnConstraints.Ordered {
columns[i] = c.Column
}
return columns, nil
}
func (it *ScanOperator) String() string {
var s strings.Builder

View File

@@ -26,6 +26,7 @@ type Operator interface {
GetPrev() Operator
String() string
Clone() Operator
Columns(env *environment.Environment) ([]string, error)
}
// An OperatorFunc is the function that will receive each value of the stream.
@@ -64,3 +65,11 @@ func (op *BaseOperator) GetNext() Operator {
func (op BaseOperator) Clone() BaseOperator {
return op
}
func (op *BaseOperator) Columns(env *environment.Environment) ([]string, error) {
if op.Prev == nil {
return nil, nil
}
return op.Prev.Columns(env)
}

View File

@@ -55,7 +55,7 @@ func TestFilter(t *testing.T) {
for _, test := range tests {
t.Run(test.e.String(), func(t *testing.T) {
s := stream.New(rows.Emit(test.in...)).Pipe(rows.Filter(test.e))
s := stream.New(rows.Emit([]string{"a"}, test.in...)).Pipe(rows.Filter(test.e))
i := 0
err := s.Iterate(new(environment.Environment), func(out *environment.Environment) error {
r, _ := out.GetRow()
@@ -97,7 +97,7 @@ func TestTake(t *testing.T) {
ds = append(ds, testutil.MakeRowExpr(t, `{"a": `+strconv.Itoa(i)+`}`))
}
s := stream.New(rows.Emit(ds...))
s := stream.New(rows.Emit([]string{"a"}, ds...))
s = s.Pipe(rows.Take(parser.MustParseExpr(strconv.Itoa(test.n))))
var count int
@@ -142,7 +142,7 @@ func TestSkip(t *testing.T) {
ds = append(ds, testutil.MakeRowExpr(t, `{"a": `+strconv.Itoa(i)+`}`))
}
s := stream.New(rows.Emit(ds...))
s := stream.New(rows.Emit([]string{"a"}, ds...))
s = s.Pipe(rows.Skip(parser.MustParseExpr(strconv.Itoa(test.n))))
var count int
@@ -174,7 +174,7 @@ func TestTableInsert(t *testing.T) {
}{
{
"doc with no key",
rows.Emit(testutil.MakeRowExpr(t, `{"a": 10}`), testutil.MakeRowExpr(t, `{"a": 11}`)),
rows.Emit([]string{"a"}, testutil.MakeRowExpr(t, `{"a": 10}`), testutil.MakeRowExpr(t, `{"a": 11}`)),
[]row.Row{testutil.MakeRow(t, `{"a": 10}`), testutil.MakeRow(t, `{"a": 11}`)},
1,
false,

View File

@@ -89,6 +89,10 @@ func (op *RenameOperator) Iterate(in *environment.Environment, f func(out *envir
})
}
func (op *RenameOperator) Columns(env *environment.Environment) ([]string, error) {
return op.ColumnNames, nil
}
func (op *RenameOperator) String() string {
return fmt.Sprintf("paths.Rename(%s)", strings.Join(op.ColumnNames, ", "))
}

View File

@@ -41,7 +41,7 @@ func TestPathsRename(t *testing.T) {
}
for _, test := range tests {
s := stream.New(rows.Emit(test.in...)).Pipe(path.PathsRename(test.fieldNames...))
s := stream.New(rows.Emit([]string{"a", "b"}, test.in...)).Pipe(path.PathsRename(test.fieldNames...))
t.Run(s.String(), func(t *testing.T) {
i := 0
err := s.Iterate(new(environment.Environment), func(out *environment.Environment) error {

View File

@@ -40,7 +40,7 @@ func TestSet(t *testing.T) {
for _, test := range tests {
t.Run(test.column, func(t *testing.T) {
s := stream.New(rows.Emit(test.in...)).Pipe(path.Set(test.column, test.e))
s := stream.New(rows.Emit([]string{"a"}, test.in...)).Pipe(path.Set(test.column, test.e))
i := 0
err := s.Iterate(new(environment.Environment), func(out *environment.Environment) error {
r, _ := out.GetRow()

View File

@@ -10,13 +10,14 @@ import (
type EmitOperator struct {
stream.BaseOperator
Rows []expr.Row
Rows []expr.Row
columns []string
}
// Emit creates an operator that iterates over the given expressions.
// Each expression must evaluate to an row.
func Emit(rows ...expr.Row) *EmitOperator {
return &EmitOperator{Rows: rows}
func Emit(columns []string, rows ...expr.Row) *EmitOperator {
return &EmitOperator{columns: columns, Rows: rows}
}
func (op *EmitOperator) Iterate(in *environment.Environment, fn func(out *environment.Environment) error) error {
@@ -40,6 +41,10 @@ func (op *EmitOperator) Iterate(in *environment.Environment, fn func(out *enviro
return nil
}
func (it *EmitOperator) Columns(env *environment.Environment) ([]string, error) {
return it.columns, nil
}
func (op *EmitOperator) Clone() stream.Operator {
return &EmitOperator{
BaseOperator: op.BaseOperator.Clone(),

View File

@@ -117,6 +117,19 @@ func (op *GroupAggregateOperator) Iterate(in *environment.Environment, f func(ou
return f(e)
}
func (op *GroupAggregateOperator) Columns(env *environment.Environment) ([]string, error) {
columns := make([]string, 0, len(op.Builders)+1)
if op.E != nil {
columns = append(columns, op.E.String())
}
for _, agg := range op.Builders {
columns = append(columns, agg.String())
}
return columns, nil
}
func (op *GroupAggregateOperator) String() string {
var sb strings.Builder

View File

@@ -35,6 +35,28 @@ func (op *ProjectOperator) Clone() stream.Operator {
}
}
func (op *ProjectOperator) Columns(env *environment.Environment) ([]string, error) {
var cols, prev []string
var err error
for _, e := range op.Exprs {
if _, ok := e.(expr.Wildcard); ok {
if prev == nil {
prev, err = op.Prev.Columns(env)
if err != nil {
return nil, err
}
}
cols = append(cols, prev...)
} else {
cols = append(cols, e.String())
}
}
return cols, nil
}
// Iterate implements the Operator interface.
func (op *ProjectOperator) Iterate(in *environment.Environment, f func(out *environment.Environment) error) error {
var mask RowMask

View File

@@ -26,6 +26,14 @@ func (s *Stream) Pipe(op Operator) *Stream {
return s
}
func (s *Stream) Columns(env *environment.Environment) ([]string, error) {
if s.Op == nil {
return nil, nil
}
return s.Op.Columns(env)
}
func (s *Stream) Iterate(in *environment.Environment, fn func(out *environment.Environment) error) error {
if s.Op == nil {
return nil

View File

@@ -17,6 +17,7 @@ import (
func TestStream(t *testing.T) {
s := stream.New(rows.Emit(
[]string{"a"},
testutil.MakeRowExpr(t, `{"a": 1}`),
testutil.MakeRowExpr(t, `{"a": 2}`),
))
@@ -80,13 +81,13 @@ func TestUnion(t *testing.T) {
var streams []*stream.Stream
if test.first != nil {
streams = append(streams, stream.New(rows.Emit(test.first...)))
streams = append(streams, stream.New(rows.Emit([]string{"a", "b"}, test.first...)))
}
if test.second != nil {
streams = append(streams, stream.New(rows.Emit(test.second...)))
streams = append(streams, stream.New(rows.Emit([]string{"a", "b"}, test.second...)))
}
if test.third != nil {
streams = append(streams, stream.New(rows.Emit(test.third...)))
streams = append(streams, stream.New(rows.Emit([]string{"a", "b"}, test.third...)))
}
st := stream.New(stream.Union(streams...))
@@ -100,9 +101,9 @@ func TestUnion(t *testing.T) {
t.Run("String", func(t *testing.T) {
st := stream.New(stream.Union(
stream.New(rows.Emit(testutil.MakeRowExprs(t, `{"a": 1}`, `{"a": 2}`)...)),
stream.New(rows.Emit(testutil.MakeRowExprs(t, `{"a": 3}`, `{"a": 4}`)...)),
stream.New(rows.Emit(testutil.MakeRowExprs(t, `{"a": 5}`, `{"a": 6}`)...)),
stream.New(rows.Emit([]string{"a"}, testutil.MakeRowExprs(t, `{"a": 1}`, `{"a": 2}`)...)),
stream.New(rows.Emit([]string{"a"}, testutil.MakeRowExprs(t, `{"a": 3}`, `{"a": 4}`)...)),
stream.New(rows.Emit([]string{"a"}, testutil.MakeRowExprs(t, `{"a": 5}`, `{"a": 6}`)...)),
))
require.Equal(t, `union(rows.Emit((1), (2)), rows.Emit((3), (4)), rows.Emit((5), (6)))`, st.String())
@@ -113,8 +114,8 @@ func TestConcatOperator(t *testing.T) {
in1 := testutil.MakeRowExprs(t, `{"a": 10}`, `{"a": 11}`)
in2 := testutil.MakeRowExprs(t, `{"a": 12}`, `{"a": 13}`)
s1 := stream.New(rows.Emit(in1...))
s2 := stream.New(rows.Emit(in2...))
s1 := stream.New(rows.Emit([]string{"a"}, in1...))
s2 := stream.New(rows.Emit([]string{"a"}, in2...))
s := stream.Concat(s1, s2)
var got []row.Row

View File

@@ -86,6 +86,22 @@ func (it *ScanOperator) Iterate(in *environment.Environment, fn func(out *enviro
return nil
}
func (it *ScanOperator) Columns(env *environment.Environment) ([]string, error) {
tx := env.GetTx()
info, err := tx.Catalog.GetTableInfo(it.TableName)
if err != nil {
return nil, err
}
columns := make([]string, len(info.ColumnConstraints.Ordered))
for i, c := range info.ColumnConstraints.Ordered {
columns[i] = c.Column
}
return columns, nil
}
func (it *ScanOperator) String() string {
var s strings.Builder

View File

@@ -34,6 +34,14 @@ func (it *UnionOperator) Clone() Operator {
}
}
func (it *UnionOperator) Columns(env *environment.Environment) ([]string, error) {
if len(it.Streams) == 0 {
return nil, nil
}
return it.Streams[0].Columns(env)
}
// Iterate iterates over all the streams and returns their union.
func (it *UnionOperator) Iterate(in *environment.Environment, fn func(out *environment.Environment) error) (err error) {
var temp *tree.Tree