distribute Value logic to each implementation

This commit is contained in:
Asdine El Hrychy
2024-01-20 18:47:14 +04:00
parent f970bc625e
commit 21ac003166
16 changed files with 1299 additions and 501 deletions

View File

@@ -132,14 +132,14 @@ func (r *Range) IsEqual(other *Range) bool {
} }
for i := range r.Min { for i := range r.Min {
eq, err := types.IsEqual(r.Min[i], other.Min[i]) eq, err := r.Min[i].EQ(other.Min[i])
if err != nil || !eq { if err != nil || !eq {
return false return false
} }
} }
for i := range r.Max { for i := range r.Max {
eq, err := types.IsEqual(r.Max[i], other.Max[i]) eq, err := r.Max[i].EQ(other.Max[i])
if err != nil || !eq { if err != nil || !eq {
return false return false
} }

View File

@@ -116,7 +116,7 @@ func TestTableGetObject(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
fc, err := res.Get("fieldc") fc, err := res.Get("fieldc")
assert.NoError(t, err) assert.NoError(t, err)
ok, err := types.IsEqual(vc, fc) ok, err := vc.EQ(fc)
assert.NoError(t, err) assert.NoError(t, err)
require.True(t, ok) require.True(t, ok)
}) })

View File

@@ -21,21 +21,21 @@ func (op *arithmeticOperator) Eval(env *environment.Environment) (types.Value, e
return op.simpleOperator.eval(env, func(a, b types.Value) (types.Value, error) { return op.simpleOperator.eval(env, func(a, b types.Value) (types.Value, error) {
switch op.simpleOperator.Tok { switch op.simpleOperator.Tok {
case scanner.ADD: case scanner.ADD:
return types.Add(a, b) return a.Add(b)
case scanner.SUB: case scanner.SUB:
return types.Sub(a, b) return a.Sub(b)
case scanner.MUL: case scanner.MUL:
return types.Mul(a, b) return a.Mul(b)
case scanner.DIV: case scanner.DIV:
return types.Div(a, b) return a.Div(b)
case scanner.MOD: case scanner.MOD:
return types.Mod(a, b) return a.Mod(b)
case scanner.BITWISEAND: case scanner.BITWISEAND:
return types.BitwiseAnd(a, b) return a.BitwiseAnd(b)
case scanner.BITWISEOR: case scanner.BITWISEOR:
return types.BitwiseOr(a, b) return a.BitwiseOr(b)
case scanner.BITWISEXOR: case scanner.BITWISEXOR:
return types.BitwiseXor(a, b) return a.BitwiseXor(b)
} }
panic("unknown arithmetic token") panic("unknown arithmetic token")

View File

@@ -40,17 +40,21 @@ func (op *cmpOp) Eval(env *environment.Environment) (types.Value, error) {
func (op *cmpOp) compare(l, r types.Value) (bool, error) { func (op *cmpOp) compare(l, r types.Value) (bool, error) {
switch op.Tok { switch op.Tok {
case scanner.EQ: case scanner.EQ:
return types.IsEqual(l, r) return l.EQ(r)
case scanner.NEQ: case scanner.NEQ:
return types.IsNotEqual(l, r) eq, err := l.EQ(r)
if err != nil {
return false, err
}
return !eq, nil
case scanner.GT: case scanner.GT:
return types.IsGreaterThan(l, r) return l.GT(r)
case scanner.GTE: case scanner.GTE:
return types.IsGreaterThanOrEqual(l, r) return l.GTE(r)
case scanner.LT: case scanner.LT:
return types.IsLesserThan(l, r) return l.LT(r)
case scanner.LTE: case scanner.LTE:
return types.IsLesserThanOrEqual(l, r) return l.LTE(r)
default: default:
panic(fmt.Sprintf("unknown token %v", op.Tok)) panic(fmt.Sprintf("unknown token %v", op.Tok))
} }
@@ -106,21 +110,20 @@ func (op *BetweenOperator) Eval(env *environment.Environment) (types.Value, erro
} }
return op.simpleOperator.eval(env, func(a, b types.Value) (types.Value, error) { return op.simpleOperator.eval(env, func(a, b types.Value) (types.Value, error) {
if a.Type() == types.TypeNull || b.Type() == types.TypeNull { if a.Type() == types.TypeNull || b.Type() == types.TypeNull || x.Type() == types.TypeNull {
return NullLiteral, nil return NullLiteral, nil
} }
ok, err := types.IsGreaterThanOrEqual(x, a) ok, err := x.Between(a, b)
if !ok || err != nil { if err != nil {
return FalseLiteral, err return NullLiteral, err
} }
ok, err = types.IsLesserThanOrEqual(x, b) if ok {
if !ok || err != nil { return TrueLiteral, nil
return FalseLiteral, err
} }
return TrueLiteral, nil return FalseLiteral, nil
}) })
} }
@@ -198,7 +201,7 @@ func Is(a, b Expr) Expr {
func (op *IsOperator) Eval(env *environment.Environment) (types.Value, error) { func (op *IsOperator) Eval(env *environment.Environment) (types.Value, error) {
return op.simpleOperator.eval(env, func(a, b types.Value) (types.Value, error) { return op.simpleOperator.eval(env, func(a, b types.Value) (types.Value, error) {
ok, err := types.IsEqual(a, b) ok, err := a.EQ(b)
if err != nil { if err != nil {
return NullLiteral, err return NullLiteral, err
} }
@@ -221,11 +224,11 @@ func IsNot(a, b Expr) Expr {
func (op *IsNotOperator) Eval(env *environment.Environment) (types.Value, error) { func (op *IsNotOperator) Eval(env *environment.Environment) (types.Value, error) {
return op.simpleOperator.eval(env, func(a, b types.Value) (types.Value, error) { return op.simpleOperator.eval(env, func(a, b types.Value) (types.Value, error) {
ok, err := types.IsNotEqual(a, b) eq, err := a.EQ(b)
if err != nil { if err != nil {
return NullLiteral, err return NullLiteral, err
} }
if ok { if !eq {
return TrueLiteral, nil return TrueLiteral, nil
} }

View File

@@ -363,7 +363,7 @@ func (m *MinAggregator) Aggregate(env *environment.Environment) error {
} }
if m.Min.Type() == v.Type() || m.Min.Type().IsNumber() && v.Type().IsNumber() { if m.Min.Type() == v.Type() || m.Min.Type().IsNumber() && v.Type().IsNumber() {
ok, err := types.IsGreaterThan(m.Min, v) ok, err := m.Min.GT(v)
if err != nil { if err != nil {
return err return err
} }
@@ -467,7 +467,7 @@ func (m *MaxAggregator) Aggregate(env *environment.Environment) error {
} }
if m.Max.Type() == v.Type() || m.Max.Type().IsNumber() && v.Type().IsNumber() { if m.Max.Type() == v.Type() || m.Max.Type().IsNumber() && v.Type().IsNumber() {
ok, err := types.IsLesserThan(m.Max, v) ok, err := m.Max.LT(v)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -22,7 +22,7 @@ func (v LiteralValue) IsEqual(other Expr) bool {
if !ok { if !ok {
return false return false
} }
ok, err := types.IsEqual(v.Value, o.Value) ok, err := v.Value.EQ(o.Value)
return ok && err == nil return ok && err == nil
} }

View File

@@ -27,7 +27,7 @@ func ArrayContains(a types.Array, v types.Value) (bool, error) {
var found bool var found bool
err := a.Iterate(func(i int, vv types.Value) error { err := a.Iterate(func(i int, vv types.Value) error {
ok, err := types.IsEqual(vv, v) ok, err := vv.EQ(v)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -77,7 +77,7 @@ func diff(path Path, d1, d2 types.Object) ([]Op, error) {
} }
ops = append(ops, subOps...) ops = append(ops, subOps...)
default: default:
ok, err := types.IsEqual(v1, v2) ok, err := v1.EQ(v2)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -145,7 +145,7 @@ func arrayDiff(path Path, a1, a2 types.Array) ([]Op, error) {
} }
ops = append(ops, subOps...) ops = append(ops, subOps...)
default: default:
ok, err := types.IsEqual(v1, v2) ok, err := v1.EQ(v2)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -56,7 +56,7 @@ func (op *GroupAggregateOperator) Iterate(in *environment.Environment, f func(ou
return ga.Aggregate(out) return ga.Aggregate(out)
} }
ok, err := types.IsEqual(lastGroup, group) ok, err := lastGroup.EQ(group)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -114,7 +114,7 @@ func (vb *sortableValueBuffer) Less(i, j int) (ok bool) {
if it == jt || (it.IsNumber() && jt.IsNumber()) { if it == jt || (it.IsNumber() && jt.IsNumber()) {
// TODO(asdine) make the types package work with static objects // TODO(asdine) make the types package work with static objects
// to avoid having to deal with errors? // to avoid having to deal with errors?
ok, _ = types.IsLesserThan(vb.Values[i], vb.Values[j]) ok, _ = vb.Values[i].LT(vb.Values[j])
return return
} }

View File

@@ -1,210 +0,0 @@
package types
import (
"fmt"
"math"
"time"
)
// Add u to v and return the result.
// Only numeric values and booleans can be added together.
func Add(v1, v2 Value) (res Value, err error) {
return calculateValues(v1, v2, '+')
}
// Sub calculates v - u and returns the result.
// Only numeric values and booleans can be calculated together.
func Sub(v1, v2 Value) (res Value, err error) {
return calculateValues(v1, v2, '-')
}
// Mul calculates v * u and returns the result.
// Only numeric values and booleans can be calculated together.
func Mul(v1, v2 Value) (res Value, err error) {
return calculateValues(v1, v2, '*')
}
// Div calculates v / u and returns the result.
// Only numeric values and booleans can be calculated together.
// If both v and u are integers, the result will be an integer.
func Div(v1, v2 Value) (res Value, err error) {
return calculateValues(v1, v2, '/')
}
// Mod calculates v / u and returns the result.
// Only numeric values and booleans can be calculated together.
// If both v and u are integers, the result will be an integer.
func Mod(v1, v2 Value) (res Value, err error) {
return calculateValues(v1, v2, '%')
}
// BitwiseAnd calculates v & u and returns the result.
// Only numeric values and booleans can be calculated together.
// If both v and u are integers, the result will be an integer.
func BitwiseAnd(v1, v2 Value) (res Value, err error) {
return calculateValues(v1, v2, '&')
}
// BitwiseOr calculates v | u and returns the result.
// Only numeric values and booleans can be calculated together.
// If both v and u are integers, the result will be an integer.
func BitwiseOr(v1, v2 Value) (res Value, err error) {
return calculateValues(v1, v2, '|')
}
// BitwiseXor calculates v ^ u and returns the result.
// Only numeric values and booleans can be calculated together.
// If both v and u are integers, the result will be an integer.
func BitwiseXor(v1, v2 Value) (res Value, err error) {
return calculateValues(v1, v2, '^')
}
func calculateValues(a, b Value, operator byte) (res Value, err error) {
if a.Type() == TypeNull || b.Type() == TypeNull {
return NewNullValue(), nil
}
if a.Type().IsNumber() && b.Type().IsNumber() {
if a.Type() == TypeDouble || b.Type() == TypeDouble {
return calculateFloats(a, b, operator)
}
return calculateIntegers(a, b, operator)
}
return NewNullValue(), nil
}
func calculateIntegers(a, b Value, operator byte) (res Value, err error) {
var xa, xb int64
ia := convertNumberToInteger(a)
xa = As[int64](ia)
ib := convertNumberToInteger(b)
xb = As[int64](ib)
var xr int64
switch operator {
case '-':
xb = -xb
fallthrough
case '+':
xr = xa + xb
// if there is an integer overflow
// convert to float
if (xr > xa) != (xb > 0) {
return NewDoubleValue(float64(xa) + float64(xb)), nil
}
return NewIntegerValue(xr), nil
case '*':
if xa == 0 || xb == 0 {
return NewIntegerValue(0), nil
}
xr = xa * xb
// if there is no integer overflow
// return an int, otherwise
// convert to float
if (xr < 0) == ((xa < 0) != (xb < 0)) {
if xr/xb == xa {
return NewIntegerValue(xr), nil
}
}
return NewDoubleValue(float64(xa) * float64(xb)), nil
case '/':
if xb == 0 {
return NewNullValue(), nil
}
return NewIntegerValue(xa / xb), nil
case '%':
if xb == 0 {
return NewNullValue(), nil
}
return NewIntegerValue(xa % xb), nil
case '&':
return NewIntegerValue(xa & xb), nil
case '|':
return NewIntegerValue(xa | xb), nil
case '^':
return NewIntegerValue(xa ^ xb), nil
default:
panic(fmt.Sprintf("unknown operator %c", operator))
}
}
func calculateFloats(a, b Value, operator byte) (res Value, err error) {
var xa, xb float64
fa := convertNumberToDouble(a)
xa = As[float64](fa)
fb := convertNumberToDouble(b)
xb = As[float64](fb)
switch operator {
case '+':
return NewDoubleValue(xa + xb), nil
case '-':
return NewDoubleValue(xa - xb), nil
case '*':
return NewDoubleValue(xa * xb), nil
case '/':
if xb == 0 {
return NewNullValue(), nil
}
return NewDoubleValue(xa / xb), nil
case '%':
mod := math.Mod(xa, xb)
if math.IsNaN(mod) {
return NewNullValue(), nil
}
return NewDoubleValue(mod), nil
case '&':
ia, ib := int64(xa), int64(xb)
return NewIntegerValue(ia & ib), nil
case '|':
ia, ib := int64(xa), int64(xb)
return NewIntegerValue(ia | ib), nil
case '^':
ia, ib := int64(xa), int64(xb)
return NewIntegerValue(ia ^ ib), nil
default:
panic(fmt.Sprintf("unknown operator %c", operator))
}
}
func convertNumberToInteger(v Value) Value {
switch v.Type() {
case TypeInteger:
return v
default:
return NewIntegerValue(int64(As[float64](v)))
}
}
func convertNumberToDouble(v Value) Value {
switch v.Type() {
case TypeDouble:
return v
default:
return NewDoubleValue(float64(As[int64](v)))
}
}
func convertToTime(v Value) (time.Time, error) {
switch v.Type() {
case TypeTimestamp:
return As[time.Time](v), nil
case TypeText:
return ParseTimestamp(As[string](v))
default:
panic(fmt.Sprintf("cannot convert %v to time", v.Type()))
}
}

View File

@@ -35,7 +35,7 @@ func TestValueAdd(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
res, err := types.Add(test.v, test.u) res, err := test.v.Add(test.u)
if test.fails { if test.fails {
assert.Error(t, err) assert.Error(t, err)
} else { } else {
@@ -71,7 +71,7 @@ func TestValueSub(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
res, err := types.Sub(test.v, test.u) res, err := test.v.Sub(test.u)
if test.fails { if test.fails {
assert.Error(t, err) assert.Error(t, err)
} else { } else {
@@ -105,7 +105,7 @@ func TestValueMult(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
res, err := types.Mul(test.v, test.u) res, err := test.v.Mul(test.u)
if test.fails { if test.fails {
assert.Error(t, err) assert.Error(t, err)
} else { } else {
@@ -140,7 +140,7 @@ func TestValueDiv(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
res, err := types.Div(test.v, test.u) res, err := test.v.Div(test.u)
if test.fails { if test.fails {
assert.Error(t, err) assert.Error(t, err)
} else { } else {
@@ -177,7 +177,7 @@ func TestValueMod(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
res, err := types.Mod(test.v, test.u) res, err := test.v.Mod(test.u)
if test.fails { if test.fails {
assert.Error(t, err) assert.Error(t, err)
} else { } else {
@@ -211,7 +211,7 @@ func TestValueBitwiseAnd(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
res, err := types.BitwiseAnd(test.v, test.u) res, err := test.v.BitwiseAnd(test.u)
if test.fails { if test.fails {
assert.Error(t, err) assert.Error(t, err)
} else { } else {
@@ -244,7 +244,7 @@ func TestValueBitwiseOr(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
res, err := types.BitwiseOr(test.v, test.u) res, err := test.v.BitwiseOr(test.u)
if test.fails { if test.fails {
assert.Error(t, err) assert.Error(t, err)
} else { } else {
@@ -276,7 +276,7 @@ func TestValueBitwiseXor(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
res, err := types.BitwiseXor(test.v, test.u) res, err := test.v.BitwiseXor(test.u)
if test.fails { if test.fails {
assert.Error(t, err) assert.Error(t, err)
} else { } else {

View File

@@ -1,10 +1,8 @@
package types package types
import ( import (
"bytes"
"sort" "sort"
"strings" "strings"
"time"
) )
type operator uint8 type operator uint8
@@ -34,224 +32,6 @@ func (op operator) String() string {
return "" return ""
} }
// IsEqual returns true if v is equal to the given value.
func IsEqual(v, other Value) (bool, error) {
return compare(operatorEq, v, other)
}
// IsNotEqual returns true if v is not equal to the given value.
func IsNotEqual(v, other Value) (bool, error) {
ok, err := IsEqual(v, other)
if err != nil {
return ok, err
}
return !ok, nil
}
// IsGreaterThan returns true if v is greather than the given value.
func IsGreaterThan(v, other Value) (bool, error) {
return compare(operatorGt, v, other)
}
// IsGreaterThanOrEqual returns true if v is greather than or equal to the given value.
func IsGreaterThanOrEqual(v, other Value) (bool, error) {
return compare(operatorGte, v, other)
}
// IsLesserThan returns true if v is lesser than the given value.
func IsLesserThan(v, other Value) (bool, error) {
return compare(operatorLt, v, other)
}
// IsLesserThanOrEqual returns true if v is lesser than or equal to the given value.
func IsLesserThanOrEqual(v, other Value) (bool, error) {
return compare(operatorLte, v, other)
}
func compare(op operator, l, r Value) (bool, error) {
switch {
// deal with nil
case l.Type() == TypeNull || r.Type() == TypeNull:
return compareWithNull(op, l, r), nil
// compare booleans together
case l.Type() == TypeBoolean && r.Type() == TypeBoolean:
return compareBooleans(op, As[bool](l), As[bool](r)), nil
// compare texts together
case l.Type() == TypeText && r.Type() == TypeText:
return compareTexts(op, As[string](l), As[string](r)), nil
// compare blobs together
case r.Type() == TypeBlob && l.Type() == TypeBlob:
return compareBlobs(op, As[[]byte](l), As[[]byte](r)), nil
// compare integers together
case l.Type() == TypeInteger && r.Type() == TypeInteger:
return compareIntegers(op, As[int64](l), As[int64](r)), nil
// compare numbers together
case l.Type().IsNumber() && r.Type().IsNumber():
return compareNumbers(op, l, r), nil
// compare timestamps together
case l.Type() == TypeTimestamp && r.Type() == TypeTimestamp:
return compareTimes(op, As[time.Time](l), As[time.Time](r)), nil
// compare arrays together
case l.Type() == TypeArray && r.Type() == TypeArray:
return compareArrays(op, As[Array](l), As[Array](r))
// compare objects together
case l.Type() == TypeObject && r.Type() == TypeObject:
return compareobjects(op, As[Object](l), As[Object](r))
}
// compare compatible timestamps
if l.Type() == TypeTimestamp && r.Type().IsTimestampCompatible() {
return compareTimestamps(op, l, r)
} else if r.Type() == TypeTimestamp && l.Type().IsTimestampCompatible() {
return compareTimestamps(op, l, r)
}
return false, nil
}
func compareWithNull(op operator, l, r Value) bool {
switch op {
case operatorEq, operatorGte, operatorLte:
return l.Type() == r.Type()
case operatorGt, operatorLt:
return false
}
return false
}
func compareBooleans(op operator, a, b bool) bool {
switch op {
case operatorEq:
return a == b
case operatorGt:
return a && !b
case operatorGte:
return a == b || a
case operatorLt:
return !a && b
case operatorLte:
return a == b || !a
}
return false
}
func compareTexts(op operator, l, r string) bool {
switch op {
case operatorEq:
return l == r
case operatorGt:
return strings.Compare(l, r) > 0
case operatorGte:
return strings.Compare(l, r) >= 0
case operatorLt:
return strings.Compare(l, r) < 0
case operatorLte:
return strings.Compare(l, r) <= 0
}
return false
}
func compareBlobs(op operator, l, r []byte) bool {
switch op {
case operatorEq:
return bytes.Equal(l, r)
case operatorGt:
return bytes.Compare(l, r) > 0
case operatorGte:
return bytes.Compare(l, r) >= 0
case operatorLt:
return bytes.Compare(l, r) < 0
case operatorLte:
return bytes.Compare(l, r) <= 0
}
return false
}
func compareIntegers(op operator, l, r int64) bool {
switch op {
case operatorEq:
return l == r
case operatorGt:
return l > r
case operatorGte:
return l >= r
case operatorLt:
return l < r
case operatorLte:
return l <= r
}
return false
}
func compareNumbers(op operator, l, r Value) bool {
l = convertNumberToDouble(l)
r = convertNumberToDouble(r)
af := As[float64](l)
bf := As[float64](r)
var ok bool
switch op {
case operatorEq:
ok = af == bf
case operatorGt:
ok = af > bf
case operatorGte:
ok = af >= bf
case operatorLt:
ok = af < bf
case operatorLte:
ok = af <= bf
}
return ok
}
func compareTimes(op operator, l, r time.Time) bool {
switch op {
case operatorEq:
return l.Equal(r)
case operatorGt:
return l.After(r)
case operatorGte:
return l.After(r) || l.Equal(r)
case operatorLt:
return l.Before(r)
case operatorLte:
return l.Before(r) || l.Equal(r)
}
return false
}
func compareTimestamps(op operator, l, r Value) (bool, error) {
t1, err := convertToTime(l)
if err != nil {
return false, err
}
t2, err := convertToTime(r)
if err != nil {
return false, err
}
return compareTimes(op, t1, t2), nil
}
func compareArrays(op operator, l Array, r Array) (bool, error) { func compareArrays(op operator, l Array, r Array) (bool, error) {
var i, j int var i, j int
@@ -267,16 +47,24 @@ func compareArrays(op operator, l Array, r Array) (bool, error) {
if lerr != nil || rerr != nil { if lerr != nil || rerr != nil {
break break
} }
if lv.Type() == rv.Type() || (lv.Type().IsNumber() && rv.Type().IsNumber()) { if lv.Type().IsComparableWith(rv.Type()) {
isEq, err := compare(operatorEq, lv, rv) isEq, err := lv.EQ(rv)
if err != nil { if err != nil {
return false, err return false, err
} }
if !isEq && op != operatorEq {
return compare(op, lv, rv)
}
if !isEq { if !isEq {
return false, nil switch op {
case operatorEq:
return false, nil
case operatorGt:
return lv.GT(rv)
case operatorGte:
return lv.GTE(rv)
case operatorLt:
return lv.LT(rv)
case operatorLte:
return lv.LTE(rv)
}
} }
} else { } else {
switch op { switch op {
@@ -315,7 +103,7 @@ func compareArrays(op operator, l Array, r Array) (bool, error) {
} }
} }
func compareobjects(op operator, l, r Object) (bool, error) { func compareObjects(op operator, l, r Object) (bool, error) {
lf, err := Fields(l) lf, err := Fields(l)
if err != nil { if err != nil {
return false, err return false, err
@@ -384,16 +172,24 @@ func compareobjects(op operator, l, r Object) (bool, error) {
if lerr != nil || rerr != nil { if lerr != nil || rerr != nil {
break break
} }
if lv.Type() == rv.Type() || (lv.Type().IsNumber() && rv.Type().IsNumber()) { if lv.Type().IsComparableWith(rv.Type()) {
isEq, err := compare(operatorEq, lv, rv) isEq, err := lv.EQ(rv)
if err != nil { if err != nil {
return false, err return false, err
} }
if !isEq && op != operatorEq {
return compare(op, lv, rv)
}
if !isEq { if !isEq {
return false, nil switch op {
case operatorEq:
return false, nil
case operatorGt:
return lv.GT(rv)
case operatorGte:
return lv.GTE(rv)
case operatorLt:
return lv.LT(rv)
case operatorLte:
return lv.LTE(rv)
}
} }
} else { } else {
switch op { switch op {

View File

@@ -269,17 +269,18 @@ func TestCompare(t *testing.T) {
switch test.op { switch test.op {
case "=": case "=":
ok, err = types.IsEqual(a, b) ok, err = a.EQ(b)
case "!=": case "!=":
ok, err = types.IsNotEqual(a, b) ok, err = a.EQ(b)
ok = !ok
case ">": case ">":
ok, err = types.IsGreaterThan(a, b) ok, err = a.GT(b)
case ">=": case ">=":
ok, err = types.IsGreaterThanOrEqual(a, b) ok, err = a.GTE(b)
case "<": case "<":
ok, err = types.IsLesserThan(a, b) ok, err = a.LT(b)
case "<=": case "<=":
ok, err = types.IsLesserThanOrEqual(a, b) ok, err = a.LTE(b)
} }
assert.NoError(t, err) assert.NoError(t, err)
require.Equal(t, test.ok, ok) require.Equal(t, test.ok, ok)
@@ -322,17 +323,18 @@ func TestCompareValues(t *testing.T) {
switch test.op { switch test.op {
case "=": case "=":
ok, err = types.IsEqual(a, b) ok, err = a.EQ(b)
case "!=": case "!=":
ok, err = types.IsNotEqual(a, b) ok, err = a.EQ(b)
ok = !ok
case ">": case ">":
ok, err = types.IsGreaterThan(a, b) ok, err = a.GT(b)
case ">=": case ">=":
ok, err = types.IsGreaterThanOrEqual(a, b) ok, err = a.GTE(b)
case "<": case "<":
ok, err = types.IsLesserThan(a, b) ok, err = a.LT(b)
case "<=": case "<=":
ok, err = types.IsLesserThanOrEqual(a, b) ok, err = a.LTE(b)
} }
assert.NoError(t, err) assert.NoError(t, err)
require.Equal(t, test.ok, ok) require.Equal(t, test.ok, ok)

View File

@@ -68,6 +68,22 @@ func (t ValueType) IsTimestampCompatible() bool {
return t == TypeTimestamp || t == TypeText return t == TypeTimestamp || t == TypeText
} }
func (t ValueType) IsComparableWith(other ValueType) bool {
if t == other {
return true
}
if t.IsNumber() && other.IsNumber() {
return true
}
if t.IsTimestampCompatible() && other.IsTimestampCompatible() {
return true
}
return false
}
// IsAny returns whether this is type is Any or a real type // IsAny returns whether this is type is Any or a real type
func (t ValueType) IsAny() bool { func (t ValueType) IsAny() bool {
return t == TypeAny return t == TypeAny
@@ -80,6 +96,41 @@ type Value interface {
String() string String() string
MarshalJSON() ([]byte, error) MarshalJSON() ([]byte, error)
MarshalText() ([]byte, error) MarshalText() ([]byte, error)
EQ(other Value) (bool, error)
GT(other Value) (bool, error)
GTE(other Value) (bool, error)
LT(other Value) (bool, error)
LTE(other Value) (bool, error)
Between(a, b Value) (bool, error)
// Add u to v and return the result.
// Only numeric values and booleans can be added together.
Add(other Value) (Value, error)
// Sub calculates v - u and returns the result.
// Only numeric values and booleans can be calculated together.
Sub(other Value) (Value, error)
// Mul calculates v * u and returns the result.
// Only numeric values and booleans can be calculated together.
Mul(other Value) (Value, error)
// Div calculates v / u and returns the result.
// Only numeric values and booleans can be calculated together.
// If both v and u are integers, the result will be an integer.
Div(other Value) (Value, error)
// Mod calculates v / u and returns the result.
// Only numeric values and booleans can be calculated together.
// If both v and u are integers, the result will be an integer.
Mod(other Value) (Value, error)
// BitwiseAnd calculates v & u and returns the result.
// Only numeric values and booleans can be calculated together.
// If both v and u are integers, the result will be an integer.
BitwiseAnd(other Value) (Value, error)
// BitwiseOr calculates v | u and returns the result.
// Only numeric values and booleans can be calculated together.
// If both v and u are integers, the result will be an integer.
BitwiseOr(other Value) (Value, error)
// BitwiseXor calculates v ^ u and returns the result.
// Only numeric values and booleans can be calculated together.
// If both v and u are integers, the result will be an integer.
BitwiseXor(other Value) (Value, error)
} }
// A Object represents a group of key value pairs. // A Object represents a group of key value pairs.

File diff suppressed because it is too large Load Diff