Files
kubevpn/vendor/github.com/expr-lang/expr/patcher/operator_override.go
2025-04-19 10:06:56 +08:00

148 lines
4.0 KiB
Go

package patcher
import (
"fmt"
"reflect"
"github.com/expr-lang/expr/ast"
"github.com/expr-lang/expr/builtin"
"github.com/expr-lang/expr/checker/nature"
"github.com/expr-lang/expr/conf"
)
type OperatorOverloading struct {
Operator string // Operator token to overload.
Overloads []string // List of function names to replace operator with.
Env *nature.Nature // Env type.
Functions conf.FunctionsTable // Env functions.
applied bool // Flag to indicate if any changes were made to the tree.
}
func (p *OperatorOverloading) Visit(node *ast.Node) {
binaryNode, ok := (*node).(*ast.BinaryNode)
if !ok {
return
}
if binaryNode.Operator != p.Operator {
return
}
leftType := binaryNode.Left.Type()
rightType := binaryNode.Right.Type()
ret, fn, ok := p.FindSuitableOperatorOverload(leftType, rightType)
if ok {
newNode := &ast.CallNode{
Callee: &ast.IdentifierNode{Value: fn},
Arguments: []ast.Node{binaryNode.Left, binaryNode.Right},
}
newNode.SetType(ret)
ast.Patch(node, newNode)
p.applied = true
}
}
// Tracking must be reset before every walk over the AST tree
func (p *OperatorOverloading) Reset() {
p.applied = false
}
func (p *OperatorOverloading) ShouldRepeat() bool {
return p.applied
}
func (p *OperatorOverloading) FindSuitableOperatorOverload(l, r reflect.Type) (reflect.Type, string, bool) {
t, fn, ok := p.findSuitableOperatorOverloadInFunctions(l, r)
if !ok {
t, fn, ok = p.findSuitableOperatorOverloadInTypes(l, r)
}
return t, fn, ok
}
func (p *OperatorOverloading) findSuitableOperatorOverloadInTypes(l, r reflect.Type) (reflect.Type, string, bool) {
for _, fn := range p.Overloads {
fnType, ok := p.Env.Get(fn)
if !ok {
continue
}
firstInIndex := 0
if fnType.Method {
firstInIndex = 1 // As first argument to method is receiver.
}
ret, done := checkTypeSuits(fnType.Type, l, r, firstInIndex)
if done {
return ret, fn, true
}
}
return nil, "", false
}
func (p *OperatorOverloading) findSuitableOperatorOverloadInFunctions(l, r reflect.Type) (reflect.Type, string, bool) {
for _, fn := range p.Overloads {
fnType, ok := p.Functions[fn]
if !ok {
continue
}
firstInIndex := 0
for _, overload := range fnType.Types {
ret, done := checkTypeSuits(overload, l, r, firstInIndex)
if done {
return ret, fn, true
}
}
}
return nil, "", false
}
func checkTypeSuits(t reflect.Type, l reflect.Type, r reflect.Type, firstInIndex int) (reflect.Type, bool) {
firstArgType := t.In(firstInIndex)
secondArgType := t.In(firstInIndex + 1)
firstArgumentFit := l == firstArgType || (firstArgType.Kind() == reflect.Interface && (l == nil || l.Implements(firstArgType)))
secondArgumentFit := r == secondArgType || (secondArgType.Kind() == reflect.Interface && (r == nil || r.Implements(secondArgType)))
if firstArgumentFit && secondArgumentFit {
return t.Out(0), true
}
return nil, false
}
func (p *OperatorOverloading) Check() {
for _, fn := range p.Overloads {
fnType, foundType := p.Env.Get(fn)
fnFunc, foundFunc := p.Functions[fn]
if !foundFunc && (!foundType || fnType.Type.Kind() != reflect.Func) {
panic(fmt.Errorf("function %s for %s operator does not exist in the environment", fn, p.Operator))
}
if foundType {
checkType(fnType, fn, p.Operator)
}
if foundFunc {
checkFunc(fnFunc, fn, p.Operator)
}
}
}
func checkType(fnType nature.Nature, fn string, operator string) {
requiredNumIn := 2
if fnType.Method {
requiredNumIn = 3 // As first argument of method is receiver.
}
if fnType.Type.NumIn() != requiredNumIn || fnType.Type.NumOut() != 1 {
panic(fmt.Errorf("function %s for %s operator does not have a correct signature", fn, operator))
}
}
func checkFunc(fn *builtin.Function, name string, operator string) {
if len(fn.Types) == 0 {
panic(fmt.Errorf("function %q for %q operator misses types", name, operator))
}
for _, t := range fn.Types {
if t.NumIn() != 2 || t.NumOut() != 1 {
panic(fmt.Errorf("function %q for %q operator does not have a correct signature", name, operator))
}
}
}