mirror of
https://github.com/chaisql/chai.git
synced 2025-09-26 19:51:21 +08:00
stream: add Columns method
This commit is contained in:
30
db.go
30
db.go
@@ -20,7 +20,6 @@ import (
|
||||
"github.com/chaisql/chai/internal/row"
|
||||
"github.com/chaisql/chai/internal/sql/parser"
|
||||
"github.com/chaisql/chai/internal/stream"
|
||||
"github.com/chaisql/chai/internal/stream/rows"
|
||||
"github.com/chaisql/chai/internal/types"
|
||||
"github.com/cockroachdb/errors"
|
||||
)
|
||||
@@ -411,35 +410,22 @@ func (r *Result) GetFirst() (*Row, error) {
|
||||
return rr, nil
|
||||
}
|
||||
|
||||
func (r *Result) Columns() []string {
|
||||
func (r *Result) Columns() ([]string, error) {
|
||||
if r.result.Iterator == nil {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
stmt, ok := r.result.Iterator.(*statement.StreamStmtIterator)
|
||||
if !ok || stmt.Stream.Op == nil {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Search for the ProjectOperator. If found, extract the projected expression list
|
||||
for op := stmt.Stream.First(); op != nil; op = op.GetNext() {
|
||||
if po, ok := op.(*rows.ProjectOperator); ok {
|
||||
// if there are no projected expression, it's a wildcard
|
||||
if len(po.Exprs) == 0 {
|
||||
break
|
||||
}
|
||||
var env environment.Environment
|
||||
env.DB = stmt.Context.DB
|
||||
env.Tx = stmt.Context.Tx
|
||||
env.SetParams(stmt.Context.Params)
|
||||
|
||||
columns := make([]string, len(po.Exprs))
|
||||
for i := range po.Exprs {
|
||||
columns[i] = po.Exprs[i].String()
|
||||
}
|
||||
|
||||
return columns
|
||||
}
|
||||
}
|
||||
|
||||
// the stream will output rows in a single field
|
||||
return []string{"*"}
|
||||
return stmt.Stream.Columns(&env)
|
||||
}
|
||||
|
||||
// Close the result stream.
|
||||
|
38
db_test.go
38
db_test.go
@@ -1,7 +1,6 @@
|
||||
package chai_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -10,7 +9,6 @@ import (
|
||||
"github.com/chaisql/chai"
|
||||
"github.com/chaisql/chai/internal/testutil"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func ExampleTx() {
|
||||
@@ -212,42 +210,6 @@ func TestQueryRow(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestPrepareThreadSafe(t *testing.T) {
|
||||
db, err := chai.Open(":memory:")
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
conn, err := db.Connect()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.Exec("CREATE TABLE test(a int unique, b text); INSERT INTO test(a, b) VALUES (1, 'a'), (2, 'a')")
|
||||
require.NoError(t, err)
|
||||
|
||||
stmt, err := conn.Prepare("SELECT COUNT(a) FROM test WHERE a < ? GROUP BY b ORDER BY a DESC LIMIT 5")
|
||||
require.NoError(t, err)
|
||||
|
||||
g, _ := errgroup.WithContext(context.Background())
|
||||
|
||||
for i := 1; i <= 3; i++ {
|
||||
arg := i
|
||||
g.Go(func() error {
|
||||
res, err := stmt.Query(arg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Close()
|
||||
|
||||
return res.Iterate(func(d *chai.Row) error {
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
err = g.Wait()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestIterateDeepCopy(t *testing.T) {
|
||||
db, err := chai.Open(":memory:")
|
||||
require.NoError(t, err)
|
||||
|
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
"github.com/chaisql/chai"
|
||||
"github.com/chaisql/chai/internal/environment"
|
||||
"github.com/chaisql/chai/internal/row"
|
||||
"github.com/chaisql/chai/internal/types"
|
||||
"github.com/cockroachdb/errors"
|
||||
)
|
||||
@@ -192,9 +191,7 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rs := newRows(res)
|
||||
rs.columns = res.Columns()
|
||||
return rs, nil
|
||||
return newRows(res)
|
||||
}
|
||||
|
||||
func namedValueToParams(args []driver.NamedValue) []any {
|
||||
@@ -229,7 +226,7 @@ type Row struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func newRows(res *chai.Result) *Rows {
|
||||
func newRows(res *chai.Result) (*Rows, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
rs := Rows{
|
||||
@@ -239,9 +236,16 @@ func newRows(res *chai.Result) *Rows {
|
||||
}
|
||||
rs.wg.Add(1)
|
||||
|
||||
cols, err := rs.res.Columns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rs.columns = cols
|
||||
|
||||
go rs.iterate(ctx)
|
||||
|
||||
return &rs
|
||||
return &rs, nil
|
||||
}
|
||||
|
||||
func (rs *Rows) iterate(ctx context.Context) {
|
||||
@@ -284,7 +288,7 @@ func (rs *Rows) iterate(ctx context.Context) {
|
||||
|
||||
// Columns returns the fields selected by the SELECT statement.
|
||||
func (rs *Rows) Columns() []string {
|
||||
return rs.res.Columns()
|
||||
return rs.columns
|
||||
}
|
||||
|
||||
// Close closes the rows iterator.
|
||||
@@ -376,26 +380,3 @@ func (rs *Rows) Next(dest []driver.Value) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type valueScanner struct {
|
||||
dest any
|
||||
}
|
||||
|
||||
func (v valueScanner) Scan(src any) error {
|
||||
if r, ok := src.(*chai.Row); ok {
|
||||
return r.StructScan(v.dest)
|
||||
}
|
||||
|
||||
vv, err := row.NewValue(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return row.ScanValue(vv, v.dest)
|
||||
}
|
||||
|
||||
// Scanner turns a variable into a sql.Scanner.
|
||||
// x must be a pointer to a valid variable.
|
||||
func Scanner(x any) sql.Scanner {
|
||||
return valueScanner{x}
|
||||
}
|
||||
|
@@ -40,7 +40,7 @@ func TestDriver(t *testing.T) {
|
||||
var count int
|
||||
var rt rowtest
|
||||
for rows.Next() {
|
||||
err = rows.Scan(Scanner(&rt))
|
||||
err = rows.Scan(&rt.A, &rt.B, &rt.C)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, rt)
|
||||
count++
|
||||
@@ -50,7 +50,7 @@ func TestDriver(t *testing.T) {
|
||||
require.Equal(t, 10, count)
|
||||
})
|
||||
|
||||
t.Run("Multiple fields", func(t *testing.T) {
|
||||
t.Run("Multiple columns", func(t *testing.T) {
|
||||
rows, err := db.Query("SELECT a, c FROM test")
|
||||
require.NoError(t, err)
|
||||
defer rows.Close()
|
||||
@@ -59,7 +59,7 @@ func TestDriver(t *testing.T) {
|
||||
var a int
|
||||
var c bool
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&a, Scanner(&c))
|
||||
err = rows.Scan(&a, &c)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, count, a)
|
||||
require.Equal(t, count%2 == 0, c)
|
||||
@@ -78,7 +78,7 @@ func TestDriver(t *testing.T) {
|
||||
var a int
|
||||
var c bool
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&a, Scanner(&c))
|
||||
err = rows.Scan(&a, &c)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, count, a)
|
||||
require.Equal(t, count%2 == 0, c)
|
||||
@@ -96,7 +96,7 @@ func TestDriver(t *testing.T) {
|
||||
var count int
|
||||
var rt rowtest
|
||||
for rows.Next() {
|
||||
err = rows.Scan(Scanner(&rt))
|
||||
err = rows.Scan(&rt.A, &rt.B, &rt.C)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, rt)
|
||||
count++
|
||||
@@ -113,7 +113,7 @@ func TestDriver(t *testing.T) {
|
||||
var count int
|
||||
var rt rowtest
|
||||
for rows.Next() {
|
||||
err = rows.Scan(Scanner(&rt))
|
||||
err = rows.Scan(&rt.A, &rt.B, &rt.C)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, rt)
|
||||
count++
|
||||
@@ -134,7 +134,7 @@ func TestDriver(t *testing.T) {
|
||||
var c bool
|
||||
var dt1, dt2 rowtest
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&a, Scanner(&aa), Scanner(&dt1), Scanner(&b), Scanner(&c), Scanner(&dt2))
|
||||
err = rows.Scan(&a, &aa, &dt1.A, &dt1.B, &dt1.C, &b, &c, &dt2.A, &dt2.B, &dt2.C)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, count, a)
|
||||
require.Equal(t, fmt.Sprintf("foo%d", count), b)
|
||||
@@ -192,11 +192,11 @@ func TestDriver(t *testing.T) {
|
||||
defer rows.Close()
|
||||
|
||||
var count int
|
||||
var dt rowtest
|
||||
var rt rowtest
|
||||
for rows.Next() {
|
||||
err = rows.Scan(Scanner(&dt))
|
||||
err = rows.Scan(&rt.A, &rt.B, &rt.C)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, dt)
|
||||
require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, rt)
|
||||
count++
|
||||
}
|
||||
require.NoError(t, rows.Err())
|
||||
@@ -213,11 +213,11 @@ func TestDriver(t *testing.T) {
|
||||
defer rows.Close()
|
||||
|
||||
var count int
|
||||
var dt rowtest
|
||||
var rt rowtest
|
||||
for rows.Next() {
|
||||
err = rows.Scan(Scanner(&dt))
|
||||
err = rows.Scan(&rt.A, &rt.B, &rt.C)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, dt)
|
||||
require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, rt)
|
||||
count++
|
||||
}
|
||||
require.NoError(t, rows.Err())
|
||||
@@ -238,11 +238,11 @@ func TestDriver(t *testing.T) {
|
||||
defer rows.Close()
|
||||
|
||||
var count int
|
||||
var dt rowtest
|
||||
var rt rowtest
|
||||
for rows.Next() {
|
||||
err = rows.Scan(Scanner(&dt))
|
||||
err = rows.Scan(&rt.A, &rt.B, &rt.C)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, dt)
|
||||
require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, rt)
|
||||
count++
|
||||
}
|
||||
require.NoError(t, rows.Err())
|
||||
@@ -277,7 +277,7 @@ func TestDriverWithTimeValues(t *testing.T) {
|
||||
defer tx.Rollback()
|
||||
|
||||
var tt time.Time
|
||||
err = tx.QueryRow(`SELECT a FROM test`).Scan(Scanner(&tt))
|
||||
err = tx.QueryRow(`SELECT a FROM test`).Scan(&tt)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, now, tt)
|
||||
}
|
||||
|
@@ -3,8 +3,6 @@ package driver_test
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/chaisql/chai/driver"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
@@ -43,7 +41,7 @@ func Example() {
|
||||
|
||||
for rows.Next() {
|
||||
var u User
|
||||
err = rows.Scan(driver.Scanner(&u))
|
||||
err = rows.Scan(&u.ID, &u.Name, &u.Age)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
2
go.mod
2
go.mod
@@ -9,7 +9,6 @@ require (
|
||||
github.com/golang-module/carbon/v2 v2.3.8
|
||||
github.com/google/go-cmp v0.6.0
|
||||
github.com/stretchr/testify v1.8.4
|
||||
golang.org/x/sync v0.6.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -34,6 +33,7 @@ require (
|
||||
github.com/prometheus/procfs v0.12.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.12.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240213143201-ec583247a57a // indirect
|
||||
golang.org/x/sync v0.6.0 // indirect
|
||||
golang.org/x/sys v0.17.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
google.golang.org/protobuf v1.32.0 // indirect
|
||||
|
@@ -37,6 +37,7 @@ func NewInsertStatement() *InsertStmt {
|
||||
func (stmt *InsertStmt) Prepare(c *Context) (Statement, error) {
|
||||
var s *stream.Stream
|
||||
|
||||
var columns []string
|
||||
if stmt.Values != nil {
|
||||
ti, err := c.Tx.Catalog.GetTableInfo(stmt.TableName)
|
||||
if err != nil {
|
||||
@@ -46,6 +47,7 @@ func (stmt *InsertStmt) Prepare(c *Context) (Statement, error) {
|
||||
var rowList []expr.Row
|
||||
// if no columns have been specified, we need to inject the columns from the defined table info
|
||||
if len(stmt.Columns) == 0 {
|
||||
|
||||
rowList = make([]expr.Row, 0, len(stmt.Values))
|
||||
for i := range stmt.Values {
|
||||
var r expr.Row
|
||||
@@ -60,9 +62,13 @@ func (stmt *InsertStmt) Prepare(c *Context) (Statement, error) {
|
||||
r.Columns = append(r.Columns, ti.ColumnConstraints.Ordered[i].Column)
|
||||
}
|
||||
|
||||
columns = r.Columns
|
||||
|
||||
rowList = append(rowList, r)
|
||||
}
|
||||
} else {
|
||||
columns = stmt.Columns
|
||||
|
||||
rowList = make([]expr.Row, 0, len(stmt.Values))
|
||||
for i := range stmt.Columns {
|
||||
_, ok := ti.ColumnConstraints.ByColumn[stmt.Columns[i]]
|
||||
@@ -81,10 +87,14 @@ func (stmt *InsertStmt) Prepare(c *Context) (Statement, error) {
|
||||
}
|
||||
|
||||
r.Columns = stmt.Columns
|
||||
if len(stmt.Columns) != len(r.Exprs) {
|
||||
return nil, errors.Errorf("expected %d columns, got %d", len(stmt.Columns), len(stmt.Values))
|
||||
}
|
||||
rowList = append(rowList, r)
|
||||
}
|
||||
}
|
||||
s = stream.New(rows.Emit(rowList...))
|
||||
|
||||
s = stream.New(rows.Emit(columns, rowList...))
|
||||
} else {
|
||||
selectStream, err := stmt.SelectStmt.Prepare(c)
|
||||
if err != nil {
|
||||
|
@@ -25,6 +25,7 @@ func TestParserInsert(t *testing.T) {
|
||||
}{
|
||||
{"Values / With fields", "INSERT INTO test (a, b) VALUES ('c', 'd')",
|
||||
stream.New(rows.Emit(
|
||||
[]string{"a", "b"},
|
||||
expr.Row{
|
||||
Columns: []string{"a", "b"},
|
||||
Exprs: []expr.Expr{
|
||||
@@ -41,6 +42,7 @@ func TestParserInsert(t *testing.T) {
|
||||
nil, true},
|
||||
{"Values / Multiple", "INSERT INTO test (a, b) VALUES ('c', 'd'), ('e', 'f')",
|
||||
stream.New(rows.Emit(
|
||||
[]string{"a", "b"},
|
||||
expr.Row{
|
||||
Columns: []string{"a", "b"},
|
||||
Exprs: []expr.Expr{
|
||||
@@ -62,6 +64,7 @@ func TestParserInsert(t *testing.T) {
|
||||
false},
|
||||
{"Values / Returning", "INSERT INTO test (a, b) VALUES ('c', 'd') RETURNING *, a, b as B, c",
|
||||
stream.New(rows.Emit(
|
||||
[]string{"a", "b"},
|
||||
expr.Row{
|
||||
Columns: []string{"a", "b"},
|
||||
Exprs: []expr.Expr{
|
||||
@@ -80,6 +83,7 @@ func TestParserInsert(t *testing.T) {
|
||||
nil, true},
|
||||
{"Values / ON CONFLICT DO NOTHING", "INSERT INTO test (a, b) VALUES ('c', 'd') ON CONFLICT DO NOTHING RETURNING *",
|
||||
stream.New(rows.Emit(
|
||||
[]string{"a", "b"},
|
||||
expr.Row{
|
||||
Columns: []string{"a", "b"},
|
||||
Exprs: []expr.Expr{
|
||||
@@ -95,6 +99,7 @@ func TestParserInsert(t *testing.T) {
|
||||
false},
|
||||
{"Values / ON CONFLICT IGNORE", "INSERT INTO test (a, b) VALUES ('c', 'd') ON CONFLICT IGNORE RETURNING *",
|
||||
stream.New(rows.Emit(
|
||||
[]string{"a", "b"},
|
||||
expr.Row{
|
||||
Columns: []string{"a", "b"},
|
||||
Exprs: []expr.Expr{
|
||||
@@ -109,6 +114,7 @@ func TestParserInsert(t *testing.T) {
|
||||
false},
|
||||
{"Values / ON CONFLICT DO REPLACE", "INSERT INTO test (a, b) VALUES ('c', 'd') ON CONFLICT DO REPLACE RETURNING *",
|
||||
stream.New(rows.Emit(
|
||||
[]string{"a", "b"},
|
||||
expr.Row{
|
||||
Columns: []string{"a", "b"},
|
||||
Exprs: []expr.Expr{
|
||||
@@ -124,6 +130,7 @@ func TestParserInsert(t *testing.T) {
|
||||
false},
|
||||
{"Values / ON CONFLICT REPLACE", "INSERT INTO test (a, b) VALUES ('c', 'd') ON CONFLICT REPLACE RETURNING *",
|
||||
stream.New(rows.Emit(
|
||||
[]string{"a", "b"},
|
||||
expr.Row{
|
||||
Columns: []string{"a", "b"},
|
||||
Exprs: []expr.Expr{
|
||||
|
@@ -29,6 +29,14 @@ func (it *ConcatOperator) Clone() Operator {
|
||||
}
|
||||
}
|
||||
|
||||
func (it *ConcatOperator) Columns(env *environment.Environment) ([]string, error) {
|
||||
if len(it.Streams) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return it.Streams[0].Columns(env)
|
||||
}
|
||||
|
||||
func (it *ConcatOperator) Iterate(in *environment.Environment, fn func(*environment.Environment) error) error {
|
||||
for _, s := range it.Streams {
|
||||
if err := s.Iterate(in, fn); err != nil {
|
||||
|
@@ -106,6 +106,27 @@ func (it *ScanOperator) Iterate(in *environment.Environment, fn func(out *enviro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *ScanOperator) Columns(env *environment.Environment) ([]string, error) {
|
||||
tx := env.GetTx()
|
||||
|
||||
idxInfo, err := tx.Catalog.GetIndexInfo(it.IndexName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info, err := tx.Catalog.GetTableInfo(idxInfo.Owner.TableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns := make([]string, len(info.ColumnConstraints.Ordered))
|
||||
for i, c := range info.ColumnConstraints.Ordered {
|
||||
columns[i] = c.Column
|
||||
}
|
||||
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (it *ScanOperator) String() string {
|
||||
var s strings.Builder
|
||||
|
||||
|
@@ -26,6 +26,7 @@ type Operator interface {
|
||||
GetPrev() Operator
|
||||
String() string
|
||||
Clone() Operator
|
||||
Columns(env *environment.Environment) ([]string, error)
|
||||
}
|
||||
|
||||
// An OperatorFunc is the function that will receive each value of the stream.
|
||||
@@ -64,3 +65,11 @@ func (op *BaseOperator) GetNext() Operator {
|
||||
func (op BaseOperator) Clone() BaseOperator {
|
||||
return op
|
||||
}
|
||||
|
||||
func (op *BaseOperator) Columns(env *environment.Environment) ([]string, error) {
|
||||
if op.Prev == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return op.Prev.Columns(env)
|
||||
}
|
||||
|
@@ -55,7 +55,7 @@ func TestFilter(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.e.String(), func(t *testing.T) {
|
||||
s := stream.New(rows.Emit(test.in...)).Pipe(rows.Filter(test.e))
|
||||
s := stream.New(rows.Emit([]string{"a"}, test.in...)).Pipe(rows.Filter(test.e))
|
||||
i := 0
|
||||
err := s.Iterate(new(environment.Environment), func(out *environment.Environment) error {
|
||||
r, _ := out.GetRow()
|
||||
@@ -97,7 +97,7 @@ func TestTake(t *testing.T) {
|
||||
ds = append(ds, testutil.MakeRowExpr(t, `{"a": `+strconv.Itoa(i)+`}`))
|
||||
}
|
||||
|
||||
s := stream.New(rows.Emit(ds...))
|
||||
s := stream.New(rows.Emit([]string{"a"}, ds...))
|
||||
s = s.Pipe(rows.Take(parser.MustParseExpr(strconv.Itoa(test.n))))
|
||||
|
||||
var count int
|
||||
@@ -142,7 +142,7 @@ func TestSkip(t *testing.T) {
|
||||
ds = append(ds, testutil.MakeRowExpr(t, `{"a": `+strconv.Itoa(i)+`}`))
|
||||
}
|
||||
|
||||
s := stream.New(rows.Emit(ds...))
|
||||
s := stream.New(rows.Emit([]string{"a"}, ds...))
|
||||
s = s.Pipe(rows.Skip(parser.MustParseExpr(strconv.Itoa(test.n))))
|
||||
|
||||
var count int
|
||||
@@ -174,7 +174,7 @@ func TestTableInsert(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
"doc with no key",
|
||||
rows.Emit(testutil.MakeRowExpr(t, `{"a": 10}`), testutil.MakeRowExpr(t, `{"a": 11}`)),
|
||||
rows.Emit([]string{"a"}, testutil.MakeRowExpr(t, `{"a": 10}`), testutil.MakeRowExpr(t, `{"a": 11}`)),
|
||||
[]row.Row{testutil.MakeRow(t, `{"a": 10}`), testutil.MakeRow(t, `{"a": 11}`)},
|
||||
1,
|
||||
false,
|
||||
|
@@ -89,6 +89,10 @@ func (op *RenameOperator) Iterate(in *environment.Environment, f func(out *envir
|
||||
})
|
||||
}
|
||||
|
||||
func (op *RenameOperator) Columns(env *environment.Environment) ([]string, error) {
|
||||
return op.ColumnNames, nil
|
||||
}
|
||||
|
||||
func (op *RenameOperator) String() string {
|
||||
return fmt.Sprintf("paths.Rename(%s)", strings.Join(op.ColumnNames, ", "))
|
||||
}
|
||||
|
@@ -41,7 +41,7 @@ func TestPathsRename(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
s := stream.New(rows.Emit(test.in...)).Pipe(path.PathsRename(test.fieldNames...))
|
||||
s := stream.New(rows.Emit([]string{"a", "b"}, test.in...)).Pipe(path.PathsRename(test.fieldNames...))
|
||||
t.Run(s.String(), func(t *testing.T) {
|
||||
i := 0
|
||||
err := s.Iterate(new(environment.Environment), func(out *environment.Environment) error {
|
||||
|
@@ -40,7 +40,7 @@ func TestSet(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.column, func(t *testing.T) {
|
||||
s := stream.New(rows.Emit(test.in...)).Pipe(path.Set(test.column, test.e))
|
||||
s := stream.New(rows.Emit([]string{"a"}, test.in...)).Pipe(path.Set(test.column, test.e))
|
||||
i := 0
|
||||
err := s.Iterate(new(environment.Environment), func(out *environment.Environment) error {
|
||||
r, _ := out.GetRow()
|
||||
|
@@ -10,13 +10,14 @@ import (
|
||||
|
||||
type EmitOperator struct {
|
||||
stream.BaseOperator
|
||||
Rows []expr.Row
|
||||
Rows []expr.Row
|
||||
columns []string
|
||||
}
|
||||
|
||||
// Emit creates an operator that iterates over the given expressions.
|
||||
// Each expression must evaluate to an row.
|
||||
func Emit(rows ...expr.Row) *EmitOperator {
|
||||
return &EmitOperator{Rows: rows}
|
||||
func Emit(columns []string, rows ...expr.Row) *EmitOperator {
|
||||
return &EmitOperator{columns: columns, Rows: rows}
|
||||
}
|
||||
|
||||
func (op *EmitOperator) Iterate(in *environment.Environment, fn func(out *environment.Environment) error) error {
|
||||
@@ -40,6 +41,10 @@ func (op *EmitOperator) Iterate(in *environment.Environment, fn func(out *enviro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *EmitOperator) Columns(env *environment.Environment) ([]string, error) {
|
||||
return it.columns, nil
|
||||
}
|
||||
|
||||
func (op *EmitOperator) Clone() stream.Operator {
|
||||
return &EmitOperator{
|
||||
BaseOperator: op.BaseOperator.Clone(),
|
||||
|
@@ -117,6 +117,19 @@ func (op *GroupAggregateOperator) Iterate(in *environment.Environment, f func(ou
|
||||
return f(e)
|
||||
}
|
||||
|
||||
func (op *GroupAggregateOperator) Columns(env *environment.Environment) ([]string, error) {
|
||||
columns := make([]string, 0, len(op.Builders)+1)
|
||||
if op.E != nil {
|
||||
columns = append(columns, op.E.String())
|
||||
}
|
||||
|
||||
for _, agg := range op.Builders {
|
||||
columns = append(columns, agg.String())
|
||||
}
|
||||
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (op *GroupAggregateOperator) String() string {
|
||||
var sb strings.Builder
|
||||
|
||||
|
@@ -35,6 +35,28 @@ func (op *ProjectOperator) Clone() stream.Operator {
|
||||
}
|
||||
}
|
||||
|
||||
func (op *ProjectOperator) Columns(env *environment.Environment) ([]string, error) {
|
||||
var cols, prev []string
|
||||
var err error
|
||||
|
||||
for _, e := range op.Exprs {
|
||||
if _, ok := e.(expr.Wildcard); ok {
|
||||
if prev == nil {
|
||||
prev, err = op.Prev.Columns(env)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
cols = append(cols, prev...)
|
||||
} else {
|
||||
cols = append(cols, e.String())
|
||||
}
|
||||
}
|
||||
|
||||
return cols, nil
|
||||
}
|
||||
|
||||
// Iterate implements the Operator interface.
|
||||
func (op *ProjectOperator) Iterate(in *environment.Environment, f func(out *environment.Environment) error) error {
|
||||
var mask RowMask
|
||||
|
@@ -26,6 +26,14 @@ func (s *Stream) Pipe(op Operator) *Stream {
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Stream) Columns(env *environment.Environment) ([]string, error) {
|
||||
if s.Op == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return s.Op.Columns(env)
|
||||
}
|
||||
|
||||
func (s *Stream) Iterate(in *environment.Environment, fn func(out *environment.Environment) error) error {
|
||||
if s.Op == nil {
|
||||
return nil
|
||||
|
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
func TestStream(t *testing.T) {
|
||||
s := stream.New(rows.Emit(
|
||||
[]string{"a"},
|
||||
testutil.MakeRowExpr(t, `{"a": 1}`),
|
||||
testutil.MakeRowExpr(t, `{"a": 2}`),
|
||||
))
|
||||
@@ -80,13 +81,13 @@ func TestUnion(t *testing.T) {
|
||||
|
||||
var streams []*stream.Stream
|
||||
if test.first != nil {
|
||||
streams = append(streams, stream.New(rows.Emit(test.first...)))
|
||||
streams = append(streams, stream.New(rows.Emit([]string{"a", "b"}, test.first...)))
|
||||
}
|
||||
if test.second != nil {
|
||||
streams = append(streams, stream.New(rows.Emit(test.second...)))
|
||||
streams = append(streams, stream.New(rows.Emit([]string{"a", "b"}, test.second...)))
|
||||
}
|
||||
if test.third != nil {
|
||||
streams = append(streams, stream.New(rows.Emit(test.third...)))
|
||||
streams = append(streams, stream.New(rows.Emit([]string{"a", "b"}, test.third...)))
|
||||
}
|
||||
|
||||
st := stream.New(stream.Union(streams...))
|
||||
@@ -100,9 +101,9 @@ func TestUnion(t *testing.T) {
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
st := stream.New(stream.Union(
|
||||
stream.New(rows.Emit(testutil.MakeRowExprs(t, `{"a": 1}`, `{"a": 2}`)...)),
|
||||
stream.New(rows.Emit(testutil.MakeRowExprs(t, `{"a": 3}`, `{"a": 4}`)...)),
|
||||
stream.New(rows.Emit(testutil.MakeRowExprs(t, `{"a": 5}`, `{"a": 6}`)...)),
|
||||
stream.New(rows.Emit([]string{"a"}, testutil.MakeRowExprs(t, `{"a": 1}`, `{"a": 2}`)...)),
|
||||
stream.New(rows.Emit([]string{"a"}, testutil.MakeRowExprs(t, `{"a": 3}`, `{"a": 4}`)...)),
|
||||
stream.New(rows.Emit([]string{"a"}, testutil.MakeRowExprs(t, `{"a": 5}`, `{"a": 6}`)...)),
|
||||
))
|
||||
|
||||
require.Equal(t, `union(rows.Emit((1), (2)), rows.Emit((3), (4)), rows.Emit((5), (6)))`, st.String())
|
||||
@@ -113,8 +114,8 @@ func TestConcatOperator(t *testing.T) {
|
||||
in1 := testutil.MakeRowExprs(t, `{"a": 10}`, `{"a": 11}`)
|
||||
in2 := testutil.MakeRowExprs(t, `{"a": 12}`, `{"a": 13}`)
|
||||
|
||||
s1 := stream.New(rows.Emit(in1...))
|
||||
s2 := stream.New(rows.Emit(in2...))
|
||||
s1 := stream.New(rows.Emit([]string{"a"}, in1...))
|
||||
s2 := stream.New(rows.Emit([]string{"a"}, in2...))
|
||||
s := stream.Concat(s1, s2)
|
||||
|
||||
var got []row.Row
|
||||
|
@@ -86,6 +86,22 @@ func (it *ScanOperator) Iterate(in *environment.Environment, fn func(out *enviro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *ScanOperator) Columns(env *environment.Environment) ([]string, error) {
|
||||
tx := env.GetTx()
|
||||
|
||||
info, err := tx.Catalog.GetTableInfo(it.TableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns := make([]string, len(info.ColumnConstraints.Ordered))
|
||||
for i, c := range info.ColumnConstraints.Ordered {
|
||||
columns[i] = c.Column
|
||||
}
|
||||
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (it *ScanOperator) String() string {
|
||||
var s strings.Builder
|
||||
|
||||
|
@@ -34,6 +34,14 @@ func (it *UnionOperator) Clone() Operator {
|
||||
}
|
||||
}
|
||||
|
||||
func (it *UnionOperator) Columns(env *environment.Environment) ([]string, error) {
|
||||
if len(it.Streams) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return it.Streams[0].Columns(env)
|
||||
}
|
||||
|
||||
// Iterate iterates over all the streams and returns their union.
|
||||
func (it *UnionOperator) Iterate(in *environment.Environment, fn func(out *environment.Environment) error) (err error) {
|
||||
var temp *tree.Tree
|
||||
|
Reference in New Issue
Block a user