Improve default value expressions

This commit is contained in:
Asdine El Hrychy
2021-07-02 18:41:33 +04:00
parent 6af0a6f114
commit fa66f81fb6
16 changed files with 381 additions and 150 deletions

View File

@@ -378,9 +378,7 @@ func calculateValues(a, b Value, operator byte) (res Value, err error) {
return calculateFloats(a, b, operator) return calculateFloats(a, b, operator)
} }
if a.Type == IntegerValue || b.Type == IntegerValue { return calculateIntegers(a, b, operator)
return calculateIntegers(a, b, operator)
}
} }
return NewNullValue(), nil return NewNullValue(), nil

View File

@@ -112,8 +112,7 @@ func (c *catalogCache) Add(tx *database.Transaction, o Relation) error {
// if name is provided, ensure it's not duplicated // if name is provided, ensure it's not duplicated
if name != "" { if name != "" {
_, err := c.Get(RelationIndexType, name) if c.objectExists(name) {
if err == nil {
return errs.AlreadyExistsError{Name: name} return errs.AlreadyExistsError{Name: name}
} }
} else { } else {

View File

@@ -49,7 +49,7 @@ func (c *Catalog) Load(tx *database.Transaction) error {
return err return err
} }
// ensure the catalog table sequence exists // ensure the store sequence exists
err = c.CreateSequence(tx, &database.SequenceInfo{ err = c.CreateSequence(tx, &database.SequenceInfo{
Name: StoreSequence, Name: StoreSequence,
IncrementBy: 1, IncrementBy: 1,
@@ -75,6 +75,17 @@ func (c *Catalog) loadCatalog(tx *database.Transaction) error {
return err return err
} }
for _, tb := range tables {
// bind default values with catalog
for _, fc := range tb.FieldConstraints {
if fc.DefaultValue == nil {
continue
}
fc.DefaultValue.Bind(c)
}
}
// add the __genji_catalog table to the list of tables // add the __genji_catalog table to the list of tables
// so that it can be queried // so that it can be queried
ti := c.CatalogTable.Info.Clone() ti := c.CatalogTable.Info.Clone()
@@ -208,6 +219,15 @@ func (c *Catalog) CreateTable(tx *database.Transaction, tableName string, info *
} }
} }
// bind default values with catalog
for _, fc := range info.FieldConstraints {
if fc.DefaultValue == nil {
continue
}
fc.DefaultValue.Bind(c)
}
err = c.CatalogTable.Insert(tx, info) err = c.CatalogTable.Insert(tx, info)
if err != nil { if err != nil {
return err return err

View File

@@ -23,7 +23,7 @@ type FieldConstraint struct {
IsPrimaryKey bool IsPrimaryKey bool
IsNotNull bool IsNotNull bool
IsUnique bool IsUnique bool
DefaultValue document.Value DefaultValue TableExpression
Identity *FieldConstraintIdentity Identity *FieldConstraintIdentity
IsInferred bool IsInferred bool
InferredBy []document.Path InferredBy []document.Path
@@ -31,38 +31,38 @@ type FieldConstraint struct {
// IsEqual compares f with other member by member. // IsEqual compares f with other member by member.
// Inference is not compared. // Inference is not compared.
func (f *FieldConstraint) IsEqual(other *FieldConstraint) (bool, error) { func (f *FieldConstraint) IsEqual(other *FieldConstraint) bool {
if !f.Path.IsEqual(other.Path) { if !f.Path.IsEqual(other.Path) {
return false, nil return false
} }
if f.Type != other.Type { if f.Type != other.Type {
return false, nil return false
} }
if f.IsPrimaryKey != other.IsPrimaryKey { if f.IsPrimaryKey != other.IsPrimaryKey {
return false, nil return false
} }
if f.IsNotNull != other.IsNotNull { if f.IsNotNull != other.IsNotNull {
return false, nil return false
} }
if f.HasDefaultValue() != other.HasDefaultValue() { if f.HasDefaultValue() != other.HasDefaultValue() {
return false, nil return false
} }
if f.HasDefaultValue() { if f.HasDefaultValue() {
if ok, err := f.DefaultValue.IsEqual(other.DefaultValue); !ok || err != nil { if !f.DefaultValue.IsEqual(other.DefaultValue) {
return ok, err return false
} }
} }
if !f.Identity.IsEqual(other.Identity) { if !f.Identity.IsEqual(other.Identity) {
return false, nil return false
} }
return true, nil return true
} }
func (f *FieldConstraint) String() string { func (f *FieldConstraint) String() string {
@@ -111,7 +111,7 @@ func (f *FieldConstraint) MergeInferred(other *FieldConstraint) {
// HasDefaultValue returns this field contains a default value constraint. // HasDefaultValue returns this field contains a default value constraint.
func (f *FieldConstraint) HasDefaultValue() bool { func (f *FieldConstraint) HasDefaultValue() bool {
return f.DefaultValue.Type != 0 return f.DefaultValue != nil
} }
// FieldConstraints is a list of field constraints. // FieldConstraints is a list of field constraints.
@@ -223,14 +223,8 @@ func (f *FieldConstraints) Add(newFc *FieldConstraint) error {
inferredFc.IsNotNull = nonInferredFc.IsNotNull inferredFc.IsNotNull = nonInferredFc.IsNotNull
inferredFc.IsPrimaryKey = nonInferredFc.IsPrimaryKey inferredFc.IsPrimaryKey = nonInferredFc.IsPrimaryKey
// safe-guard in case we add more fields to the struct // detect if constraints are different
ok, err := c.IsEqual(newFc) if !c.IsEqual(newFc) {
if err != nil {
return err
}
// if constraints are different
if !ok {
return stringutil.Errorf("conflicting constraints: %q and %q", c.String(), newFc.String()) return stringutil.Errorf("conflicting constraints: %q and %q", c.String(), newFc.String())
} }
@@ -260,18 +254,25 @@ func (f *FieldConstraints) Add(newFc *FieldConstraint) error {
} }
} }
// convert default values to the right types // ensure default value type is compatible
targetType := newFc.Type if newFc.DefaultValue != nil && !newFc.Type.IsAny() {
// first, try to evaluate the default value
// if there is no type constraint, numbers must be converted to double v, err := newFc.DefaultValue.Eval(nil)
if newFc.DefaultValue.Type == document.IntegerValue && newFc.Type == 0 { // if there is no error, check if the default value can be converted to the type of the constraint
targetType = document.DoubleValue if err == nil {
} _, err = v.CastAs(newFc.Type)
if newFc.DefaultValue.Type != 0 && targetType != 0 { if err != nil {
var err error return stringutil.Errorf("default value %q cannot be converted to type %q", newFc.DefaultValue, newFc.Type)
newFc.DefaultValue, err = newFc.DefaultValue.CastAs(targetType) }
if err != nil { } else {
return err // if there is an error, we know we are using a function that returns an integer (NEXT VALUE FOR)
// which is the only one compatible for the moment.
// Integers can be converted to other integers, doubles, texts and bools.
switch newFc.Type {
case document.IntegerValue, document.DoubleValue, document.TextValue, document.BoolValue:
default:
return stringutil.Errorf("default value %q cannot be converted to type %q", newFc.DefaultValue, newFc.Type)
}
} }
} }
@@ -280,20 +281,55 @@ func (f *FieldConstraints) Add(newFc *FieldConstraint) error {
} }
// ValidateDocument calls Convert then ensures the document validates against the field constraints. // ValidateDocument calls Convert then ensures the document validates against the field constraints.
func (f FieldConstraints) ValidateDocument(d document.Document) (*document.FieldBuffer, error) { func (f FieldConstraints) ValidateDocument(tx *Transaction, d document.Document) (*document.FieldBuffer, error) {
fb, err := f.ConvertDocument(d) fb := document.NewFieldBuffer()
err := fb.Copy(d)
if err != nil {
return nil, err
}
// generate default values for all fields
for _, fc := range f {
if fc.DefaultValue == nil {
continue
}
_, err := fc.Path.GetValueFromDocument(fb)
if err == nil {
continue
}
if err != document.ErrFieldNotFound {
return nil, err
}
v, err := fc.DefaultValue.Eval(tx)
if err != nil {
return nil, err
}
err = fb.Set(fc.Path, v)
if err != nil {
return nil, err
}
}
fb, err = f.ConvertDocument(fb)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// ensure no field is missing // ensure no field is missing
for _, fc := range f { for _, fc := range f {
if !fc.IsNotNull {
continue
}
v, err := fc.Path.GetValueFromDocument(fb) v, err := fc.Path.GetValueFromDocument(fb)
if err == nil { if err == nil {
// if field is found, it has already been converted // if field is found, it has already been converted
// to the right type above. // to the right type above.
// check if it is required but null. // check if it is required but null.
if v.Type == document.NullValue && fc.IsNotNull { if v.Type == document.NullValue {
return nil, &ConstraintViolationError{"NOT NULL", fc.Path} return nil, &ConstraintViolationError{"NOT NULL", fc.Path}
} }
@@ -304,18 +340,7 @@ func (f FieldConstraints) ValidateDocument(d document.Document) (*document.Field
return nil, err return nil, err
} }
// if field is not found return nil, &ConstraintViolationError{"NOT NULL", fc.Path}
// check if there is a default value
if fc.DefaultValue.Type != 0 {
err = fb.Set(fc.Path, fc.DefaultValue)
if err != nil {
return nil, err
}
// if there is no default value
// check if field is required
} else if fc.IsNotNull {
return nil, &ConstraintViolationError{"NOT NULL", fc.Path}
}
} }
return fb, nil return fb, nil
@@ -390,13 +415,16 @@ func (f FieldConstraints) convertScalarAtPath(path document.Path, v document.Val
} }
func (f FieldConstraints) convertDocumentAtPath(path document.Path, d document.Document, conversionFn ConversionFunc) (*document.FieldBuffer, error) { func (f FieldConstraints) convertDocumentAtPath(path document.Path, d document.Document, conversionFn ConversionFunc) (*document.FieldBuffer, error) {
fb := document.NewFieldBuffer() fb, ok := d.(*document.FieldBuffer)
err := fb.Copy(d) if !ok {
if err != nil { fb = document.NewFieldBuffer()
return nil, err err := fb.Copy(d)
if err != nil {
return nil, err
}
} }
err = fb.Apply(func(p document.Path, v document.Value) (document.Value, error) { err := fb.Apply(func(p document.Path, v document.Value) (document.Value, error) {
return f.convertScalarAtPath(append(path, p...), v, conversionFn) return f.convertScalarAtPath(append(path, p...), v, conversionFn)
}) })
@@ -433,3 +461,10 @@ func (f *FieldConstraintIdentity) IsEqual(other *FieldConstraintIdentity) bool {
return f.SequenceName == other.SequenceName && f.Always == other.Always return f.SequenceName == other.SequenceName && f.Always == other.Always
} }
type TableExpression interface {
Bind(catalog Catalog)
Eval(tx *Transaction) (document.Value, error)
IsEqual(other TableExpression) bool
String() string
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/genjidb/genji/document" "github.com/genjidb/genji/document"
"github.com/genjidb/genji/internal/database" "github.com/genjidb/genji/internal/database"
"github.com/genjidb/genji/internal/expr"
"github.com/genjidb/genji/internal/testutil" "github.com/genjidb/genji/internal/testutil"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -200,20 +201,44 @@ func TestFieldConstraintsAdd(t *testing.T) {
{ {
"Default value conversion, typed constraint", "Default value conversion, typed constraint",
[]*database.FieldConstraint{{Path: document.NewPath("a"), Type: document.IntegerValue}}, []*database.FieldConstraint{{Path: document.NewPath("a"), Type: document.IntegerValue}},
database.FieldConstraint{Path: document.NewPath("b"), Type: document.IntegerValue, DefaultValue: document.NewDoubleValue(5)}, database.FieldConstraint{Path: document.NewPath("b"), Type: document.IntegerValue, DefaultValue: expr.Constraint(testutil.DoubleValue(5))},
[]*database.FieldConstraint{ []*database.FieldConstraint{
{Path: document.NewPath("a"), Type: document.IntegerValue}, {Path: document.NewPath("a"), Type: document.IntegerValue},
{Path: document.NewPath("b"), Type: document.IntegerValue, DefaultValue: document.NewIntegerValue(5)}, {Path: document.NewPath("b"), Type: document.IntegerValue, DefaultValue: expr.Constraint(testutil.DoubleValue(5))},
}, },
false, false,
}, },
{ {
"Default value conversion, untyped constraint", "Default value conversion, typed constraint, NEXT VALUE FOR",
[]*database.FieldConstraint{{Path: document.NewPath("a"), Type: document.IntegerValue}}, []*database.FieldConstraint{{Path: document.NewPath("a"), Type: document.IntegerValue}},
database.FieldConstraint{Path: document.NewPath("b"), DefaultValue: document.NewIntegerValue(5)}, database.FieldConstraint{Path: document.NewPath("b"), Type: document.IntegerValue, DefaultValue: expr.Constraint(expr.NextValueFor{SeqName: "seq"})},
[]*database.FieldConstraint{ []*database.FieldConstraint{
{Path: document.NewPath("a"), Type: document.IntegerValue}, {Path: document.NewPath("a"), Type: document.IntegerValue},
{Path: document.NewPath("b"), DefaultValue: document.NewDoubleValue(5)}, {Path: document.NewPath("b"), Type: document.IntegerValue, DefaultValue: expr.Constraint(expr.NextValueFor{SeqName: "seq"})},
},
false,
},
{
"Default value conversion, typed constraint, NEXT VALUE FOR with blob",
[]*database.FieldConstraint{{Path: document.NewPath("a"), Type: document.IntegerValue}},
database.FieldConstraint{Path: document.NewPath("b"), Type: document.BlobValue, DefaultValue: expr.Constraint(expr.NextValueFor{SeqName: "seq"})},
nil,
true,
},
{
"Default value conversion, typed constraint, incompatible value",
[]*database.FieldConstraint{{Path: document.NewPath("a"), Type: document.IntegerValue}},
database.FieldConstraint{Path: document.NewPath("b"), Type: document.DoubleValue, DefaultValue: expr.Constraint(testutil.BoolValue(true))},
nil,
true,
},
{
"Default value conversion, untyped constraint",
[]*database.FieldConstraint{{Path: document.NewPath("a"), Type: document.IntegerValue}},
database.FieldConstraint{Path: document.NewPath("b"), DefaultValue: expr.Constraint(testutil.IntegerValue(5))},
[]*database.FieldConstraint{
{Path: document.NewPath("a"), Type: document.IntegerValue},
{Path: document.NewPath("b"), DefaultValue: expr.Constraint(testutil.IntegerValue(5))},
}, },
false, false,
}, },
@@ -282,7 +307,7 @@ func TestFieldConstraintsConvert(t *testing.T) {
true, true,
}, },
{ {
database.FieldConstraints{{Path: document.NewPath("a"), DefaultValue: document.NewIntegerValue(10)}}, database.FieldConstraints{{Path: document.NewPath("a"), DefaultValue: expr.Constraint(testutil.IntegerValue(10))}},
document.NewPath("a"), document.NewPath("a"),
document.NewTextValue("foo"), document.NewTextValue("foo"),
document.NewTextValue("foo"), document.NewTextValue("foo"),

View File

@@ -57,7 +57,7 @@ func (t *Table) InsertWithConflictResolution(d document.Document, onConflict OnI
return nil, errors.New("cannot write to read-only table") return nil, errors.New("cannot write to read-only table")
} }
fb, err := t.Info.FieldConstraints.ValidateDocument(d) fb, err := t.Info.FieldConstraints.ValidateDocument(t.Tx, d)
if err != nil { if err != nil {
if onConflict != nil { if onConflict != nil {
if ce, ok := err.(*ConstraintViolationError); ok && ce.Constraint == "NOT NULL" { if ce, ok := err.(*ConstraintViolationError); ok && ce.Constraint == "NOT NULL" {
@@ -227,7 +227,7 @@ func (t *Table) Replace(key []byte, d document.Document) (document.Document, err
return nil, errors.New("cannot write to read-only table") return nil, errors.New("cannot write to read-only table")
} }
d, err := t.Info.FieldConstraints.ValidateDocument(d) d, err := t.Info.FieldConstraints.ValidateDocument(t.Tx, d)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -14,6 +14,7 @@ import (
"github.com/genjidb/genji/internal/binarysort" "github.com/genjidb/genji/internal/binarysort"
"github.com/genjidb/genji/internal/catalog" "github.com/genjidb/genji/internal/catalog"
"github.com/genjidb/genji/internal/database" "github.com/genjidb/genji/internal/database"
"github.com/genjidb/genji/internal/expr"
"github.com/genjidb/genji/internal/query/statement" "github.com/genjidb/genji/internal/query/statement"
"github.com/genjidb/genji/internal/testutil" "github.com/genjidb/genji/internal/testutil"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -386,8 +387,8 @@ func TestTableInsert(t *testing.T) {
tb := createTable(t, tx, db.Catalog, database.TableInfo{ tb := createTable(t, tx, db.Catalog, database.TableInfo{
TableName: "test", TableName: "test",
FieldConstraints: []*database.FieldConstraint{ FieldConstraints: []*database.FieldConstraint{
{testutil.ParseDocumentPath(t, "foo"), document.DocumentValue, false, false, false, document.Value{}, nil, true, []document.Path{testutil.ParseDocumentPath(t, "foo.bar")}}, {testutil.ParseDocumentPath(t, "foo"), document.DocumentValue, false, false, false, nil, nil, true, []document.Path{testutil.ParseDocumentPath(t, "foo.bar")}},
{testutil.ParseDocumentPath(t, "foo.bar"), document.IntegerValue, false, false, false, document.Value{}, nil, true, []document.Path{testutil.ParseDocumentPath(t, "foo")}}, {testutil.ParseDocumentPath(t, "foo.bar"), document.IntegerValue, false, false, false, nil, nil, true, []document.Path{testutil.ParseDocumentPath(t, "foo")}},
}, },
}) })
@@ -428,7 +429,7 @@ func TestTableInsert(t *testing.T) {
err := db.Catalog.CreateTable(tx, "test", &database.TableInfo{ err := db.Catalog.CreateTable(tx, "test", &database.TableInfo{
FieldConstraints: []*database.FieldConstraint{ FieldConstraints: []*database.FieldConstraint{
{testutil.ParseDocumentPath(t, "foo"), document.DoubleValue, false, false, false, document.Value{}, nil, false, nil}, {testutil.ParseDocumentPath(t, "foo"), document.DoubleValue, false, false, false, nil, nil, false, nil},
}, },
}) })
require.NoError(t, err) require.NoError(t, err)
@@ -451,7 +452,7 @@ func TestTableInsert(t *testing.T) {
tb1 := createTable(t, tx, db.Catalog, database.TableInfo{ tb1 := createTable(t, tx, db.Catalog, database.TableInfo{
TableName: "test1", TableName: "test1",
FieldConstraints: []*database.FieldConstraint{ FieldConstraints: []*database.FieldConstraint{
{testutil.ParseDocumentPath(t, "foo"), 0, false, true, false, document.Value{}, nil, false, nil}, {testutil.ParseDocumentPath(t, "foo"), 0, false, true, false, nil, nil, false, nil},
}, },
}) })
@@ -459,7 +460,7 @@ func TestTableInsert(t *testing.T) {
tb2 := createTable(t, tx, db.Catalog, database.TableInfo{ tb2 := createTable(t, tx, db.Catalog, database.TableInfo{
TableName: "test2", TableName: "test2",
FieldConstraints: []*database.FieldConstraint{ FieldConstraints: []*database.FieldConstraint{
{testutil.ParseDocumentPath(t, "foo"), document.IntegerValue, false, true, false, document.Value{}, nil, false, nil}, {testutil.ParseDocumentPath(t, "foo"), document.IntegerValue, false, true, false, nil, nil, false, nil},
}, },
}) })
@@ -502,7 +503,7 @@ func TestTableInsert(t *testing.T) {
tb1 := createTable(t, tx, db.Catalog, database.TableInfo{ tb1 := createTable(t, tx, db.Catalog, database.TableInfo{
TableName: "test1", TableName: "test1",
FieldConstraints: []*database.FieldConstraint{ FieldConstraints: []*database.FieldConstraint{
{testutil.ParseDocumentPath(t, "foo"), 0, false, true, false, document.NewIntegerValue(42), nil, false, nil}, {testutil.ParseDocumentPath(t, "foo"), 0, false, true, false, expr.Constraint(testutil.IntegerValue(42)), nil, false, nil},
}, },
}) })
@@ -510,7 +511,7 @@ func TestTableInsert(t *testing.T) {
tb2 := createTable(t, tx, db.Catalog, database.TableInfo{ tb2 := createTable(t, tx, db.Catalog, database.TableInfo{
TableName: "test2", TableName: "test2",
FieldConstraints: []*database.FieldConstraint{ FieldConstraints: []*database.FieldConstraint{
{testutil.ParseDocumentPath(t, "foo"), document.IntegerValue, false, true, false, document.NewIntegerValue(42), nil, false, nil}, {testutil.ParseDocumentPath(t, "foo"), document.IntegerValue, false, true, false, expr.Constraint(testutil.IntegerValue(42)), nil, false, nil},
}, },
}) })
@@ -564,7 +565,7 @@ func TestTableInsert(t *testing.T) {
tb := createTable(t, tx, db.Catalog, database.TableInfo{ tb := createTable(t, tx, db.Catalog, database.TableInfo{
TableName: "test1", TableName: "test1",
FieldConstraints: []*database.FieldConstraint{ FieldConstraints: []*database.FieldConstraint{
{testutil.ParseDocumentPath(t, "foo[1]"), 0, false, true, false, document.Value{}, nil, false, nil}, {testutil.ParseDocumentPath(t, "foo[1]"), 0, false, true, false, nil, nil, false, nil},
}, },
}) })
@@ -585,7 +586,7 @@ func TestTableInsert(t *testing.T) {
tb := createTable(t, tx, db.Catalog, database.TableInfo{ tb := createTable(t, tx, db.Catalog, database.TableInfo{
TableName: "test", TableName: "test",
FieldConstraints: []*database.FieldConstraint{ FieldConstraints: []*database.FieldConstraint{
{testutil.ParseDocumentPath(t, "foo"), 0, true, true, false, document.Value{}, nil, false, nil}, {testutil.ParseDocumentPath(t, "foo"), 0, true, true, false, nil, nil, false, nil},
}}) }})
doc := document.NewFieldBuffer(). doc := document.NewFieldBuffer().
@@ -607,7 +608,7 @@ func TestTableInsert(t *testing.T) {
tb := createTable(t, tx, db.Catalog, database.TableInfo{ tb := createTable(t, tx, db.Catalog, database.TableInfo{
TableName: "test", TableName: "test",
FieldConstraints: []*database.FieldConstraint{ FieldConstraints: []*database.FieldConstraint{
{testutil.ParseDocumentPath(t, "foo"), 0, true, true, false, document.Value{}, nil, false, nil}, {testutil.ParseDocumentPath(t, "foo"), 0, true, true, false, nil, nil, false, nil},
}}) }})
doc := document.NewFieldBuffer(). doc := document.NewFieldBuffer().
@@ -701,7 +702,7 @@ func TestTableInsert(t *testing.T) {
tb := createTable(t, tx, db.Catalog, database.TableInfo{ tb := createTable(t, tx, db.Catalog, database.TableInfo{
TableName: "test", TableName: "test",
FieldConstraints: []*database.FieldConstraint{ FieldConstraints: []*database.FieldConstraint{
{testutil.ParseDocumentPath(t, "foo"), 0, false, true, false, document.Value{}, nil, false, nil}, {testutil.ParseDocumentPath(t, "foo"), 0, false, true, false, nil, nil, false, nil},
}, },
}) })
@@ -726,7 +727,7 @@ func TestTableInsert(t *testing.T) {
err := db.Catalog.CreateTable(tx, "test", &database.TableInfo{ err := db.Catalog.CreateTable(tx, "test", &database.TableInfo{
FieldConstraints: []*database.FieldConstraint{ FieldConstraints: []*database.FieldConstraint{
{testutil.ParseDocumentPath(t, "foo"), 0, true, true, false, document.Value{}, nil, false, nil}, {testutil.ParseDocumentPath(t, "foo"), 0, true, true, false, nil, nil, false, nil},
}}) }})
require.NoError(t, err) require.NoError(t, err)
@@ -752,7 +753,7 @@ func TestTableInsert(t *testing.T) {
err := db.Catalog.CreateTable(tx, "test", &database.TableInfo{ err := db.Catalog.CreateTable(tx, "test", &database.TableInfo{
FieldConstraints: []*database.FieldConstraint{ FieldConstraints: []*database.FieldConstraint{
{testutil.ParseDocumentPath(t, "foo"), 0, false, true, false, document.Value{}, nil, false, nil}, {testutil.ParseDocumentPath(t, "foo"), 0, false, true, false, nil, nil, false, nil},
}}) }})
require.NoError(t, err) require.NoError(t, err)

View File

@@ -0,0 +1,54 @@
package expr
import (
"errors"
"github.com/genjidb/genji/document"
"github.com/genjidb/genji/internal/database"
"github.com/genjidb/genji/internal/environment"
)
type ConstraintExpr struct {
Expr Expr
Catalog database.Catalog
}
func Constraint(e Expr) *ConstraintExpr {
return &ConstraintExpr{
Expr: e,
}
}
func (t *ConstraintExpr) Eval(tx *database.Transaction) (document.Value, error) {
var env environment.Environment
env.Catalog = t.Catalog
env.Tx = tx
if t.Expr == nil {
return NullLitteral, errors.New("missing expression")
}
return t.Expr.Eval(&env)
}
func (t *ConstraintExpr) Bind(catalog database.Catalog) {
t.Catalog = catalog
}
func (t *ConstraintExpr) IsEqual(other database.TableExpression) bool {
if t == nil {
return other == nil
}
if other == nil {
return false
}
o, ok := other.(*ConstraintExpr)
if !ok {
return false
}
return Equal(t.Expr, o.Expr)
}
func (t *ConstraintExpr) String() string {
return t.Expr.String()
}

View File

@@ -129,7 +129,7 @@ func Walk(e Expr, fn func(Expr) bool) bool {
} }
} }
return false return true
} }
type NextValueFor struct { type NextValueFor struct {
@@ -141,6 +141,10 @@ func (n NextValueFor) Eval(env *environment.Environment) (document.Value, error)
catalog := env.GetCatalog() catalog := env.GetCatalog()
tx := env.GetTx() tx := env.GetTx()
if catalog == nil || tx == nil {
return NullLitteral, stringutil.Errorf(`NEXT VALUE FOR cannot be evaluated`)
}
seq, err := catalog.GetSequence(n.SeqName) seq, err := catalog.GetSequence(n.SeqName)
if err != nil { if err != nil {
return NullLitteral, err return NullLitteral, err

View File

@@ -1,6 +1,8 @@
package expr package expr
import ( import (
"errors"
"github.com/genjidb/genji/document" "github.com/genjidb/genji/document"
"github.com/genjidb/genji/internal/environment" "github.com/genjidb/genji/internal/environment"
"github.com/genjidb/genji/internal/sql/scanner" "github.com/genjidb/genji/internal/sql/scanner"
@@ -37,6 +39,10 @@ func (op *simpleOperator) Token() scanner.Token {
} }
func (op *simpleOperator) eval(env *environment.Environment, fn func(a, b document.Value) (document.Value, error)) (document.Value, error) { func (op *simpleOperator) eval(env *environment.Environment, fn func(a, b document.Value) (document.Value, error)) (document.Value, error) {
if op.a == nil || op.b == nil {
return NullLitteral, errors.New("missing operand")
}
va, err := op.a.Eval(env) va, err := op.a.Eval(env)
if err != nil { if err != nil {
return NullLitteral, err return NullLitteral, err

View File

@@ -5,6 +5,7 @@ import (
"github.com/genjidb/genji/document" "github.com/genjidb/genji/document"
"github.com/genjidb/genji/internal/database" "github.com/genjidb/genji/internal/database"
"github.com/genjidb/genji/internal/expr"
"github.com/genjidb/genji/internal/sql/parser" "github.com/genjidb/genji/internal/sql/parser"
"github.com/genjidb/genji/internal/testutil" "github.com/genjidb/genji/internal/testutil"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -201,9 +202,9 @@ func TestCreateTable(t *testing.T) {
constraints database.FieldConstraints constraints database.FieldConstraints
fails bool fails bool
}{ }{
{"With default, no type and integer default", "CREATE TABLE test(foo DEFAULT 10)", database.FieldConstraints{{Path: parsePath(t, "foo"), DefaultValue: document.NewDoubleValue(10)}}, false}, {"With default, no type and integer default", "CREATE TABLE test(foo DEFAULT 10)", database.FieldConstraints{{Path: parsePath(t, "foo"), DefaultValue: expr.Constraint(testutil.IntegerValue(10))}}, false},
{"With default, double type and integer default", "CREATE TABLE test(foo DOUBLE DEFAULT 10)", database.FieldConstraints{{Path: parsePath(t, "foo"), Type: document.DoubleValue, DefaultValue: document.NewDoubleValue(10)}}, false}, {"With default, double type and integer default", "CREATE TABLE test(foo DOUBLE DEFAULT 10)", database.FieldConstraints{{Path: parsePath(t, "foo"), Type: document.DoubleValue, DefaultValue: expr.Constraint(testutil.IntegerValue(10))}}, false},
{"With default, some type and compatible default", "CREATE TABLE test(foo BOOL DEFAULT 10)", database.FieldConstraints{{Path: parsePath(t, "foo"), Type: document.BoolValue, DefaultValue: document.NewBoolValue(true)}}, false}, {"With default, some type and compatible default", "CREATE TABLE test(foo BOOL DEFAULT 10)", database.FieldConstraints{{Path: parsePath(t, "foo"), Type: document.BoolValue, DefaultValue: expr.Constraint(testutil.IntegerValue(10))}}, false},
{"With default, some type and incompatible default", "CREATE TABLE test(foo BOOL DEFAULT 10.5)", nil, true}, {"With default, some type and incompatible default", "CREATE TABLE test(foo BOOL DEFAULT 10.5)", nil, true},
} }
@@ -221,6 +222,12 @@ func TestCreateTable(t *testing.T) {
tb, err := db.Catalog.GetTable(tx, "test") tb, err := db.Catalog.GetTable(tx, "test")
require.NoError(t, err) require.NoError(t, err)
for _, fc := range test.constraints {
if fc.DefaultValue != nil {
fc.DefaultValue.(*expr.ConstraintExpr).Catalog = db.Catalog
}
}
require.Equal(t, test.constraints, tb.Info.FieldConstraints) require.Equal(t, test.constraints, tb.Info.FieldConstraints)
}) })
} }

View File

@@ -122,6 +122,31 @@ func TestInsertStmt(t *testing.T) {
testutil.RequireStreamEq(t, ``, res) testutil.RequireStreamEq(t, ``, res)
}) })
t.Run("with NEXT VALUE FOR", func(t *testing.T) {
db, err := genji.Open(":memory:")
require.NoError(t, err)
defer db.Close()
err = db.Exec(`CREATE SEQUENCE seq; CREATE TABLE test(a int, b int default NEXT VALUE FOR seq)`)
require.NoError(t, err)
err = db.Exec(`insert into test (a) VALUES (1), (2), (3)`)
require.NoError(t, err)
res, err := db.Query("SELECT * FROM test")
require.NoError(t, err)
defer res.Close()
var b bytes.Buffer
testutil.IteratorToJSONArray(&b, res)
require.JSONEq(t, `
[{"a": 1, "b": 1},
{"a": 2, "b": 2},
{"a": 3, "b": 3}]
`, b.String())
})
// t.Run("without RETURNING", func(t *testing.T) { // t.Run("without RETURNING", func(t *testing.T) {
// db, err := genji.Open(":memory:") // db, err := genji.Open(":memory:")
// require.NoError(t, err) // require.NoError(t, err)

View File

@@ -5,6 +5,7 @@ import (
"github.com/genjidb/genji/document" "github.com/genjidb/genji/document"
"github.com/genjidb/genji/internal/database" "github.com/genjidb/genji/internal/database"
"github.com/genjidb/genji/internal/expr"
"github.com/genjidb/genji/internal/query/statement" "github.com/genjidb/genji/internal/query/statement"
"github.com/genjidb/genji/internal/sql/parser" "github.com/genjidb/genji/internal/sql/parser"
"github.com/genjidb/genji/internal/testutil" "github.com/genjidb/genji/internal/testutil"
@@ -66,7 +67,7 @@ func TestParserAlterTableAddField(t *testing.T) {
Path: document.Path(testutil.ParsePath(t, "bar")), Path: document.Path(testutil.ParsePath(t, "bar")),
Type: document.IntegerValue, Type: document.IntegerValue,
IsNotNull: true, IsNotNull: true,
DefaultValue: document.NewIntegerValue(0), DefaultValue: expr.Constraint(expr.LiteralValue(document.NewIntegerValue(0))),
}, },
}, false}, }, false},
{"With error / missing FIELD keyword", "ALTER TABLE foo ADD bar", nil, true}, {"With error / missing FIELD keyword", "ALTER TABLE foo ADD bar", nil, true},

View File

@@ -4,7 +4,7 @@ import (
"math" "math"
"github.com/genjidb/genji/internal/database" "github.com/genjidb/genji/internal/database"
"github.com/genjidb/genji/internal/environment" "github.com/genjidb/genji/internal/expr"
"github.com/genjidb/genji/internal/query/statement" "github.com/genjidb/genji/internal/query/statement"
"github.com/genjidb/genji/internal/sql/scanner" "github.com/genjidb/genji/internal/sql/scanner"
"github.com/genjidb/genji/internal/stringutil" "github.com/genjidb/genji/internal/stringutil"
@@ -71,7 +71,7 @@ func (p *Parser) parseFieldDefinition(fc *database.FieldConstraint) (err error)
return err return err
} }
if fc.Type.IsAny() && fc.DefaultValue.Type.IsAny() && !fc.IsNotNull && !fc.IsPrimaryKey && !fc.IsUnique { if fc.Type.IsAny() && fc.DefaultValue == nil && !fc.IsNotNull && !fc.IsPrimaryKey && !fc.IsUnique {
tok, pos, lit := p.ScanIgnoreWhitespace() tok, pos, lit := p.ScanIgnoreWhitespace()
return newParseError(scanner.Tokstr(tok, lit), []string{"CONSTRAINT", "TYPE"}, pos) return newParseError(scanner.Tokstr(tok, lit), []string{"CONSTRAINT", "TYPE"}, pos)
} }
@@ -118,7 +118,10 @@ func (p *Parser) parseConstraints(stmt *statement.CreateTableStmt) error {
return err return err
} }
stmt.Info.FieldConstraints = append(stmt.Info.FieldConstraints, &fc) err = stmt.Info.FieldConstraints.Add(&fc)
if err != nil {
return err
}
} }
if tok, _, _ := p.ScanIgnoreWhitespace(); tok != scanner.COMMA { if tok, _, _ := p.ScanIgnoreWhitespace(); tok != scanner.COMMA {
@@ -176,23 +179,44 @@ func (p *Parser) parseFieldConstraint(fc *database.FieldConstraint) error {
fc.IsNotNull = true fc.IsNotNull = true
case scanner.DEFAULT: case scanner.DEFAULT:
// Parse default value expression.
e, err := p.parseUnaryExpr()
if err != nil {
return err
}
d, err := e.Eval(&environment.Environment{})
if err != nil {
return err
}
// if it has already a default value we return an error // if it has already a default value we return an error
if fc.HasDefaultValue() { if fc.HasDefaultValue() {
return newParseError(scanner.Tokstr(tok, lit), []string{"CONSTRAINT", ")"}, pos) return newParseError(scanner.Tokstr(tok, lit), []string{"CONSTRAINT", ")"}, pos)
} }
fc.DefaultValue = d // Parse default value expression.
// Only a few tokens are allowed.
e, err := p.parseExprWithMinPrecedence(scanner.EQ.Precedence(),
scanner.EQ,
scanner.NEQ,
scanner.BITWISEOR,
scanner.BITWISEXOR,
scanner.BITWISEAND,
scanner.LT,
scanner.LTE,
scanner.GT,
scanner.GTE,
scanner.ADD,
scanner.SUB,
scanner.MUL,
scanner.DIV,
scanner.MOD,
scanner.CONCAT,
scanner.INTEGER,
scanner.NUMBER,
scanner.STRING,
scanner.TRUE,
scanner.FALSE,
scanner.NULL,
scanner.LPAREN, // only opening parenthesis are necessary
scanner.LBRACKET, // only opening brackets are necessary
scanner.NEXT,
)
if err != nil {
return err
}
fc.DefaultValue = expr.Constraint(e)
case scanner.UNIQUE: case scanner.UNIQUE:
// if it's already unique we return an error // if it's already unique we return an error
if fc.IsUnique { if fc.IsUnique {

View File

@@ -6,6 +6,7 @@ import (
"github.com/genjidb/genji/document" "github.com/genjidb/genji/document"
"github.com/genjidb/genji/internal/database" "github.com/genjidb/genji/internal/database"
"github.com/genjidb/genji/internal/expr"
"github.com/genjidb/genji/internal/query/statement" "github.com/genjidb/genji/internal/query/statement"
"github.com/genjidb/genji/internal/sql/parser" "github.com/genjidb/genji/internal/sql/parser"
"github.com/genjidb/genji/internal/testutil" "github.com/genjidb/genji/internal/testutil"
@@ -21,7 +22,7 @@ func TestParserCreateTable(t *testing.T) {
}{ }{
{"Basic", "CREATE TABLE test", &statement.CreateTableStmt{Info: database.TableInfo{TableName: "test"}}, false}, {"Basic", "CREATE TABLE test", &statement.CreateTableStmt{Info: database.TableInfo{TableName: "test"}}, false},
{"If not exists", "CREATE TABLE IF NOT EXISTS test", &statement.CreateTableStmt{Info: database.TableInfo{TableName: "test"}, IfNotExists: true}, false}, {"If not exists", "CREATE TABLE IF NOT EXISTS test", &statement.CreateTableStmt{Info: database.TableInfo{TableName: "test"}, IfNotExists: true}, false},
{"Path only", "CREATE TABLE test(a)", &statement.CreateTableStmt{}, true}, {"Path only", "CREATE TABLE test(a)", nil, true},
{"With primary key", "CREATE TABLE test(foo INTEGER PRIMARY KEY)", {"With primary key", "CREATE TABLE test(foo INTEGER PRIMARY KEY)",
&statement.CreateTableStmt{ &statement.CreateTableStmt{
Info: database.TableInfo{ Info: database.TableInfo{
@@ -31,8 +32,7 @@ func TestParserCreateTable(t *testing.T) {
}, },
}, },
}, false}, }, false},
{"With primary key twice", "CREATE TABLE test(foo PRIMARY KEY PRIMARY KEY)", {"With primary key twice", "CREATE TABLE test(foo PRIMARY KEY PRIMARY KEY)", nil, true},
&statement.CreateTableStmt{}, true},
{"With type", "CREATE TABLE test(foo INTEGER)", {"With type", "CREATE TABLE test(foo INTEGER)",
&statement.CreateTableStmt{ &statement.CreateTableStmt{
Info: database.TableInfo{ Info: database.TableInfo{
@@ -56,10 +56,13 @@ func TestParserCreateTable(t *testing.T) {
Info: database.TableInfo{ Info: database.TableInfo{
TableName: "test", TableName: "test",
FieldConstraints: []*database.FieldConstraint{ FieldConstraints: []*database.FieldConstraint{
{Path: document.Path(testutil.ParsePath(t, "foo")), DefaultValue: document.NewTextValue("10")}, {Path: document.Path(testutil.ParsePath(t, "foo")), DefaultValue: expr.Constraint(expr.LiteralValue(document.NewTextValue("10")))},
}, },
}, },
}, false}, }, false},
{"With default twice", "CREATE TABLE test(foo DEFAULT 10 DEFAULT 10)", nil, true},
{"With forbidden tokens", "CREATE TABLE test(foo DEFAULT a)", nil, true},
{"With forbidden tokens", "CREATE TABLE test(foo DEFAULT 1 AND 2)", nil, true},
{"With unique", "CREATE TABLE test(foo UNIQUE)", {"With unique", "CREATE TABLE test(foo UNIQUE)",
&statement.CreateTableStmt{ &statement.CreateTableStmt{
Info: database.TableInfo{ Info: database.TableInfo{
@@ -69,12 +72,9 @@ func TestParserCreateTable(t *testing.T) {
}, },
}, },
}, false}, }, false},
{"With default twice", "CREATE TABLE test(foo DEFAULT 10 DEFAULT 10)",
&statement.CreateTableStmt{}, true}, {"With not null twice", "CREATE TABLE test(foo NOT NULL NOT NULL)", nil, true},
{"With not null twice", "CREATE TABLE test(foo NOT NULL NOT NULL)", {"With unique twice", "CREATE TABLE test(foo UNIQUE UNIQUE)", nil, true},
&statement.CreateTableStmt{}, true},
{"With unique twice", "CREATE TABLE test(foo UNIQUE UNIQUE)",
&statement.CreateTableStmt{}, true},
{"With type and not null", "CREATE TABLE test(foo INTEGER NOT NULL)", {"With type and not null", "CREATE TABLE test(foo INTEGER NOT NULL)",
&statement.CreateTableStmt{ &statement.CreateTableStmt{
Info: database.TableInfo{ Info: database.TableInfo{
@@ -166,8 +166,7 @@ func TestParserCreateTable(t *testing.T) {
}, },
}, false}, }, false},
{"With table constraints / duplicate pk on same path", "CREATE TABLE test(foo INTEGER PRIMARY KEY, PRIMARY KEY (foo))", nil, true}, {"With table constraints / duplicate pk on same path", "CREATE TABLE test(foo INTEGER PRIMARY KEY, PRIMARY KEY (foo))", nil, true},
{"With multiple primary keys", "CREATE TABLE test(foo PRIMARY KEY, bar PRIMARY KEY)", {"With multiple primary keys", "CREATE TABLE test(foo PRIMARY KEY, bar PRIMARY KEY)", nil, true},
&statement.CreateTableStmt{}, true},
{"With all supported fixed size data types", {"With all supported fixed size data types",
"CREATE TABLE test(d double, b bool)", "CREATE TABLE test(d double, b bool)",
&statement.CreateTableStmt{ &statement.CreateTableStmt{

View File

@@ -31,13 +31,13 @@ func (p *Parser) ParseExpr() (e expr.Expr, err error) {
return p.parseExprWithMinPrecedence(0) return p.parseExprWithMinPrecedence(0)
} }
func (p *Parser) parseExprWithMinPrecedence(precedence int) (e expr.Expr, err error) { func (p *Parser) parseExprWithMinPrecedence(precedence int, allowed ...scanner.Token) (e expr.Expr, err error) {
// Dummy root node. // Dummy root node.
var root expr.Operator = new(dummyOperator) var root expr.Operator = new(dummyOperator)
// Parse a non-binary expression type to start. // Parse a non-binary expression type to start.
// This variable will always be the root of the expression tree. // This variable will always be the root of the expression tree.
e, err = p.parseUnaryExpr() e, err = p.parseUnaryExpr(allowed...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -46,7 +46,7 @@ func (p *Parser) parseExprWithMinPrecedence(precedence int) (e expr.Expr, err er
// Loop over operations and unary exprs and build a tree based on precedence. // Loop over operations and unary exprs and build a tree based on precedence.
for { for {
// If the next token is NOT an operator then return the expression. // If the next token is NOT an operator then return the expression.
op, tok, err := p.parseOperator(precedence) op, tok, err := p.parseOperator(precedence, allowed...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -56,7 +56,7 @@ func (p *Parser) parseExprWithMinPrecedence(precedence int) (e expr.Expr, err er
var rhs expr.Expr var rhs expr.Expr
if rhs, err = p.parseUnaryExpr(); err != nil { if rhs, err = p.parseUnaryExpr(allowed...); err != nil {
return nil, err return nil, err
} }
@@ -76,75 +76,89 @@ func (p *Parser) parseExprWithMinPrecedence(precedence int) (e expr.Expr, err er
} }
} }
func (p *Parser) parseOperator(minPrecedence int) (func(lhs, rhs expr.Expr) expr.Expr, scanner.Token, error) { func (p *Parser) parseOperator(minPrecedence int, allowed ...scanner.Token) (func(lhs, rhs expr.Expr) expr.Expr, scanner.Token, error) {
op, _, _ := p.ScanIgnoreWhitespace() op, _, _ := p.ScanIgnoreWhitespace()
if !op.IsOperator() && op != scanner.NOT { if !op.IsOperator() && op != scanner.NOT {
p.Unscan() p.Unscan()
return nil, 0, nil return nil, 0, nil
} }
if !tokenIsAllowed(op, allowed...) {
p.Unscan()
return nil, 0, nil
}
// Ignore currently unused operators. // Ignore currently unused operators.
if op == scanner.EQREGEX || op == scanner.NEQREGEX { if op == scanner.EQREGEX || op == scanner.NEQREGEX {
p.Unscan() p.Unscan()
return nil, 0, nil return nil, 0, nil
} }
switch { if op == scanner.NOT {
case op == scanner.EQ && op.Precedence() >= minPrecedence: tok, pos, lit := p.ScanIgnoreWhitespace()
if tok.Precedence() >= minPrecedence {
switch {
case tok == scanner.IN && tok.Precedence() >= minPrecedence:
return expr.NotIn, op, nil
case tok == scanner.LIKE && tok.Precedence() >= minPrecedence:
return expr.NotLike, op, nil
}
}
return nil, 0, newParseError(scanner.Tokstr(tok, lit), []string{"IN, LIKE"}, pos)
}
if op.Precedence() < minPrecedence {
p.Unscan()
return nil, 0, nil
}
switch op {
case scanner.EQ:
return expr.Eq, op, nil return expr.Eq, op, nil
case op == scanner.NEQ && op.Precedence() >= minPrecedence: case scanner.NEQ:
return expr.Neq, op, nil return expr.Neq, op, nil
case op == scanner.GT && op.Precedence() >= minPrecedence: case scanner.GT:
return expr.Gt, op, nil return expr.Gt, op, nil
case op == scanner.GTE && op.Precedence() >= minPrecedence: case scanner.GTE:
return expr.Gte, op, nil return expr.Gte, op, nil
case op == scanner.LT && op.Precedence() >= minPrecedence: case scanner.LT:
return expr.Lt, op, nil return expr.Lt, op, nil
case op == scanner.LTE && op.Precedence() >= minPrecedence: case scanner.LTE:
return expr.Lte, op, nil return expr.Lte, op, nil
case op == scanner.AND && op.Precedence() >= minPrecedence: case scanner.AND:
return expr.And, op, nil return expr.And, op, nil
case op == scanner.OR && op.Precedence() >= minPrecedence: case scanner.OR:
return expr.Or, op, nil return expr.Or, op, nil
case op == scanner.ADD && op.Precedence() >= minPrecedence: case scanner.ADD:
return expr.Add, op, nil return expr.Add, op, nil
case op == scanner.SUB && op.Precedence() >= minPrecedence: case scanner.SUB:
return expr.Sub, op, nil return expr.Sub, op, nil
case op == scanner.MUL && op.Precedence() >= minPrecedence: case scanner.MUL:
return expr.Mul, op, nil return expr.Mul, op, nil
case op == scanner.DIV && op.Precedence() >= minPrecedence: case scanner.DIV:
return expr.Div, op, nil return expr.Div, op, nil
case op == scanner.MOD && op.Precedence() >= minPrecedence: case scanner.MOD:
return expr.Mod, op, nil return expr.Mod, op, nil
case op == scanner.BITWISEAND && op.Precedence() >= minPrecedence: case scanner.BITWISEAND:
return expr.BitwiseAnd, op, nil return expr.BitwiseAnd, op, nil
case op == scanner.BITWISEOR && op.Precedence() >= minPrecedence: case scanner.BITWISEOR:
return expr.BitwiseOr, op, nil return expr.BitwiseOr, op, nil
case op == scanner.BITWISEXOR && op.Precedence() >= minPrecedence: case scanner.BITWISEXOR:
return expr.BitwiseXor, op, nil return expr.BitwiseXor, op, nil
case op == scanner.IN && op.Precedence() >= minPrecedence: case scanner.IN:
return expr.In, op, nil return expr.In, op, nil
case op == scanner.IS && op.Precedence() >= minPrecedence: case scanner.IS:
if tok, _, _ := p.ScanIgnoreWhitespace(); tok == scanner.NOT { if tok, _, _ := p.ScanIgnoreWhitespace(); tok == scanner.NOT {
return expr.IsNot, op, nil return expr.IsNot, op, nil
} }
p.Unscan() p.Unscan()
return expr.Is, op, nil return expr.Is, op, nil
case op == scanner.NOT: case scanner.LIKE:
tok, pos, lit := p.ScanIgnoreWhitespace()
switch {
case tok == scanner.IN && tok.Precedence() >= minPrecedence:
return expr.NotIn, op, nil
case tok == scanner.LIKE && tok.Precedence() >= minPrecedence:
return expr.NotLike, op, nil
}
return nil, 0, newParseError(scanner.Tokstr(tok, lit), []string{"IN, LIKE"}, pos)
case op == scanner.LIKE && op.Precedence() >= minPrecedence:
return expr.Like, op, nil return expr.Like, op, nil
case op == scanner.CONCAT && op.Precedence() >= minPrecedence: case scanner.CONCAT:
return expr.Concat, op, nil return expr.Concat, op, nil
case op == scanner.BETWEEN && op.Precedence() >= minPrecedence: case scanner.BETWEEN:
a, err := p.parseExprWithMinPrecedence(op.Precedence()) a, err := p.parseExprWithMinPrecedence(op.Precedence())
if err != nil { if err != nil {
return nil, op, err return nil, op, err
@@ -163,8 +177,14 @@ func (p *Parser) parseOperator(minPrecedence int) (func(lhs, rhs expr.Expr) expr
} }
// parseUnaryExpr parses an non-binary expression. // parseUnaryExpr parses an non-binary expression.
func (p *Parser) parseUnaryExpr() (expr.Expr, error) { func (p *Parser) parseUnaryExpr(allowed ...scanner.Token) (expr.Expr, error) {
tok, pos, lit := p.ScanIgnoreWhitespace() tok, pos, lit := p.ScanIgnoreWhitespace()
if !tokenIsAllowed(tok, allowed...) {
p.Unscan()
return nil, nil
}
switch tok { switch tok {
case scanner.CAST: case scanner.CAST:
p.Unscan() p.Unscan()
@@ -268,7 +288,7 @@ func (p *Parser) parseUnaryExpr() (expr.Expr, error) {
return expr.NextValueFor{SeqName: seqName}, nil return expr.NextValueFor{SeqName: seqName}, nil
default: default:
return nil, newParseError(scanner.Tokstr(tok, lit), []string{"identifier", "string", "number", "bool"}, pos) return nil, newParseError(scanner.Tokstr(tok, lit), nil, pos)
} }
} }
@@ -639,3 +659,16 @@ func (p *Parser) parseCastExpression() (expr.Expr, error) {
return expr.CastFunc{Expr: e, CastAs: tp}, nil return expr.CastFunc{Expr: e, CastAs: tp}, nil
} }
// tokenIsAllowed is a helper function that determines if a token is allowed.
func tokenIsAllowed(tok scanner.Token, allowed ...scanner.Token) bool {
if allowed == nil {
return true
}
for _, a := range allowed {
if tok == a {
return true
}
}
return false
}