diff --git a/internal/database/catalog.go b/internal/database/catalog.go index 7a8a846a..5de78287 100644 --- a/internal/database/catalog.go +++ b/internal/database/catalog.go @@ -215,7 +215,7 @@ func (c *CatalogWriter) ensureSequenceExists(tx *Transaction, seq *SequenceInfo) return nil } -func (c *CatalogWriter) generateStoreName(tx *Transaction) (tree.Namespace, error) { +func (c *CatalogWriter) generateStoreNamespace(tx *Transaction) (tree.Namespace, error) { seq, err := c.Catalog.GetSequence(StoreSequence) if err != nil { return 0, err @@ -249,7 +249,7 @@ func (c *CatalogWriter) CreateTable(tx *Transaction, tableName string, info *Tab } if info.StoreNamespace == 0 { - info.StoreNamespace, err = c.generateStoreName(tx) + info.StoreNamespace, err = c.generateStoreNamespace(tx) if err != nil { return err } @@ -302,33 +302,38 @@ func (c *CatalogWriter) DropTable(tx *Transaction, tableName string) error { // CreateIndex creates an index with the given name. // If it already exists, returns errs.ErrIndexAlreadyExists. -func (c *CatalogWriter) CreateIndex(tx *Transaction, info *IndexInfo) error { +func (c *CatalogWriter) CreateIndex(tx *Transaction, info *IndexInfo) (*IndexInfo, error) { // check if the associated table exists ti, err := c.Catalog.GetTableInfo(info.Owner.TableName) if err != nil { - return err + return nil, err } // check if the indexed fields exist for _, p := range info.Paths { fc := ti.GetFieldConstraintForPath(p) if fc == nil { - return errors.Errorf("field %q does not exist for table %q", p, ti.TableName) + return nil, errors.Errorf("field %q does not exist for table %q", p, ti.TableName) } } - info.StoreNamespace, err = c.generateStoreName(tx) + info.StoreNamespace, err = c.generateStoreNamespace(tx) if err != nil { - return err + return nil, err } rel := IndexInfoRelation{Info: info} err = c.Catalog.Cache.Add(tx, &rel) if err != nil { - return err + return nil, err } - return c.Catalog.CatalogTable.Insert(tx, &rel) + err = c.Catalog.CatalogTable.Insert(tx, &rel) + if err != nil { + return nil, err + } + + return info, nil } // DropIndex deletes an index from the diff --git a/internal/database/catalog_test.go b/internal/database/catalog_test.go index 8dfac829..3982efca 100644 --- a/internal/database/catalog_test.go +++ b/internal/database/catalog_test.go @@ -113,9 +113,9 @@ func TestCatalogTable(t *testing.T) { err := catalog.CreateTable(tx, "foo", ti) assert.NoError(t, err) - err = catalog.CreateIndex(tx, &database.IndexInfo{Paths: []document.Path{testutil.ParseDocumentPath(t, "gender")}, IndexName: "idx_gender", Owner: database.Owner{TableName: "foo"}}) + _, err = catalog.CreateIndex(tx, &database.IndexInfo{Paths: []document.Path{testutil.ParseDocumentPath(t, "gender")}, IndexName: "idx_gender", Owner: database.Owner{TableName: "foo"}}) assert.NoError(t, err) - err = catalog.CreateIndex(tx, &database.IndexInfo{Paths: []document.Path{testutil.ParseDocumentPath(t, "city")}, IndexName: "idx_city", Owner: database.Owner{TableName: "foo"}, Unique: true}) + _, err = catalog.CreateIndex(tx, &database.IndexInfo{Paths: []document.Path{testutil.ParseDocumentPath(t, "city")}, IndexName: "idx_city", Owner: database.Owner{TableName: "foo"}, Unique: true}) assert.NoError(t, err) seq := database.SequenceInfo{ @@ -295,7 +295,7 @@ func TestCatalogCreateIndex(t *testing.T) { clone := db.Catalog().Clone() updateCatalog(t, db, func(tx *database.Transaction, catalog *database.CatalogWriter) error { - err := catalog.CreateIndex(tx, &database.IndexInfo{ + _, err := catalog.CreateIndex(tx, &database.IndexInfo{ IndexName: "idx_a", Owner: database.Owner{TableName: "test"}, Paths: []document.Path{testutil.ParseDocumentPath(t, "a")}, }) assert.NoError(t, err) @@ -321,12 +321,12 @@ func TestCatalogCreateIndex(t *testing.T) { }) updateCatalog(t, db, func(tx *database.Transaction, catalog *database.CatalogWriter) error { - err := catalog.CreateIndex(tx, &database.IndexInfo{ + _, err := catalog.CreateIndex(tx, &database.IndexInfo{ IndexName: "idxFoo", Owner: database.Owner{TableName: "test"}, Paths: []document.Path{testutil.ParseDocumentPath(t, "foo")}, }) assert.NoError(t, err) - err = catalog.CreateIndex(tx, &database.IndexInfo{ + _, err = catalog.CreateIndex(tx, &database.IndexInfo{ IndexName: "idxFoo", Owner: database.Owner{TableName: "test"}, Paths: []document.Path{testutil.ParseDocumentPath(t, "foo")}, }) assert.ErrorIs(t, err, errs.AlreadyExistsError{Name: "idxFoo"}) @@ -337,7 +337,7 @@ func TestCatalogCreateIndex(t *testing.T) { t.Run("Should fail if table doesn't exist", func(t *testing.T) { db := testutil.NewTestDB(t) updateCatalog(t, db, func(tx *database.Transaction, catalog *database.CatalogWriter) error { - err := catalog.CreateIndex(tx, &database.IndexInfo{ + _, err := catalog.CreateIndex(tx, &database.IndexInfo{ IndexName: "idxFoo", Owner: database.Owner{TableName: "test"}, Paths: []document.Path{testutil.ParseDocumentPath(t, "foo")}, }) if !errs.IsNotFoundError(err) { @@ -368,7 +368,7 @@ func TestCatalogCreateIndex(t *testing.T) { }) updateCatalog(t, db, func(tx *database.Transaction, catalog *database.CatalogWriter) error { - err := catalog.CreateIndex(tx, &database.IndexInfo{ + _, err := catalog.CreateIndex(tx, &database.IndexInfo{ Owner: database.Owner{TableName: "test"}, Paths: []document.Path{testutil.ParseDocumentPath(t, "foo.` bar `.c")}, }) assert.NoError(t, err) @@ -377,7 +377,7 @@ func TestCatalogCreateIndex(t *testing.T) { assert.NoError(t, err) // create another one - err = catalog.CreateIndex(tx, &database.IndexInfo{ + _, err = catalog.CreateIndex(tx, &database.IndexInfo{ Owner: database.Owner{TableName: "test"}, Paths: []document.Path{testutil.ParseDocumentPath(t, "foo.` bar `.c")}, }) assert.NoError(t, err) @@ -401,11 +401,11 @@ func TestTxDropIndex(t *testing.T) { ), }) assert.NoError(t, err) - err = catalog.CreateIndex(tx, &database.IndexInfo{ + _, err = catalog.CreateIndex(tx, &database.IndexInfo{ IndexName: "idxFoo", Owner: database.Owner{TableName: "test"}, Paths: []document.Path{testutil.ParseDocumentPath(t, "foo")}, }) assert.NoError(t, err) - err = catalog.CreateIndex(tx, &database.IndexInfo{ + _, err = catalog.CreateIndex(tx, &database.IndexInfo{ IndexName: "idxBar", Owner: database.Owner{TableName: "test"}, Paths: []document.Path{testutil.ParseDocumentPath(t, "bar")}, }) assert.NoError(t, err) diff --git a/internal/database/constraint.go b/internal/database/constraint.go index 1b197027..7d325aa3 100644 --- a/internal/database/constraint.go +++ b/internal/database/constraint.go @@ -22,7 +22,7 @@ type FieldConstraint struct { } func (f *FieldConstraint) IsEmpty() bool { - return f.Type.IsAny() && !f.IsNotNull && f.DefaultValue == nil + return f.Field == "" && f.Type.IsAny() && !f.IsNotNull && f.DefaultValue == nil } func (f *FieldConstraint) String() string { @@ -43,7 +43,7 @@ func (f *FieldConstraint) String() string { s.WriteString(" NOT NULL") } - if f.HasDefaultValue() { + if f.DefaultValue != nil { s.WriteString(" DEFAULT ") s.WriteString(f.DefaultValue.String()) } @@ -51,11 +51,6 @@ func (f *FieldConstraint) String() string { return s.String() } -// HasDefaultValue returns this field contains a default value constraint. -func (f *FieldConstraint) HasDefaultValue() bool { - return f.DefaultValue != nil -} - // FieldConstraints is a list of field constraints. type FieldConstraints struct { Ordered []*FieldConstraint diff --git a/internal/database/table.go b/internal/database/table.go index aab5f202..22d893bc 100644 --- a/internal/database/table.go +++ b/internal/database/table.go @@ -65,7 +65,8 @@ func (t *Table) Insert(d types.Document) (*tree.Key, types.Document, error) { func (t *Table) encodeDocument(d types.Document) (types.Document, []byte, error) { ed, ok := d.(*EncodedDocument) - if ok { + // pointer comparison is enough here + if ok && ed.fieldConstraints == &t.Info.FieldConstraints { return d, ed.encoded, nil } diff --git a/internal/query/statement/alter.go b/internal/query/statement/alter.go index 2e0e5f92..86ad980c 100644 --- a/internal/query/statement/alter.go +++ b/internal/query/statement/alter.go @@ -4,22 +4,25 @@ import ( "github.com/cockroachdb/errors" "github.com/genjidb/genji/internal/database" errs "github.com/genjidb/genji/internal/errors" + "github.com/genjidb/genji/internal/stream" + "github.com/genjidb/genji/internal/stream/index" + "github.com/genjidb/genji/internal/stream/table" ) -// AlterStmt is a DSL that allows creating a full ALTER TABLE query. -type AlterStmt struct { +// AlterTableRenameStmt is a DSL that allows creating a full ALTER TABLE query. +type AlterTableRenameStmt struct { TableName string NewTableName string } // IsReadOnly always returns false. It implements the Statement interface. -func (stmt AlterStmt) IsReadOnly() bool { +func (stmt AlterTableRenameStmt) IsReadOnly() bool { return false } // Run runs the ALTER TABLE statement in the given transaction. // It implements the Statement interface. -func (stmt AlterStmt) Run(ctx *Context) (Result, error) { +func (stmt AlterTableRenameStmt) Run(ctx *Context) (Result, error) { var res Result if stmt.TableName == "" { @@ -38,25 +41,131 @@ func (stmt AlterStmt) Run(ctx *Context) (Result, error) { return res, err } -type AlterTableAddField struct { - Info database.TableInfo +type AlterTableAddFieldStmt struct { + TableName string + FieldConstraint *database.FieldConstraint + TableConstraints database.TableConstraints } // IsReadOnly always returns false. It implements the Statement interface. -func (stmt AlterTableAddField) IsReadOnly() bool { +func (stmt *AlterTableAddFieldStmt) IsReadOnly() bool { return false } // Run runs the ALTER TABLE ADD FIELD statement in the given transaction. // It implements the Statement interface. -func (stmt AlterTableAddField) Run(ctx *Context) (Result, error) { - var res Result +// The statement rebuilds the table. +func (stmt *AlterTableAddFieldStmt) Run(ctx *Context) (Result, error) { + var err error - var fc *database.FieldConstraint - if len(stmt.Info.FieldConstraints.Ordered) != 0 { - fc = stmt.Info.FieldConstraints.Ordered[0] + // get the table before adding the field constraint + // and assign the table to the table.Scan operator + // so that it can decode the records properly + scan := table.Scan(stmt.TableName) + scan.Table, err = ctx.Tx.Catalog.GetTable(ctx.Tx, stmt.TableName) + if err != nil { + return Result{}, errors.Wrap(err, "failed to get table") } - err := ctx.Tx.CatalogWriter().AddFieldConstraint(ctx.Tx, stmt.Info.TableName, fc, stmt.Info.TableConstraints) - return res, err + // get the current list of indexes + indexNames := ctx.Tx.Catalog.ListIndexes(stmt.TableName) + + // add the field constraint to the table + err = ctx.Tx.CatalogWriter().AddFieldConstraint( + ctx.Tx, + stmt.TableName, + stmt.FieldConstraint, + stmt.TableConstraints) + if err != nil { + return Result{}, err + } + + // create a unique index for every unique constraint + pkAdded := false + var newIdxs []*database.IndexInfo + for _, tc := range stmt.TableConstraints { + if tc.Unique { + idx, err := ctx.Tx.CatalogWriter().CreateIndex(ctx.Tx, &database.IndexInfo{ + Paths: tc.Paths, + Unique: true, + Owner: database.Owner{ + TableName: stmt.TableName, + Paths: tc.Paths, + }, + }) + if err != nil { + return Result{}, err + } + + newIdxs = append(newIdxs, idx) + } + + if tc.PrimaryKey { + pkAdded = true + } + } + + // create the stream: + // on one side, scan the table with the old schema + // on the other side, insert the records into the same table with the new schema + s := stream.New(scan) + + // if a primary key was added, we need to delete the old records + // and old indexes, and insert the new records and indexes + if pkAdded { + // delete the old records from the indexes + for _, indexName := range indexNames { + s = s.Pipe(index.Delete(indexName)) + } + // delete the old records from the table + s = s.Pipe(table.Delete(stmt.TableName)) + + // validate the record against the new schema + s = s.Pipe(table.Validate(stmt.TableName)) + + // insert the record with the new primary key + s = s.Pipe(table.Insert(stmt.TableName)) + + // insert the record into the all the indexes + indexNames = ctx.Tx.Catalog.ListIndexes(stmt.TableName) + for _, indexName := range indexNames { + info, err := ctx.Tx.Catalog.GetIndexInfo(indexName) + if err != nil { + return Result{}, err + } + if info.Unique { + s = s.Pipe(index.Validate(indexName)) + } + + s = s.Pipe(index.Insert(indexName)) + } + } else { + // otherwise, we can just replace the old records with the new ones + + // validate the record against the new schema + s = s.Pipe(table.Validate(stmt.TableName)) + + // replace the old record with the new one + s = s.Pipe(table.Replace(stmt.TableName)) + + // update the new indexes only + for _, idx := range newIdxs { + if idx.Unique { + s = s.Pipe(index.Validate(idx.IndexName)) + } + + s = s.Pipe(index.Insert(idx.IndexName)) + } + } + + // ALTER TABLE ADD FIELD does not return any result + s = s.Pipe(stream.Discard()) + + // do NOT optimize the stream + return Result{ + Iterator: &StreamStmtIterator{ + Stream: s, + Context: ctx, + }, + }, nil } diff --git a/internal/query/statement/create.go b/internal/query/statement/create.go index 0e89e9f2..61e72f78 100644 --- a/internal/query/statement/create.go +++ b/internal/query/statement/create.go @@ -55,7 +55,7 @@ func (stmt *CreateTableStmt) Run(ctx *Context) (Result, error) { // create a unique index for every unique constraint for _, tc := range stmt.Info.TableConstraints { if tc.Unique { - err = ctx.Tx.CatalogWriter().CreateIndex(ctx.Tx, &database.IndexInfo{ + _, err = ctx.Tx.CatalogWriter().CreateIndex(ctx.Tx, &database.IndexInfo{ Paths: tc.Paths, Unique: true, Owner: database.Owner{ @@ -88,7 +88,7 @@ func (stmt *CreateIndexStmt) IsReadOnly() bool { func (stmt *CreateIndexStmt) Run(ctx *Context) (Result, error) { var res Result - err := ctx.Tx.CatalogWriter().CreateIndex(ctx.Tx, &stmt.Info) + _, err := ctx.Tx.CatalogWriter().CreateIndex(ctx.Tx, &stmt.Info) if stmt.IfNotExists { if errs.IsAlreadyExistsError(err) { return res, nil @@ -99,7 +99,7 @@ func (stmt *CreateIndexStmt) Run(ctx *Context) (Result, error) { } s := stream.New(table.Scan(stmt.Info.Owner.TableName)). - Pipe(index.IndexInsert(stmt.Info.IndexName)). + Pipe(index.Insert(stmt.Info.IndexName)). Pipe(stream.Discard()) ss := PreparedStreamStmt{ diff --git a/internal/query/statement/explain_test.go b/internal/query/statement/explain_test.go index 75d1aa2b..58c16d61 100644 --- a/internal/query/statement/explain_test.go +++ b/internal/query/statement/explain_test.go @@ -32,9 +32,9 @@ func TestExplainStmt(t *testing.T) { {"EXPLAIN SELECT a + 1 FROM test WHERE c > 30 ORDER BY a DESC LIMIT 10 OFFSET 20", false, `"index.ScanReverse(\"idx_a\") | docs.Filter(c > 30) | docs.Project(a + 1) | docs.Skip(20) | docs.Take(10)"`}, {"EXPLAIN SELECT a FROM test WHERE c > 30 GROUP BY a ORDER BY a DESC LIMIT 10 OFFSET 20", false, `"index.ScanReverse(\"idx_a\") | docs.Filter(c > 30) | docs.GroupAggregate(a) | docs.Project(a) | docs.Skip(20) | docs.Take(10)"`}, {"EXPLAIN SELECT a + 1 FROM test WHERE c > 30 GROUP BY a + 1 ORDER BY a DESC LIMIT 10 OFFSET 20", false, `"table.Scan(\"test\") | docs.Filter(c > 30) | docs.TempTreeSort(a + 1) | docs.GroupAggregate(a + 1) | docs.Project(a + 1) | docs.TempTreeSortReverse(a) | docs.Skip(20) | docs.Take(10)"`}, - {"EXPLAIN UPDATE test SET a = 10", false, `"table.Scan(\"test\") | paths.Set(a, 10) | table.Validate(\"test\") | index.Delete(\"idx_a\") | index.Delete(\"idx_b\") | index.Delete(\"idx_x_y\") | table.Replace(\"test\") | index.Insert(\"idx_a\") | index.Insert(\"idx_b\") | index.Insert(\"idx_x_y\") | discard()"`}, - {"EXPLAIN UPDATE test SET a = 10 WHERE c > 10", false, `"table.Scan(\"test\") | docs.Filter(c > 10) | paths.Set(a, 10) | table.Validate(\"test\") | index.Delete(\"idx_a\") | index.Delete(\"idx_b\") | index.Delete(\"idx_x_y\") | table.Replace(\"test\") | index.Insert(\"idx_a\") | index.Insert(\"idx_b\") | index.Insert(\"idx_x_y\") | discard()"`}, - {"EXPLAIN UPDATE test SET a = 10 WHERE a > 10", false, `"index.Scan(\"idx_a\", [{\"min\": [10], \"exclusive\": true}]) | paths.Set(a, 10) | table.Validate(\"test\") | index.Delete(\"idx_a\") | index.Delete(\"idx_b\") | index.Delete(\"idx_x_y\") | table.Replace(\"test\") | index.Insert(\"idx_a\") | index.Insert(\"idx_b\") | index.Insert(\"idx_x_y\") | discard()"`}, + {"EXPLAIN UPDATE test SET a = 10", false, `"table.Scan(\"test\") | paths.Set(a, 10) | table.Validate(\"test\") | index.Delete(\"idx_a\") | index.Delete(\"idx_b\") | index.Delete(\"idx_x_y\") | table.Replace(\"test\") | index.Insert(\"idx_a\") | index.Validate(\"idx_b\") | index.Insert(\"idx_b\") | index.Insert(\"idx_x_y\") | discard()"`}, + {"EXPLAIN UPDATE test SET a = 10 WHERE c > 10", false, `"table.Scan(\"test\") | docs.Filter(c > 10) | paths.Set(a, 10) | table.Validate(\"test\") | index.Delete(\"idx_a\") | index.Delete(\"idx_b\") | index.Delete(\"idx_x_y\") | table.Replace(\"test\") | index.Insert(\"idx_a\") | index.Validate(\"idx_b\") | index.Insert(\"idx_b\") | index.Insert(\"idx_x_y\") | discard()"`}, + {"EXPLAIN UPDATE test SET a = 10 WHERE a > 10", false, `"index.Scan(\"idx_a\", [{\"min\": [10], \"exclusive\": true}]) | paths.Set(a, 10) | table.Validate(\"test\") | index.Delete(\"idx_a\") | index.Delete(\"idx_b\") | index.Delete(\"idx_x_y\") | table.Replace(\"test\") | index.Insert(\"idx_a\") | index.Validate(\"idx_b\") | index.Insert(\"idx_b\") | index.Insert(\"idx_x_y\") | discard()"`}, {"EXPLAIN DELETE FROM test", false, `"table.Scan(\"test\") | index.Delete(\"idx_a\") | index.Delete(\"idx_b\") | index.Delete(\"idx_x_y\") | table.Delete('test') | discard()"`}, {"EXPLAIN DELETE FROM test WHERE c > 10", false, `"table.Scan(\"test\") | docs.Filter(c > 10) | index.Delete(\"idx_a\") | index.Delete(\"idx_b\") | index.Delete(\"idx_x_y\") | table.Delete('test') | discard()"`}, {"EXPLAIN DELETE FROM test WHERE a > 10", false, `"index.Scan(\"idx_a\", [{\"min\": [10], \"exclusive\": true}]) | index.Delete(\"idx_a\") | index.Delete(\"idx_b\") | index.Delete(\"idx_x_y\") | table.Delete('test') | discard()"`}, diff --git a/internal/query/statement/insert.go b/internal/query/statement/insert.go index 569def6a..bc69d4ff 100644 --- a/internal/query/statement/insert.go +++ b/internal/query/statement/insert.go @@ -121,7 +121,7 @@ func (stmt *InsertStmt) Prepare(c *Context) (Statement, error) { s = s.Pipe(table.Insert(stmt.TableName)) for _, indexName := range indexNames { - s = s.Pipe(index.IndexInsert(indexName)) + s = s.Pipe(index.Insert(indexName)) } if len(stmt.Returning) > 0 { diff --git a/internal/query/statement/reindex.go b/internal/query/statement/reindex.go index 89d01919..19057424 100644 --- a/internal/query/statement/reindex.go +++ b/internal/query/statement/reindex.go @@ -57,7 +57,7 @@ func (stmt ReIndexStmt) Prepare(ctx *Context) (Statement, error) { return nil, err } - s := stream.New(table.Scan(info.Owner.TableName)).Pipe(index.IndexInsert(info.IndexName)) + s := stream.New(table.Scan(info.Owner.TableName)).Pipe(index.Insert(info.IndexName)) streams = append(streams, s) } diff --git a/internal/query/statement/update.go b/internal/query/statement/update.go index c0f82228..e152f2f6 100644 --- a/internal/query/statement/update.go +++ b/internal/query/statement/update.go @@ -108,7 +108,15 @@ func (stmt *UpdateStmt) Prepare(c *Context) (Statement, error) { } for _, indexName := range indexNames { - s = s.Pipe(index.IndexInsert(indexName)) + info, err := c.Tx.Catalog.GetIndexInfo(indexName) + if err != nil { + return nil, err + } + if info.Unique { + s = s.Pipe(index.Validate(indexName)) + } + + s = s.Pipe(index.Insert(indexName)) } s = s.Pipe(stream.Discard()) diff --git a/internal/sql/parser/alter.go b/internal/sql/parser/alter.go index 311f1b42..62598be8 100644 --- a/internal/sql/parser/alter.go +++ b/internal/sql/parser/alter.go @@ -7,8 +7,8 @@ import ( "github.com/genjidb/genji/internal/sql/scanner" ) -func (p *Parser) parseAlterTableRenameStatement(tableName string) (_ statement.AlterStmt, err error) { - var stmt statement.AlterStmt +func (p *Parser) parseAlterTableRenameStatement(tableName string) (_ statement.AlterTableRenameStmt, err error) { + var stmt statement.AlterTableRenameStmt stmt.TableName = tableName // Parse "TO". @@ -25,42 +25,27 @@ func (p *Parser) parseAlterTableRenameStatement(tableName string) (_ statement.A return stmt, nil } -func (p *Parser) parseAlterTableAddFieldStatement(tableName string) (_ statement.AlterTableAddField, err error) { - var stmt statement.AlterTableAddField - stmt.Info.TableName = tableName +func (p *Parser) parseAlterTableAddFieldStatement(tableName string) (*statement.AlterTableAddFieldStmt, error) { + var stmt statement.AlterTableAddFieldStmt + stmt.TableName = tableName // Parse "FIELD". if err := p.parseTokens(scanner.FIELD); err != nil { - return stmt, err + return nil, err } // Parse new field definition. - fc, tcs, err := p.parseFieldDefinition(nil) + var err error + stmt.FieldConstraint, stmt.TableConstraints, err = p.parseFieldDefinition(nil) if err != nil { - return stmt, err + return nil, err } - if fc.IsEmpty() { - return stmt, &ParseError{Message: "cannot add a field with no constraint"} + if stmt.FieldConstraint.IsEmpty() { + return nil, &ParseError{Message: "cannot add a field with no constraint"} } - err = stmt.Info.AddFieldConstraint(fc) - if err != nil { - return stmt, err - } - - for _, tc := range tcs { - err = stmt.Info.AddTableConstraint(tc) - if err != nil { - return stmt, err - } - } - - if stmt.Info.GetPrimaryKey() != nil { - return stmt, &ParseError{Message: "cannot add a PRIMARY KEY constraint"} - } - - return stmt, nil + return &stmt, nil } // parseAlterStatement parses a Alter query string and returns a Statement AST object. diff --git a/internal/sql/parser/alter_test.go b/internal/sql/parser/alter_test.go index 24ce1da7..1e1c0aee 100644 --- a/internal/sql/parser/alter_test.go +++ b/internal/sql/parser/alter_test.go @@ -3,6 +3,7 @@ package parser_test import ( "testing" + "github.com/genjidb/genji/document" "github.com/genjidb/genji/internal/database" "github.com/genjidb/genji/internal/expr" "github.com/genjidb/genji/internal/query/statement" @@ -19,10 +20,10 @@ func TestParserAlterTable(t *testing.T) { expected statement.Statement errored bool }{ - {"Basic", "ALTER TABLE foo RENAME TO bar", statement.AlterStmt{TableName: "foo", NewTableName: "bar"}, false}, - {"With error / missing TABLE keyword", "ALTER foo RENAME TO bar", statement.AlterStmt{}, true}, - {"With error / two identifiers for table name", "ALTER TABLE foo baz RENAME TO bar", statement.AlterStmt{}, true}, - {"With error / two identifiers for new table name", "ALTER TABLE foo RENAME TO bar baz", statement.AlterStmt{}, true}, + {"Basic", "ALTER TABLE foo RENAME TO bar", statement.AlterTableRenameStmt{TableName: "foo", NewTableName: "bar"}, false}, + {"With error / missing TABLE keyword", "ALTER foo RENAME TO bar", statement.AlterTableRenameStmt{}, true}, + {"With error / two identifiers for table name", "ALTER TABLE foo baz RENAME TO bar", statement.AlterTableRenameStmt{}, true}, + {"With error / two identifiers for new table name", "ALTER TABLE foo RENAME TO bar baz", statement.AlterTableRenameStmt{}, true}, } for _, test := range tests { @@ -46,41 +47,47 @@ func TestParserAlterTableAddField(t *testing.T) { expected statement.Statement errored bool }{ - {"Basic", "ALTER TABLE foo ADD FIELD bar", nil, true}, - {"With type", "ALTER TABLE foo ADD FIELD bar integer", statement.AlterTableAddField{ - Info: database.TableInfo{ - TableName: "foo", - FieldConstraints: database.MustNewFieldConstraints( - &database.FieldConstraint{ - Field: "bar", - Type: types.IntegerValue, - }, - ), + {"Basic", "ALTER TABLE foo ADD FIELD bar", &statement.AlterTableAddFieldStmt{ + TableName: "foo", + FieldConstraint: &database.FieldConstraint{ + Field: "bar", + Type: types.AnyValue, }, }, false}, - {"With not null", "ALTER TABLE foo ADD FIELD bar NOT NULL", statement.AlterTableAddField{ - Info: database.TableInfo{ - TableName: "foo", - FieldConstraints: database.MustNewFieldConstraints( - &database.FieldConstraint{ - Field: "bar", - IsNotNull: true, - }, - ), + {"With type", "ALTER TABLE foo ADD FIELD bar integer", &statement.AlterTableAddFieldStmt{ + TableName: "foo", + FieldConstraint: &database.FieldConstraint{ + Field: "bar", + Type: types.IntegerValue, }, }, false}, - {"With primary key", "ALTER TABLE foo ADD FIELD bar PRIMARY KEY", nil, true}, - {"With multiple constraints", "ALTER TABLE foo ADD FIELD bar integer NOT NULL DEFAULT 0", statement.AlterTableAddField{ - Info: database.TableInfo{ - TableName: "foo", - FieldConstraints: database.MustNewFieldConstraints( - &database.FieldConstraint{ - Field: "bar", - Type: types.IntegerValue, - IsNotNull: true, - DefaultValue: expr.Constraint(expr.LiteralValue{Value: types.NewIntegerValue(0)}), - }, - ), + {"With not null", "ALTER TABLE foo ADD FIELD bar NOT NULL", &statement.AlterTableAddFieldStmt{ + TableName: "foo", + FieldConstraint: &database.FieldConstraint{ + Field: "bar", + IsNotNull: true, + }, + }, false}, + {"With primary key", "ALTER TABLE foo ADD FIELD bar PRIMARY KEY", &statement.AlterTableAddFieldStmt{ + TableName: "foo", + FieldConstraint: &database.FieldConstraint{ + Field: "bar", + Type: types.AnyValue, + }, + TableConstraints: database.TableConstraints{ + &database.TableConstraint{ + Paths: document.Paths{document.NewPath("bar")}, + PrimaryKey: true, + }, + }, + }, false}, + {"With multiple constraints", "ALTER TABLE foo ADD FIELD bar integer NOT NULL DEFAULT 0", &statement.AlterTableAddFieldStmt{ + TableName: "foo", + FieldConstraint: &database.FieldConstraint{ + Field: "bar", + Type: types.IntegerValue, + IsNotNull: true, + DefaultValue: expr.Constraint(expr.LiteralValue{Value: types.NewIntegerValue(0)}), }, }, false}, {"With error / missing FIELD keyword", "ALTER TABLE foo ADD bar", nil, true}, diff --git a/internal/sql/parser/create.go b/internal/sql/parser/create.go index b5bbe0b7..63b35b12 100644 --- a/internal/sql/parser/create.go +++ b/internal/sql/parser/create.go @@ -216,7 +216,7 @@ LOOP: fc.IsNotNull = true case scanner.DEFAULT: // if it has already a default value we return an error - if fc.HasDefaultValue() { + if fc.DefaultValue != nil { return nil, nil, newParseError(scanner.Tokstr(tok, lit), []string{"CONSTRAINT", ")"}, pos) } diff --git a/internal/stream/index/index.go b/internal/stream/index/index.go deleted file mode 100644 index 60a0228e..00000000 --- a/internal/stream/index/index.go +++ /dev/null @@ -1 +0,0 @@ -package index diff --git a/internal/stream/index/insert.go b/internal/stream/index/insert.go index e0b2536d..623d30a7 100644 --- a/internal/stream/index/insert.go +++ b/internal/stream/index/insert.go @@ -16,7 +16,7 @@ type InsertOperator struct { indexName string } -func IndexInsert(indexName string) *InsertOperator { +func Insert(indexName string) *InsertOperator { return &InsertOperator{ indexName: indexName, } diff --git a/internal/stream/table/scan.go b/internal/stream/table/scan.go index 6c02657c..13d22650 100644 --- a/internal/stream/table/scan.go +++ b/internal/stream/table/scan.go @@ -18,6 +18,9 @@ type ScanOperator struct { TableName string Ranges stream.Ranges Reverse bool + // If set, the operator will scan this table. + // It not set, it will get the scan from the catalog. + Table *database.Table } // Scan creates an iterator that iterates over each document of the given table that match the given ranges. @@ -31,33 +34,6 @@ func ScanReverse(tableName string, ranges ...stream.Range) *ScanOperator { return &ScanOperator{TableName: tableName, Ranges: ranges, Reverse: true} } -func (it *ScanOperator) String() string { - var s strings.Builder - - s.WriteString("table.Scan") - if it.Reverse { - s.WriteString("Reverse") - } - - s.WriteRune('(') - - s.WriteString(strconv.Quote(it.TableName)) - if len(it.Ranges) > 0 { - s.WriteString(", [") - for i, r := range it.Ranges { - s.WriteString(r.String()) - if i+1 < len(it.Ranges) { - s.WriteString(", ") - } - } - s.WriteString("]") - } - - s.WriteString(")") - - return s.String() -} - // Iterate over the documents of the table. Each document is stored in the environment // that is passed to the fn function, using SetCurrentValue. func (it *ScanOperator) Iterate(in *environment.Environment, fn func(out *environment.Environment) error) error { @@ -65,9 +41,13 @@ func (it *ScanOperator) Iterate(in *environment.Environment, fn func(out *enviro newEnv.SetOuter(in) newEnv.Set(environment.TableKey, types.NewTextValue(it.TableName)) - table, err := in.GetTx().Catalog.GetTable(in.GetTx(), it.TableName) - if err != nil { - return err + table := it.Table + var err error + if table == nil { + table, err = in.GetTx().Catalog.GetTable(in.GetTx(), it.TableName) + if err != nil { + return err + } } var ranges []*database.Range @@ -98,3 +78,30 @@ func (it *ScanOperator) Iterate(in *environment.Environment, fn func(out *enviro return nil } + +func (it *ScanOperator) String() string { + var s strings.Builder + + s.WriteString("table.Scan") + if it.Reverse { + s.WriteString("Reverse") + } + + s.WriteRune('(') + + s.WriteString(strconv.Quote(it.TableName)) + if len(it.Ranges) > 0 { + s.WriteString(", [") + for i, r := range it.Ranges { + s.WriteString(r.String()) + if i+1 < len(it.Ranges) { + s.WriteString(", ") + } + } + s.WriteString("]") + } + + s.WriteString(")") + + return s.String() +} diff --git a/internal/testutil/assert/errors.go b/internal/testutil/assert/errors.go index 9fb157b5..0358a297 100644 --- a/internal/testutil/assert/errors.go +++ b/internal/testutil/assert/errors.go @@ -53,5 +53,5 @@ func NoErrorf(t testing.TB, err error, str string, args ...interface{}) { func NoError(t testing.TB, err error) { t.Helper() - NoErrorf(t, err, "Expected error to be nil but got %q instead: %+v", err, err) + NoErrorf(t, err, "Expected error to be nil: %+v", err) } diff --git a/sqltests/ALTER_TABLE/add_field.sql b/sqltests/ALTER_TABLE/add_field.sql new file mode 100644 index 00000000..37f01388 --- /dev/null +++ b/sqltests/ALTER_TABLE/add_field.sql @@ -0,0 +1,144 @@ +-- setup: +CREATE TABLE test(a int); + +-- test: field constraints are updated +INSERT INTO test VALUES (1), (2); +ALTER TABLE test ADD FIELD b int DEFAULT 0; +SELECT name, sql FROM __genji_catalog WHERE type = "table" AND name = "test"; +/* result: +{ + "name": "test", + "sql": "CREATE TABLE test (a INTEGER, b INTEGER DEFAULT 0)" +} +*/ + +-- test: default value is updated +INSERT INTO test VALUES (1), (2); +ALTER TABLE test ADD FIELD b int DEFAULT 0; +SELECT * FROM test; +/* result: +{ + "a": 1, + "b": 0 +} +{ + "a": 2, + "b": 0 +} +*/ + +-- test: not null alone +INSERT INTO test VALUES (1), (2); +ALTER TABLE test ADD FIELD b int NOT NULL; +-- error: NOT NULL constraint error: [b] + +-- test: not null with default +INSERT INTO test VALUES (1), (2); +ALTER TABLE test ADD FIELD b int NOT NULL DEFAULT 10; +SELECT * FROM test; +/* result: +{ + "a": 1, + "b": 10 +} +{ + "a": 2, + "b": 10 +} +*/ + +-- test: unique +INSERT INTO test VALUES (1), (2); +ALTER TABLE test ADD FIELD b int UNIQUE; +SELECT * FROM test; +/* result: +{ + "a": 1 +} +{ + "a": 2 +} +*/ + +-- test: unique with default: with data +INSERT INTO test VALUES (1), (2); +ALTER TABLE test ADD FIELD b int UNIQUE DEFAULT 10; +-- error: UNIQUE constraint error: [b] + +-- test: unique with default: without data +ALTER TABLE test ADD FIELD b int UNIQUE DEFAULT 10; +INSERT INTO test VALUES (1), (2); +-- error: UNIQUE constraint error: [b] + +-- test: primary key: with data +INSERT INTO test VALUES (1), (2); +ALTER TABLE test ADD FIELD b int PRIMARY KEY; +-- error: NOT NULL constraint error: [b] + +-- test: primary key: without data +ALTER TABLE test ADD FIELD b int PRIMARY KEY; +INSERT INTO test VALUES (1, 10), (2, 20); +SELECT pk() FROM test; +/* result: +{ + "pk()": [10] +} +{ + "pk()": [20] +} +*/ + +-- test: primary key: with default: with data +INSERT INTO test VALUES (1), (2); +ALTER TABLE test ADD FIELD b int PRIMARY KEY DEFAULT 10; +-- error: PRIMARY KEY constraint error: [b] + +-- test: primary key: with default: without data +ALTER TABLE test ADD FIELD b int PRIMARY KEY DEFAULT 10; +INSERT INTO test VALUES (2, 20), (3, 30); +INSERT INTO test (a) VALUES (1); +SELECT * FROM test; +/* result: +{ + "a": 1, + "b": 10 +} +{ + "a": 2, + "b": 20 +} +{ + "a": 3, + "b": 30 +} +*/ + +-- test: no type +INSERT INTO test VALUES (1), (2); +ALTER TABLE test ADD FIELD b; +INSERT INTO test VALUES (3, 30), (4, 'hello'); +SELECT * FROM test; +/* result: +{ + "a": 1 +} +{ + "a": 2 +} +{ + "a": 3, + "b": 30.0 +} +{ + "a": 4, + "b": "hello" +} +*/ + +-- test: bad syntax: no field name +ALTER TABLE test ADD FIELD; +-- error: + +-- test: bad syntax: missing FIELD keyword +ALTER TABLE test ADD a int; +-- error: \ No newline at end of file diff --git a/sqltests/ALTER_TABLE/base.sql b/sqltests/ALTER_TABLE/base.sql deleted file mode 100644 index 47a87406..00000000 --- a/sqltests/ALTER_TABLE/base.sql +++ /dev/null @@ -1,21 +0,0 @@ --- setup: -CREATE TABLE test(a int primary key); - --- test: rename -ALTER TABLE test RENAME TO test2; -SELECT name, sql FROM __genji_catalog WHERE type = "table" AND (name = "test2" OR name = "test"); -/* result: -{ - "name": "test2", - "sql": "CREATE TABLE test2 (a INTEGER NOT NULL, CONSTRAINT test_pk PRIMARY KEY (a))" -} -*/ - --- test: non-existing -ALTER TABLE unknown RENAME TO test2; --- error: - --- test: duplicate -CREATE TABLE test2; -ALTER TABLE test2 RENAME TO test; --- error: diff --git a/sqltests/ALTER_TABLE/rename.sql b/sqltests/ALTER_TABLE/rename.sql new file mode 100644 index 00000000..4ed61115 --- /dev/null +++ b/sqltests/ALTER_TABLE/rename.sql @@ -0,0 +1,48 @@ +-- setup: +CREATE TABLE test(a int primary key); + +-- test: rename +ALTER TABLE test RENAME TO test2; +SELECT name, sql FROM __genji_catalog WHERE type = "table" AND (name = "test2" OR name = "test"); +/* result: +{ + "name": "test2", + "sql": "CREATE TABLE test2 (a INTEGER NOT NULL, CONSTRAINT test_pk PRIMARY KEY (a))" +} +*/ + +-- test: non-existing +ALTER TABLE unknown RENAME TO test2; +-- error: + +-- test: duplicate +CREATE TABLE test2; +ALTER TABLE test2 RENAME TO test; +-- error: + +-- test: reserved name +ALTER TABLE test RENAME TO __genji_catalog; +-- error: + +-- test: bad syntax: no new name +ALTER TABLE test RENAME TO; +-- error: + +-- test: bad syntax: no table name +ALTER TABLE RENAME TO test2; +-- error: + +-- test: bad syntax: no TABLE +ALTER RENAME TABLE test TO test2; +-- error: + +-- test: bad syntax: two identifiers for new name +ALTER TABLE test RENAME TO test2 test3; +-- error: + +-- test: bad syntax: two identifiers for table name +ALTER TABLE test test2 RENAME TO test3; +-- error: + + + diff --git a/sqltests/UPDATE/unique.sql b/sqltests/UPDATE/unique.sql new file mode 100644 index 00000000..b29715bf --- /dev/null +++ b/sqltests/UPDATE/unique.sql @@ -0,0 +1,7 @@ +-- setup: +CREATE TABLE test(a int UNIQUE); + +-- test: conflict +INSERT INTO test VALUES (1), (2); +UPDATE test SET a = 2 WHERE a = 1; +-- error: UNIQUE constraint error: [a] \ No newline at end of file