Files
kubevpn/vendor/github.com/expr-lang/expr/checker/checker.go
2025-04-19 10:06:56 +08:00

1286 lines
32 KiB
Go

package checker
import (
"fmt"
"reflect"
"regexp"
"github.com/expr-lang/expr/ast"
"github.com/expr-lang/expr/builtin"
. "github.com/expr-lang/expr/checker/nature"
"github.com/expr-lang/expr/conf"
"github.com/expr-lang/expr/file"
"github.com/expr-lang/expr/parser"
)
// Run visitors in a given config over the given tree
// runRepeatable controls whether to filter for only vistors that require multiple passes or not
func runVisitors(tree *parser.Tree, config *conf.Config, runRepeatable bool) {
for {
more := false
for _, v := range config.Visitors {
// We need to perform types check, because some visitors may rely on
// types information available in the tree.
_, _ = Check(tree, config)
r, repeatable := v.(interface {
Reset()
ShouldRepeat() bool
})
if repeatable {
if runRepeatable {
r.Reset()
ast.Walk(&tree.Node, v)
more = more || r.ShouldRepeat()
}
} else {
if !runRepeatable {
ast.Walk(&tree.Node, v)
}
}
}
if !more {
break
}
}
}
// ParseCheck parses input expression and checks its types. Also, it applies
// all provided patchers. In case of error, it returns error with a tree.
func ParseCheck(input string, config *conf.Config) (*parser.Tree, error) {
tree, err := parser.ParseWithConfig(input, config)
if err != nil {
return tree, err
}
if len(config.Visitors) > 0 {
// Run all patchers that dont support being run repeatedly first
runVisitors(tree, config, false)
// Run patchers that require multiple passes next (currently only Operator patching)
runVisitors(tree, config, true)
}
_, err = Check(tree, config)
if err != nil {
return tree, err
}
return tree, nil
}
// Check checks types of the expression tree. It returns type of the expression
// and error if any. If config is nil, then default configuration will be used.
func Check(tree *parser.Tree, config *conf.Config) (reflect.Type, error) {
if config == nil {
config = conf.New(nil)
}
v := &checker{config: config}
nt := v.visit(tree.Node)
// To keep compatibility with previous versions, we should return any, if nature is unknown.
t := nt.Type
if t == nil {
t = anyType
}
if v.err != nil {
return t, v.err.Bind(tree.Source)
}
if v.config.Expect != reflect.Invalid {
if v.config.ExpectAny {
if isUnknown(nt) {
return t, nil
}
}
switch v.config.Expect {
case reflect.Int, reflect.Int64, reflect.Float64:
if !isNumber(nt) {
return nil, fmt.Errorf("expected %v, but got %v", v.config.Expect, nt)
}
default:
if nt.Kind() != v.config.Expect {
return nil, fmt.Errorf("expected %v, but got %s", v.config.Expect, nt)
}
}
}
return t, nil
}
type checker struct {
config *conf.Config
predicateScopes []predicateScope
varScopes []varScope
err *file.Error
}
type predicateScope struct {
collection Nature
vars map[string]Nature
}
type varScope struct {
name string
nature Nature
}
type info struct {
method bool
fn *builtin.Function
// elem is element type of array or map.
// Arrays created with type []any, but
// we would like to detect expressions
// like `42 in ["a"]` as invalid.
elem reflect.Type
}
func (v *checker) visit(node ast.Node) Nature {
var nt Nature
switch n := node.(type) {
case *ast.NilNode:
nt = v.NilNode(n)
case *ast.IdentifierNode:
nt = v.IdentifierNode(n)
case *ast.IntegerNode:
nt = v.IntegerNode(n)
case *ast.FloatNode:
nt = v.FloatNode(n)
case *ast.BoolNode:
nt = v.BoolNode(n)
case *ast.StringNode:
nt = v.StringNode(n)
case *ast.ConstantNode:
nt = v.ConstantNode(n)
case *ast.UnaryNode:
nt = v.UnaryNode(n)
case *ast.BinaryNode:
nt = v.BinaryNode(n)
case *ast.ChainNode:
nt = v.ChainNode(n)
case *ast.MemberNode:
nt = v.MemberNode(n)
case *ast.SliceNode:
nt = v.SliceNode(n)
case *ast.CallNode:
nt = v.CallNode(n)
case *ast.BuiltinNode:
nt = v.BuiltinNode(n)
case *ast.PredicateNode:
nt = v.PredicateNode(n)
case *ast.PointerNode:
nt = v.PointerNode(n)
case *ast.VariableDeclaratorNode:
nt = v.VariableDeclaratorNode(n)
case *ast.SequenceNode:
nt = v.SequenceNode(n)
case *ast.ConditionalNode:
nt = v.ConditionalNode(n)
case *ast.ArrayNode:
nt = v.ArrayNode(n)
case *ast.MapNode:
nt = v.MapNode(n)
case *ast.PairNode:
nt = v.PairNode(n)
default:
panic(fmt.Sprintf("undefined node type (%T)", node))
}
node.SetNature(nt)
return nt
}
func (v *checker) error(node ast.Node, format string, args ...any) Nature {
if v.err == nil { // show first error
v.err = &file.Error{
Location: node.Location(),
Message: fmt.Sprintf(format, args...),
}
}
return unknown
}
func (v *checker) NilNode(*ast.NilNode) Nature {
return nilNature
}
func (v *checker) IdentifierNode(node *ast.IdentifierNode) Nature {
if variable, ok := v.lookupVariable(node.Value); ok {
return variable.nature
}
if node.Value == "$env" {
return unknown
}
return v.ident(node, node.Value, v.config.Strict, true)
}
// ident method returns type of environment variable, builtin or function.
func (v *checker) ident(node ast.Node, name string, strict, builtins bool) Nature {
if nt, ok := v.config.Env.Get(name); ok {
return nt
}
if builtins {
if fn, ok := v.config.Functions[name]; ok {
return Nature{Type: fn.Type(), Func: fn}
}
if fn, ok := v.config.Builtins[name]; ok {
return Nature{Type: fn.Type(), Func: fn}
}
}
if v.config.Strict && strict {
return v.error(node, "unknown name %v", name)
}
return unknown
}
func (v *checker) IntegerNode(*ast.IntegerNode) Nature {
return integerNature
}
func (v *checker) FloatNode(*ast.FloatNode) Nature {
return floatNature
}
func (v *checker) BoolNode(*ast.BoolNode) Nature {
return boolNature
}
func (v *checker) StringNode(*ast.StringNode) Nature {
return stringNature
}
func (v *checker) ConstantNode(node *ast.ConstantNode) Nature {
return Nature{Type: reflect.TypeOf(node.Value)}
}
func (v *checker) UnaryNode(node *ast.UnaryNode) Nature {
nt := v.visit(node.Node)
nt = nt.Deref()
switch node.Operator {
case "!", "not":
if isBool(nt) {
return boolNature
}
if isUnknown(nt) {
return boolNature
}
case "+", "-":
if isNumber(nt) {
return nt
}
if isUnknown(nt) {
return unknown
}
default:
return v.error(node, "unknown operator (%v)", node.Operator)
}
return v.error(node, `invalid operation: %v (mismatched type %s)`, node.Operator, nt)
}
func (v *checker) BinaryNode(node *ast.BinaryNode) Nature {
l := v.visit(node.Left)
r := v.visit(node.Right)
l = l.Deref()
r = r.Deref()
switch node.Operator {
case "==", "!=":
if isComparable(l, r) {
return boolNature
}
case "or", "||", "and", "&&":
if isBool(l) && isBool(r) {
return boolNature
}
if or(l, r, isBool) {
return boolNature
}
case "<", ">", ">=", "<=":
if isNumber(l) && isNumber(r) {
return boolNature
}
if isString(l) && isString(r) {
return boolNature
}
if isTime(l) && isTime(r) {
return boolNature
}
if isDuration(l) && isDuration(r) {
return boolNature
}
if or(l, r, isNumber, isString, isTime, isDuration) {
return boolNature
}
case "-":
if isNumber(l) && isNumber(r) {
return combined(l, r)
}
if isTime(l) && isTime(r) {
return durationNature
}
if isTime(l) && isDuration(r) {
return timeNature
}
if isDuration(l) && isDuration(r) {
return durationNature
}
if or(l, r, isNumber, isTime, isDuration) {
return unknown
}
case "*":
if isNumber(l) && isNumber(r) {
return combined(l, r)
}
if isNumber(l) && isDuration(r) {
return durationNature
}
if isDuration(l) && isNumber(r) {
return durationNature
}
if isDuration(l) && isDuration(r) {
return durationNature
}
if or(l, r, isNumber, isDuration) {
return unknown
}
case "/":
if isNumber(l) && isNumber(r) {
return floatNature
}
if or(l, r, isNumber) {
return floatNature
}
case "**", "^":
if isNumber(l) && isNumber(r) {
return floatNature
}
if or(l, r, isNumber) {
return floatNature
}
case "%":
if isInteger(l) && isInteger(r) {
return integerNature
}
if or(l, r, isInteger) {
return integerNature
}
case "+":
if isNumber(l) && isNumber(r) {
return combined(l, r)
}
if isString(l) && isString(r) {
return stringNature
}
if isTime(l) && isDuration(r) {
return timeNature
}
if isDuration(l) && isTime(r) {
return timeNature
}
if isDuration(l) && isDuration(r) {
return durationNature
}
if or(l, r, isNumber, isString, isTime, isDuration) {
return unknown
}
case "in":
if (isString(l) || isUnknown(l)) && isStruct(r) {
return boolNature
}
if isMap(r) {
if !isUnknown(l) && !l.AssignableTo(r.Key()) {
return v.error(node, "cannot use %v as type %v in map key", l, r.Key())
}
return boolNature
}
if isArray(r) {
if !isComparable(l, r.Elem()) {
return v.error(node, "cannot use %v as type %v in array", l, r.Elem())
}
return boolNature
}
if isUnknown(l) && anyOf(r, isString, isArray, isMap) {
return boolNature
}
if isUnknown(r) {
return boolNature
}
case "matches":
if s, ok := node.Right.(*ast.StringNode); ok {
_, err := regexp.Compile(s.Value)
if err != nil {
return v.error(node, err.Error())
}
}
if isString(l) && isString(r) {
return boolNature
}
if or(l, r, isString) {
return boolNature
}
case "contains", "startsWith", "endsWith":
if isString(l) && isString(r) {
return boolNature
}
if or(l, r, isString) {
return boolNature
}
case "..":
if isInteger(l) && isInteger(r) {
return arrayOf(integerNature)
}
if or(l, r, isInteger) {
return arrayOf(integerNature)
}
case "??":
if isNil(l) && !isNil(r) {
return r
}
if !isNil(l) && isNil(r) {
return l
}
if isNil(l) && isNil(r) {
return nilNature
}
if r.AssignableTo(l) {
return l
}
return unknown
default:
return v.error(node, "unknown operator (%v)", node.Operator)
}
return v.error(node, `invalid operation: %v (mismatched types %v and %v)`, node.Operator, l, r)
}
func (v *checker) ChainNode(node *ast.ChainNode) Nature {
return v.visit(node.Node)
}
func (v *checker) MemberNode(node *ast.MemberNode) Nature {
// $env variable
if an, ok := node.Node.(*ast.IdentifierNode); ok && an.Value == "$env" {
if name, ok := node.Property.(*ast.StringNode); ok {
strict := v.config.Strict
if node.Optional {
// If user explicitly set optional flag, then we should not
// throw error if field is not found (as user trying to handle
// this case). But if user did not set optional flag, then we
// should throw error if field is not found & v.config.Strict.
strict = false
}
return v.ident(node, name.Value, strict, false /* no builtins and no functions */)
}
return unknown
}
base := v.visit(node.Node)
prop := v.visit(node.Property)
if isUnknown(base) {
return unknown
}
if name, ok := node.Property.(*ast.StringNode); ok {
if isNil(base) {
return v.error(node, "type nil has no field %v", name.Value)
}
// First, check methods defined on base type itself,
// independent of which type it is. Without dereferencing.
if m, ok := base.MethodByName(name.Value); ok {
return m
}
}
base = base.Deref()
switch base.Kind() {
case reflect.Map:
if !prop.AssignableTo(base.Key()) && !isUnknown(prop) {
return v.error(node.Property, "cannot use %v to get an element from %v", prop, base)
}
if prop, ok := node.Property.(*ast.StringNode); ok {
if field, ok := base.Fields[prop.Value]; ok {
return field
} else if base.Strict {
return v.error(node.Property, "unknown field %v", prop.Value)
}
}
return base.Elem()
case reflect.Array, reflect.Slice:
if !isInteger(prop) && !isUnknown(prop) {
return v.error(node.Property, "array elements can only be selected using an integer (got %v)", prop)
}
return base.Elem()
case reflect.Struct:
if name, ok := node.Property.(*ast.StringNode); ok {
propertyName := name.Value
if field, ok := base.FieldByName(propertyName); ok {
return Nature{Type: field.Type}
}
if node.Method {
return v.error(node, "type %v has no method %v", base, propertyName)
}
return v.error(node, "type %v has no field %v", base, propertyName)
}
}
// Not found.
if name, ok := node.Property.(*ast.StringNode); ok {
if node.Method {
return v.error(node, "type %v has no method %v", base, name.Value)
}
return v.error(node, "type %v has no field %v", base, name.Value)
}
return v.error(node, "type %v[%v] is undefined", base, prop)
}
func (v *checker) SliceNode(node *ast.SliceNode) Nature {
nt := v.visit(node.Node)
if isUnknown(nt) {
return unknown
}
switch nt.Kind() {
case reflect.String, reflect.Array, reflect.Slice:
// ok
default:
return v.error(node, "cannot slice %s", nt)
}
if node.From != nil {
from := v.visit(node.From)
if !isInteger(from) && !isUnknown(from) {
return v.error(node.From, "non-integer slice index %v", from)
}
}
if node.To != nil {
to := v.visit(node.To)
if !isInteger(to) && !isUnknown(to) {
return v.error(node.To, "non-integer slice index %v", to)
}
}
return nt
}
func (v *checker) CallNode(node *ast.CallNode) Nature {
nt := v.functionReturnType(node)
// Check if type was set on node (for example, by patcher)
// and use node type instead of function return type.
//
// If node type is anyType, then we should use function
// return type. For example, on error we return anyType
// for a call `errCall().Method()` and method will be
// evaluated on `anyType.Method()`, so return type will
// be anyType `anyType.Method(): anyType`. Patcher can
// fix `errCall()` to return proper type, so on second
// checker pass we should replace anyType on method node
// with new correct function return type.
if node.Type() != nil && node.Type() != anyType {
return node.Nature()
}
return nt
}
func (v *checker) functionReturnType(node *ast.CallNode) Nature {
nt := v.visit(node.Callee)
if nt.Func != nil {
return v.checkFunction(nt.Func, node, node.Arguments)
}
fnName := "function"
if identifier, ok := node.Callee.(*ast.IdentifierNode); ok {
fnName = identifier.Value
}
if member, ok := node.Callee.(*ast.MemberNode); ok {
if name, ok := member.Property.(*ast.StringNode); ok {
fnName = name.Value
}
}
if isUnknown(nt) {
return unknown
}
if isNil(nt) {
return v.error(node, "%v is nil; cannot call nil as function", fnName)
}
switch nt.Kind() {
case reflect.Func:
outType, err := v.checkArguments(fnName, nt, node.Arguments, node)
if err != nil {
if v.err == nil {
v.err = err
}
return unknown
}
return outType
}
return v.error(node, "%s is not callable", nt)
}
func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature {
switch node.Name {
case "all", "none", "any", "one":
collection := v.visit(node.Arguments[0]).Deref()
if !isArray(collection) && !isUnknown(collection) {
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
}
v.begin(collection)
predicate := v.visit(node.Arguments[1])
v.end()
if isFunc(predicate) &&
predicate.NumOut() == 1 &&
predicate.NumIn() == 1 && isUnknown(predicate.In(0)) {
if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) {
return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String())
}
return boolNature
}
return v.error(node.Arguments[1], "predicate should has one input and one output param")
case "filter":
collection := v.visit(node.Arguments[0]).Deref()
if !isArray(collection) && !isUnknown(collection) {
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
}
v.begin(collection)
predicate := v.visit(node.Arguments[1])
v.end()
if isFunc(predicate) &&
predicate.NumOut() == 1 &&
predicate.NumIn() == 1 && isUnknown(predicate.In(0)) {
if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) {
return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String())
}
if isUnknown(collection) {
return arrayNature
}
return arrayOf(collection.Elem())
}
return v.error(node.Arguments[1], "predicate should has one input and one output param")
case "map":
collection := v.visit(node.Arguments[0]).Deref()
if !isArray(collection) && !isUnknown(collection) {
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
}
v.begin(collection, scopeVar{"index", integerNature})
predicate := v.visit(node.Arguments[1])
v.end()
if isFunc(predicate) &&
predicate.NumOut() == 1 &&
predicate.NumIn() == 1 && isUnknown(predicate.In(0)) {
return arrayOf(*predicate.PredicateOut)
}
return v.error(node.Arguments[1], "predicate should has one input and one output param")
case "count":
collection := v.visit(node.Arguments[0]).Deref()
if !isArray(collection) && !isUnknown(collection) {
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
}
if len(node.Arguments) == 1 {
return integerNature
}
v.begin(collection)
predicate := v.visit(node.Arguments[1])
v.end()
if isFunc(predicate) &&
predicate.NumOut() == 1 &&
predicate.NumIn() == 1 && isUnknown(predicate.In(0)) {
if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) {
return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String())
}
return integerNature
}
return v.error(node.Arguments[1], "predicate should has one input and one output param")
case "sum":
collection := v.visit(node.Arguments[0]).Deref()
if !isArray(collection) && !isUnknown(collection) {
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
}
if len(node.Arguments) == 2 {
v.begin(collection)
predicate := v.visit(node.Arguments[1])
v.end()
if isFunc(predicate) &&
predicate.NumOut() == 1 &&
predicate.NumIn() == 1 && isUnknown(predicate.In(0)) {
return predicate.Out(0)
}
} else {
if isUnknown(collection) {
return unknown
}
return collection.Elem()
}
case "find", "findLast":
collection := v.visit(node.Arguments[0]).Deref()
if !isArray(collection) && !isUnknown(collection) {
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
}
v.begin(collection)
predicate := v.visit(node.Arguments[1])
v.end()
if isFunc(predicate) &&
predicate.NumOut() == 1 &&
predicate.NumIn() == 1 && isUnknown(predicate.In(0)) {
if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) {
return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String())
}
if isUnknown(collection) {
return unknown
}
return collection.Elem()
}
return v.error(node.Arguments[1], "predicate should has one input and one output param")
case "findIndex", "findLastIndex":
collection := v.visit(node.Arguments[0]).Deref()
if !isArray(collection) && !isUnknown(collection) {
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
}
v.begin(collection)
predicate := v.visit(node.Arguments[1])
v.end()
if isFunc(predicate) &&
predicate.NumOut() == 1 &&
predicate.NumIn() == 1 && isUnknown(predicate.In(0)) {
if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) {
return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String())
}
return integerNature
}
return v.error(node.Arguments[1], "predicate should has one input and one output param")
case "groupBy":
collection := v.visit(node.Arguments[0]).Deref()
if !isArray(collection) && !isUnknown(collection) {
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
}
v.begin(collection)
predicate := v.visit(node.Arguments[1])
v.end()
if isFunc(predicate) &&
predicate.NumOut() == 1 &&
predicate.NumIn() == 1 && isUnknown(predicate.In(0)) {
groups := arrayOf(collection.Elem())
return Nature{Type: reflect.TypeOf(map[any][]any{}), ArrayOf: &groups}
}
return v.error(node.Arguments[1], "predicate should has one input and one output param")
case "sortBy":
collection := v.visit(node.Arguments[0]).Deref()
if !isArray(collection) && !isUnknown(collection) {
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
}
v.begin(collection)
predicate := v.visit(node.Arguments[1])
v.end()
if len(node.Arguments) == 3 {
_ = v.visit(node.Arguments[2])
}
if isFunc(predicate) &&
predicate.NumOut() == 1 &&
predicate.NumIn() == 1 && isUnknown(predicate.In(0)) {
return collection
}
return v.error(node.Arguments[1], "predicate should has one input and one output param")
case "reduce":
collection := v.visit(node.Arguments[0]).Deref()
if !isArray(collection) && !isUnknown(collection) {
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
}
v.begin(collection, scopeVar{"index", integerNature}, scopeVar{"acc", unknown})
predicate := v.visit(node.Arguments[1])
v.end()
if len(node.Arguments) == 3 {
_ = v.visit(node.Arguments[2])
}
if isFunc(predicate) && predicate.NumOut() == 1 {
return *predicate.PredicateOut
}
return v.error(node.Arguments[1], "predicate should has two input and one output param")
}
if id, ok := builtin.Index[node.Name]; ok {
switch node.Name {
case "get":
return v.checkBuiltinGet(node)
}
return v.checkFunction(builtin.Builtins[id], node, node.Arguments)
}
return v.error(node, "unknown builtin %v", node.Name)
}
type scopeVar struct {
varName string
varNature Nature
}
func (v *checker) begin(collectionNature Nature, vars ...scopeVar) {
scope := predicateScope{collection: collectionNature, vars: make(map[string]Nature)}
for _, v := range vars {
scope.vars[v.varName] = v.varNature
}
v.predicateScopes = append(v.predicateScopes, scope)
}
func (v *checker) end() {
v.predicateScopes = v.predicateScopes[:len(v.predicateScopes)-1]
}
func (v *checker) checkBuiltinGet(node *ast.BuiltinNode) Nature {
if len(node.Arguments) != 2 {
return v.error(node, "invalid number of arguments (expected 2, got %d)", len(node.Arguments))
}
base := v.visit(node.Arguments[0])
prop := v.visit(node.Arguments[1])
if id, ok := node.Arguments[0].(*ast.IdentifierNode); ok && id.Value == "$env" {
if s, ok := node.Arguments[1].(*ast.StringNode); ok {
if nt, ok := v.config.Env.Get(s.Value); ok {
return nt
}
}
return unknown
}
if isUnknown(base) {
return unknown
}
switch base.Kind() {
case reflect.Slice, reflect.Array:
if !isInteger(prop) && !isUnknown(prop) {
return v.error(node.Arguments[1], "non-integer slice index %s", prop)
}
return base.Elem()
case reflect.Map:
if !prop.AssignableTo(base.Key()) && !isUnknown(prop) {
return v.error(node.Arguments[1], "cannot use %s to get an element from %s", prop, base)
}
return base.Elem()
}
return v.error(node.Arguments[0], "type %v does not support indexing", base)
}
func (v *checker) checkFunction(f *builtin.Function, node ast.Node, arguments []ast.Node) Nature {
if f.Validate != nil {
args := make([]reflect.Type, len(arguments))
for i, arg := range arguments {
argNature := v.visit(arg)
if isUnknown(argNature) {
args[i] = anyType
} else {
args[i] = argNature.Type
}
}
t, err := f.Validate(args)
if err != nil {
return v.error(node, "%v", err)
}
return Nature{Type: t}
} else if len(f.Types) == 0 {
nt, err := v.checkArguments(f.Name, Nature{Type: f.Type()}, arguments, node)
if err != nil {
if v.err == nil {
v.err = err
}
return unknown
}
// No type was specified, so we assume the function returns any.
return nt
}
var lastErr *file.Error
for _, t := range f.Types {
outNature, err := v.checkArguments(f.Name, Nature{Type: t}, arguments, node)
if err != nil {
lastErr = err
continue
}
// As we found the correct function overload, we can stop the loop.
// Also, we need to set the correct nature of the callee so compiler,
// can correctly handle OpDeref opcode.
if callNode, ok := node.(*ast.CallNode); ok {
callNode.Callee.SetType(t)
}
return outNature
}
if lastErr != nil {
if v.err == nil {
v.err = lastErr
}
return unknown
}
return v.error(node, "no matching overload for %v", f.Name)
}
func (v *checker) checkArguments(
name string,
fn Nature,
arguments []ast.Node,
node ast.Node,
) (Nature, *file.Error) {
if isUnknown(fn) {
return unknown, nil
}
if fn.NumOut() == 0 {
return unknown, &file.Error{
Location: node.Location(),
Message: fmt.Sprintf("func %v doesn't return value", name),
}
}
if numOut := fn.NumOut(); numOut > 2 {
return unknown, &file.Error{
Location: node.Location(),
Message: fmt.Sprintf("func %v returns more then two values", name),
}
}
// If func is method on an env, first argument should be a receiver,
// and actual arguments less than fnNumIn by one.
fnNumIn := fn.NumIn()
if fn.Method { // TODO: Move subtraction to the Nature.NumIn() and Nature.In() methods.
fnNumIn--
}
// Skip first argument in case of the receiver.
fnInOffset := 0
if fn.Method {
fnInOffset = 1
}
var err *file.Error
if fn.IsVariadic() {
if len(arguments) < fnNumIn-1 {
err = &file.Error{
Location: node.Location(),
Message: fmt.Sprintf("not enough arguments to call %v", name),
}
}
} else {
if len(arguments) > fnNumIn {
err = &file.Error{
Location: node.Location(),
Message: fmt.Sprintf("too many arguments to call %v", name),
}
}
if len(arguments) < fnNumIn {
err = &file.Error{
Location: node.Location(),
Message: fmt.Sprintf("not enough arguments to call %v", name),
}
}
}
if err != nil {
// If we have an error, we should still visit all arguments to
// type check them, as a patch can fix the error later.
for _, arg := range arguments {
_ = v.visit(arg)
}
return fn.Out(0), err
}
for i, arg := range arguments {
argNature := v.visit(arg)
var in Nature
if fn.IsVariadic() && i >= fnNumIn-1 {
// For variadic arguments fn(xs ...int), go replaces type of xs (int) with ([]int).
// As we compare arguments one by one, we need underling type.
in = fn.In(fn.NumIn() - 1).Elem()
} else {
in = fn.In(i + fnInOffset)
}
if isFloat(in) && isInteger(argNature) {
traverseAndReplaceIntegerNodesWithFloatNodes(&arguments[i], in)
continue
}
if isInteger(in) && isInteger(argNature) && argNature.Kind() != in.Kind() {
traverseAndReplaceIntegerNodesWithIntegerNodes(&arguments[i], in)
continue
}
if isNil(argNature) {
if in.Kind() == reflect.Ptr || in.Kind() == reflect.Interface {
continue
}
return unknown, &file.Error{
Location: arg.Location(),
Message: fmt.Sprintf("cannot use nil as argument (type %s) to call %v", in, name),
}
}
// Check if argument is assignable to the function input type.
// We check original type (like *time.Time), not dereferenced type,
// as function input type can be pointer to a struct.
assignable := argNature.AssignableTo(in)
// We also need to check if dereference arg type is assignable to the function input type.
// For example, func(int) and argument *int. In this case we will add OpDeref to the argument,
// so we can call the function with *int argument.
assignable = assignable || argNature.Deref().AssignableTo(in)
if !assignable && !isUnknown(argNature) {
return unknown, &file.Error{
Location: arg.Location(),
Message: fmt.Sprintf("cannot use %s as argument (type %s) to call %v ", argNature, in, name),
}
}
}
return fn.Out(0), nil
}
func traverseAndReplaceIntegerNodesWithFloatNodes(node *ast.Node, newNature Nature) {
switch (*node).(type) {
case *ast.IntegerNode:
*node = &ast.FloatNode{Value: float64((*node).(*ast.IntegerNode).Value)}
(*node).SetType(newNature.Type)
case *ast.UnaryNode:
unaryNode := (*node).(*ast.UnaryNode)
traverseAndReplaceIntegerNodesWithFloatNodes(&unaryNode.Node, newNature)
case *ast.BinaryNode:
binaryNode := (*node).(*ast.BinaryNode)
switch binaryNode.Operator {
case "+", "-", "*":
traverseAndReplaceIntegerNodesWithFloatNodes(&binaryNode.Left, newNature)
traverseAndReplaceIntegerNodesWithFloatNodes(&binaryNode.Right, newNature)
}
}
}
func traverseAndReplaceIntegerNodesWithIntegerNodes(node *ast.Node, newNature Nature) {
switch (*node).(type) {
case *ast.IntegerNode:
(*node).SetType(newNature.Type)
case *ast.UnaryNode:
(*node).SetType(newNature.Type)
unaryNode := (*node).(*ast.UnaryNode)
traverseAndReplaceIntegerNodesWithIntegerNodes(&unaryNode.Node, newNature)
case *ast.BinaryNode:
// TODO: Binary node return type is dependent on the type of the operands. We can't just change the type of the node.
binaryNode := (*node).(*ast.BinaryNode)
switch binaryNode.Operator {
case "+", "-", "*":
traverseAndReplaceIntegerNodesWithIntegerNodes(&binaryNode.Left, newNature)
traverseAndReplaceIntegerNodesWithIntegerNodes(&binaryNode.Right, newNature)
}
}
}
func (v *checker) PredicateNode(node *ast.PredicateNode) Nature {
nt := v.visit(node.Node)
var out []reflect.Type
if isUnknown(nt) {
out = append(out, anyType)
} else if !isNil(nt) {
out = append(out, nt.Type)
}
return Nature{
Type: reflect.FuncOf([]reflect.Type{anyType}, out, false),
PredicateOut: &nt,
}
}
func (v *checker) PointerNode(node *ast.PointerNode) Nature {
if len(v.predicateScopes) == 0 {
return v.error(node, "cannot use pointer accessor outside predicate")
}
scope := v.predicateScopes[len(v.predicateScopes)-1]
if node.Name == "" {
if isUnknown(scope.collection) {
return unknown
}
switch scope.collection.Kind() {
case reflect.Array, reflect.Slice:
return scope.collection.Elem()
}
return v.error(node, "cannot use %v as array", scope)
}
if scope.vars != nil {
if t, ok := scope.vars[node.Name]; ok {
return t
}
}
return v.error(node, "unknown pointer #%v", node.Name)
}
func (v *checker) VariableDeclaratorNode(node *ast.VariableDeclaratorNode) Nature {
if _, ok := v.config.Env.Get(node.Name); ok {
return v.error(node, "cannot redeclare %v", node.Name)
}
if _, ok := v.config.Functions[node.Name]; ok {
return v.error(node, "cannot redeclare function %v", node.Name)
}
if _, ok := v.config.Builtins[node.Name]; ok {
return v.error(node, "cannot redeclare builtin %v", node.Name)
}
if _, ok := v.lookupVariable(node.Name); ok {
return v.error(node, "cannot redeclare variable %v", node.Name)
}
varNature := v.visit(node.Value)
v.varScopes = append(v.varScopes, varScope{node.Name, varNature})
exprNature := v.visit(node.Expr)
v.varScopes = v.varScopes[:len(v.varScopes)-1]
return exprNature
}
func (v *checker) SequenceNode(node *ast.SequenceNode) Nature {
if len(node.Nodes) == 0 {
return v.error(node, "empty sequence expression")
}
var last Nature
for _, node := range node.Nodes {
last = v.visit(node)
}
return last
}
func (v *checker) lookupVariable(name string) (varScope, bool) {
for i := len(v.varScopes) - 1; i >= 0; i-- {
if v.varScopes[i].name == name {
return v.varScopes[i], true
}
}
return varScope{}, false
}
func (v *checker) ConditionalNode(node *ast.ConditionalNode) Nature {
c := v.visit(node.Cond)
if !isBool(c) && !isUnknown(c) {
return v.error(node.Cond, "non-bool expression (type %v) used as condition", c)
}
t1 := v.visit(node.Exp1)
t2 := v.visit(node.Exp2)
if isNil(t1) && !isNil(t2) {
return t2
}
if !isNil(t1) && isNil(t2) {
return t1
}
if isNil(t1) && isNil(t2) {
return nilNature
}
if t1.AssignableTo(t2) {
return t1
}
return unknown
}
func (v *checker) ArrayNode(node *ast.ArrayNode) Nature {
var prev Nature
allElementsAreSameType := true
for i, node := range node.Nodes {
curr := v.visit(node)
if i > 0 {
if curr.Kind() != prev.Kind() {
allElementsAreSameType = false
}
}
prev = curr
}
if allElementsAreSameType {
return arrayOf(prev)
}
return arrayNature
}
func (v *checker) MapNode(node *ast.MapNode) Nature {
for _, pair := range node.Pairs {
v.visit(pair)
}
return mapNature
}
func (v *checker) PairNode(node *ast.PairNode) Nature {
v.visit(node.Key)
v.visit(node.Value)
return nilNature
}