mirror of
https://github.com/kubenetworks/kubevpn.git
synced 2025-10-12 18:51:02 +08:00
148 lines
4.0 KiB
Go
148 lines
4.0 KiB
Go
package patcher
|
|
|
|
import (
|
|
"fmt"
|
|
"reflect"
|
|
|
|
"github.com/expr-lang/expr/ast"
|
|
"github.com/expr-lang/expr/builtin"
|
|
"github.com/expr-lang/expr/checker/nature"
|
|
"github.com/expr-lang/expr/conf"
|
|
)
|
|
|
|
type OperatorOverloading struct {
|
|
Operator string // Operator token to overload.
|
|
Overloads []string // List of function names to replace operator with.
|
|
Env *nature.Nature // Env type.
|
|
Functions conf.FunctionsTable // Env functions.
|
|
applied bool // Flag to indicate if any changes were made to the tree.
|
|
}
|
|
|
|
func (p *OperatorOverloading) Visit(node *ast.Node) {
|
|
binaryNode, ok := (*node).(*ast.BinaryNode)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
if binaryNode.Operator != p.Operator {
|
|
return
|
|
}
|
|
|
|
leftType := binaryNode.Left.Type()
|
|
rightType := binaryNode.Right.Type()
|
|
|
|
ret, fn, ok := p.FindSuitableOperatorOverload(leftType, rightType)
|
|
if ok {
|
|
newNode := &ast.CallNode{
|
|
Callee: &ast.IdentifierNode{Value: fn},
|
|
Arguments: []ast.Node{binaryNode.Left, binaryNode.Right},
|
|
}
|
|
newNode.SetType(ret)
|
|
ast.Patch(node, newNode)
|
|
p.applied = true
|
|
}
|
|
}
|
|
|
|
// Tracking must be reset before every walk over the AST tree
|
|
func (p *OperatorOverloading) Reset() {
|
|
p.applied = false
|
|
}
|
|
|
|
func (p *OperatorOverloading) ShouldRepeat() bool {
|
|
return p.applied
|
|
}
|
|
|
|
func (p *OperatorOverloading) FindSuitableOperatorOverload(l, r reflect.Type) (reflect.Type, string, bool) {
|
|
t, fn, ok := p.findSuitableOperatorOverloadInFunctions(l, r)
|
|
if !ok {
|
|
t, fn, ok = p.findSuitableOperatorOverloadInTypes(l, r)
|
|
}
|
|
return t, fn, ok
|
|
}
|
|
|
|
func (p *OperatorOverloading) findSuitableOperatorOverloadInTypes(l, r reflect.Type) (reflect.Type, string, bool) {
|
|
for _, fn := range p.Overloads {
|
|
fnType, ok := p.Env.Get(fn)
|
|
if !ok {
|
|
continue
|
|
}
|
|
firstInIndex := 0
|
|
if fnType.Method {
|
|
firstInIndex = 1 // As first argument to method is receiver.
|
|
}
|
|
ret, done := checkTypeSuits(fnType.Type, l, r, firstInIndex)
|
|
if done {
|
|
return ret, fn, true
|
|
}
|
|
}
|
|
return nil, "", false
|
|
}
|
|
|
|
func (p *OperatorOverloading) findSuitableOperatorOverloadInFunctions(l, r reflect.Type) (reflect.Type, string, bool) {
|
|
for _, fn := range p.Overloads {
|
|
fnType, ok := p.Functions[fn]
|
|
if !ok {
|
|
continue
|
|
}
|
|
firstInIndex := 0
|
|
for _, overload := range fnType.Types {
|
|
ret, done := checkTypeSuits(overload, l, r, firstInIndex)
|
|
if done {
|
|
return ret, fn, true
|
|
}
|
|
}
|
|
}
|
|
return nil, "", false
|
|
}
|
|
|
|
func checkTypeSuits(t reflect.Type, l reflect.Type, r reflect.Type, firstInIndex int) (reflect.Type, bool) {
|
|
firstArgType := t.In(firstInIndex)
|
|
secondArgType := t.In(firstInIndex + 1)
|
|
|
|
firstArgumentFit := l == firstArgType || (firstArgType.Kind() == reflect.Interface && (l == nil || l.Implements(firstArgType)))
|
|
secondArgumentFit := r == secondArgType || (secondArgType.Kind() == reflect.Interface && (r == nil || r.Implements(secondArgType)))
|
|
if firstArgumentFit && secondArgumentFit {
|
|
return t.Out(0), true
|
|
}
|
|
return nil, false
|
|
}
|
|
|
|
func (p *OperatorOverloading) Check() {
|
|
for _, fn := range p.Overloads {
|
|
fnType, foundType := p.Env.Get(fn)
|
|
fnFunc, foundFunc := p.Functions[fn]
|
|
if !foundFunc && (!foundType || fnType.Type.Kind() != reflect.Func) {
|
|
panic(fmt.Errorf("function %s for %s operator does not exist in the environment", fn, p.Operator))
|
|
}
|
|
|
|
if foundType {
|
|
checkType(fnType, fn, p.Operator)
|
|
}
|
|
|
|
if foundFunc {
|
|
checkFunc(fnFunc, fn, p.Operator)
|
|
}
|
|
}
|
|
}
|
|
|
|
func checkType(fnType nature.Nature, fn string, operator string) {
|
|
requiredNumIn := 2
|
|
if fnType.Method {
|
|
requiredNumIn = 3 // As first argument of method is receiver.
|
|
}
|
|
if fnType.Type.NumIn() != requiredNumIn || fnType.Type.NumOut() != 1 {
|
|
panic(fmt.Errorf("function %s for %s operator does not have a correct signature", fn, operator))
|
|
}
|
|
}
|
|
|
|
func checkFunc(fn *builtin.Function, name string, operator string) {
|
|
if len(fn.Types) == 0 {
|
|
panic(fmt.Errorf("function %q for %q operator misses types", name, operator))
|
|
}
|
|
for _, t := range fn.Types {
|
|
if t.NumIn() != 2 || t.NumOut() != 1 {
|
|
panic(fmt.Errorf("function %q for %q operator does not have a correct signature", name, operator))
|
|
}
|
|
}
|
|
}
|