diff --git a/internal/database/catalogstore/store.go b/internal/database/catalogstore/store.go index d978b31b..9546fb59 100644 --- a/internal/database/catalogstore/store.go +++ b/internal/database/catalogstore/store.go @@ -81,36 +81,52 @@ func loadSequences(tx *database.Transaction, info []database.SequenceInfo) ([]da func loadCatalogStore(tx *database.Transaction, s *database.CatalogStore) (tables []database.TableInfo, indexes []database.IndexInfo, sequences []database.SequenceInfo, err error) { tb := s.Table(tx) - err = tb.IterateOnRange(nil, false, func(key *tree.Key, r database.Row) error { + it, err := tb.Iterator(nil) + if err != nil { + return nil, nil, nil, err + } + defer it.Close() + + // iterate over all the rows in the catalog store + // and load the tables and indexes + for it.First(); it.Valid(); it.Next() { + r, err := it.Value() + if err != nil { + return nil, nil, nil, err + } + tp, err := r.Get("type") if err != nil { - return err + return nil, nil, nil, err } switch types.AsString(tp) { case database.RelationTableType: ti, err := tableInfoFromRow(r) if err != nil { - return errors.Wrap(err, "failed to decode table info") + return nil, nil, nil, errors.Wrap(err, "failed to decode table info") } tables = append(tables, *ti) case database.RelationIndexType: i, err := indexInfoFromRow(r) if err != nil { - return errors.Wrap(err, "failed to decode index info") + return nil, nil, nil, errors.Wrap(err, "failed to decode index info") } indexes = append(indexes, *i) case database.RelationSequenceType: i, err := sequenceInfoFromRow(r) if err != nil { - return errors.Wrap(err, "failed to decode sequence info") + return nil, nil, nil, errors.Wrap(err, "failed to decode sequence info") } sequences = append(sequences, *i) } + } + + if err := it.Error(); err != nil { + return nil, nil, nil, err + } - return nil - }) return } diff --git a/internal/database/iteration.go b/internal/database/iteration.go index 946a6e60..5f1fc9f6 100644 --- a/internal/database/iteration.go +++ b/internal/database/iteration.go @@ -74,3 +74,33 @@ func (r *Range) IsEqual(other *Range) bool { return true } + +type Iterator struct { + *tree.Iterator + e EncodedRow + row BasicRow +} + +func newIterator(ti *tree.Iterator, tableName string, columnConstraints *ColumnConstraints) *Iterator { + it := Iterator{ + Iterator: ti, + } + + it.e.columnConstraints = columnConstraints + it.row.tableName = tableName + it.row.Row = &it.e + + return &it +} + +func (it *Iterator) Value() (Row, error) { + var err error + + it.row.key = it.Iterator.Key() + it.e.encoded, err = it.Iterator.Value() + if err != nil { + return nil, err + } + + return &it.row, nil +} diff --git a/internal/database/table.go b/internal/database/table.go index 11a17d5a..df64fac4 100644 --- a/internal/database/table.go +++ b/internal/database/table.go @@ -141,7 +141,7 @@ func (t *Table) Put(key *tree.Key, r row.Row) (Row, error) { }, err } -func (t *Table) IterateOnRange(rng *Range, reverse bool, fn func(key *tree.Key, r Row) error) error { +func (t *Table) Iterator(rng *Range) (*Iterator, error) { var columns []string pk := t.Info.PrimaryKey @@ -155,51 +155,16 @@ func (t *Table) IterateOnRange(rng *Range, reverse bool, fn func(key *tree.Key, if rng != nil { r, err = rng.ToTreeRange(&t.Info.ColumnConstraints, columns) if err != nil { - return err + return nil, err } } - e := EncodedRow{ - columnConstraints: &t.Info.ColumnConstraints, - } - row := BasicRow{ - tableName: t.Info.TableName, - Row: &e, - } - it, err := t.Tree.Iterator(r) if err != nil { - return err - } - defer it.Close() - - if reverse { - it.Last() - } else { - it.First() + return nil, err } - for it.Valid() { - k := it.Key() - enc, err := it.Value() - if err != nil { - return err - } - - row.key = k - e.encoded = enc - if err := fn(k, &row); err != nil { - return err - } - - if reverse { - it.Prev() - } else { - it.Next() - } - } - - return it.Error() + return newIterator(it, t.Info.TableName, &t.Info.ColumnConstraints), nil } // GetRow returns one row by key. diff --git a/internal/database/table_test.go b/internal/database/table_test.go index ae760baf..77e86fd1 100644 --- a/internal/database/table_test.go +++ b/internal/database/table_test.go @@ -318,11 +318,12 @@ func TestTableTruncate(t *testing.T) { err = tb.Truncate() require.NoError(t, err) - err = tb.IterateOnRange(nil, false, func(key *tree.Key, _ database.Row) error { - return errors.New("should not iterate") - }) - + it, err := tb.Iterator(nil) require.NoError(t, err) + defer it.Close() + + it.First() + require.False(t, it.Valid()) }) } @@ -372,9 +373,17 @@ func BenchmarkTableScan(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _ = tb.IterateOnRange(nil, false, func(*tree.Key, database.Row) error { - return nil - }) + it, err := tb.Iterator(nil) + require.NoError(b, err) + + for it.First(); it.Valid(); it.Next() { + } + + it.Close() + + // _ = tb.IterateOnRange(nil, false, func(*tree.Key, database.Row) error { + // return nil + // }) } b.StopTimer() }) diff --git a/internal/stream/table/scan.go b/internal/stream/table/scan.go index 006bb08d..ee17994d 100644 --- a/internal/stream/table/scan.go +++ b/internal/stream/table/scan.go @@ -7,7 +7,6 @@ import ( "github.com/chaisql/chai/internal/database" "github.com/chaisql/chai/internal/environment" "github.com/chaisql/chai/internal/stream" - "github.com/chaisql/chai/internal/tree" "github.com/cockroachdb/errors" ) @@ -45,14 +44,14 @@ func (op *ScanOperator) Clone() stream.Operator { // Iterate over the objects of the table. Each object is stored in the environment // that is passed to the fn function, using SetCurrentValue. -func (it *ScanOperator) Iterate(in *environment.Environment, fn func(out *environment.Environment) error) error { +func (op *ScanOperator) Iterate(in *environment.Environment, fn func(out *environment.Environment) error) error { var newEnv environment.Environment newEnv.SetOuter(in) - table := it.Table + table := op.Table var err error if table == nil { - table, err = in.GetTx().Catalog.GetTable(in.GetTx(), it.TableName) + table, err = in.GetTx().Catalog.GetTable(in.GetTx(), op.TableName) if err != nil { return err } @@ -60,24 +59,17 @@ func (it *ScanOperator) Iterate(in *environment.Environment, fn func(out *enviro var ranges []*database.Range - if it.Ranges == nil { + if op.Ranges == nil { ranges = []*database.Range{nil} } else { - ranges, err = it.Ranges.Eval(in) + ranges, err = op.Ranges.Eval(in) if err != nil { return err } } for _, rng := range ranges { - err = table.IterateOnRange(rng, it.Reverse, func(key *tree.Key, r database.Row) error { - newEnv.SetRow(r) - - return fn(&newEnv) - }) - if errors.Is(err, stream.ErrStreamClosed) { - err = nil - } + err = op.iterateOverRange(table, rng, &newEnv, fn) if err != nil { return err } @@ -86,6 +78,43 @@ func (it *ScanOperator) Iterate(in *environment.Environment, fn func(out *enviro return nil } +func (op *ScanOperator) iterateOverRange(table *database.Table, rng *database.Range, to *environment.Environment, fn func(out *environment.Environment) error) error { + it, err := table.Iterator(rng) + if err != nil { + return err + } + defer it.Close() + + if !op.Reverse { + it.First() + } else { + it.Last() + } + + for it.Valid() { + row, err := it.Value() + if err != nil { + return err + } + to.SetRow(row) + err = fn(to) + if errors.Is(err, stream.ErrStreamClosed) { + break + } + if err != nil { + return err + } + + if !op.Reverse { + it.Next() + } else { + it.Prev() + } + } + + return it.Error() +} + func (it *ScanOperator) Columns(env *environment.Environment) ([]string, error) { tx := env.GetTx()