add control flow obfuscation

Implemented control flow flattening with additional features such as block splitting and junk jumps
This commit is contained in:
pagran
2023-06-25 22:57:27 +02:00
committed by GitHub
parent d89a55687c
commit 0e2e483472
14 changed files with 2850 additions and 34 deletions

View File

@@ -6,6 +6,7 @@ package asthelper
import (
"fmt"
"go/ast"
"go/constant"
"go/token"
"strconv"
)
@@ -85,3 +86,56 @@ func DataToByteSlice(data []byte) *ast.CallExpr {
Args: []ast.Expr{StringLit(string(data))},
}
}
// SelectExpr "x.sel"
func SelectExpr(x ast.Expr, sel *ast.Ident) *ast.SelectorExpr {
return &ast.SelectorExpr{
X: x,
Sel: sel,
}
}
// AssignDefineStmt "Lhs := Rhs"
func AssignDefineStmt(Lhs ast.Expr, Rhs ast.Expr) *ast.AssignStmt {
return &ast.AssignStmt{
Lhs: []ast.Expr{Lhs},
Tok: token.DEFINE,
Rhs: []ast.Expr{Rhs},
}
}
// CallExprByName "fun(args...)"
func CallExprByName(fun string, args ...ast.Expr) *ast.CallExpr {
return CallExpr(ast.NewIdent(fun), args...)
}
// AssignStmt "Lhs = Rhs"
func AssignStmt(Lhs ast.Expr, Rhs ast.Expr) *ast.AssignStmt {
return &ast.AssignStmt{
Lhs: []ast.Expr{Lhs},
Tok: token.ASSIGN,
Rhs: []ast.Expr{Rhs},
}
}
// IndexExprByExpr "xExpr[indexExpr]"
func IndexExprByExpr(xExpr, indexExpr ast.Expr) *ast.IndexExpr {
return &ast.IndexExpr{X: xExpr, Index: indexExpr}
}
func ConstToAst(val constant.Value) ast.Expr {
switch val.Kind() {
case constant.Bool:
return ast.NewIdent(val.ExactString())
case constant.String:
return &ast.BasicLit{Kind: token.STRING, Value: val.ExactString()}
case constant.Int:
return &ast.BasicLit{Kind: token.INT, Value: val.ExactString()}
case constant.Float:
return &ast.BasicLit{Kind: token.FLOAT, Value: val.String()}
case constant.Complex:
return CallExprByName("complex", ConstToAst(constant.Real(val)), ConstToAst(constant.Imag(val)))
default:
panic("unreachable")
}
}

View File

@@ -0,0 +1,186 @@
package ctrlflow
import (
"fmt"
"go/ast"
"go/token"
"go/types"
"log"
mathrand "math/rand"
"strconv"
"strings"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/ssa"
ah "mvdan.cc/garble/internal/asthelper"
"mvdan.cc/garble/internal/ssa2ast"
)
const (
mergedFileName = "GARBLE_controlflow.go"
directiveName = "//garble:controlflow"
importPrefix = "___garble_import"
defaultBlockSplits = 0
defaultJunkJumps = 0
defaultFlattenPasses = 1
)
type directiveParamMap map[string]string
func (m directiveParamMap) GetInt(name string, def int) int {
rawVal, ok := m[name]
if !ok {
return def
}
val, err := strconv.Atoi(rawVal)
if err != nil {
panic(fmt.Errorf("invalid flag %s format: %v", name, err))
}
return val
}
// parseDirective parses a directive string and returns a map of directive parameters.
// Each parameter should be in the form "key=value" or "key"
func parseDirective(directive string) (directiveParamMap, bool) {
fieldsStr, ok := strings.CutPrefix(directive, directiveName)
if !ok {
return nil, false
}
fields := strings.Fields(fieldsStr)
if len(fields) == 0 {
return nil, true
}
m := make(map[string]string)
for _, v := range fields {
key, value, ok := strings.Cut(v, "=")
if ok {
m[key] = value
} else {
m[key] = ""
}
}
return m, true
}
// Obfuscate obfuscates control flow of all functions with directive using control flattening.
// All obfuscated functions are removed from the original file and moved to the new one.
// Obfuscation can be customized by passing parameters from the directive, example:
//
// //garble:controlflow flatten_passes=1 junk_jumps=0 block_splits=0
// func someMethod() {}
//
// flatten_passes - controls number of passes of control flow flattening. Have exponential complexity and more than 3 passes are not recommended in most cases.
// junk_jumps - controls how many junk jumps are added. It does not affect final binary by itself, but together with flattening linearly increases complexity.
// block_splits - controls number of times largest block must be splitted. Together with flattening improves obfuscation of long blocks without branches.
func Obfuscate(fset *token.FileSet, ssaPkg *ssa.Package, files []*ast.File, obfRand *mathrand.Rand) (newFileName string, newFile *ast.File, affectedFiles []*ast.File, err error) {
var ssaFuncs []*ssa.Function
var ssaParams []directiveParamMap
for _, file := range files {
affected := false
for _, decl := range file.Decls {
funcDecl, ok := decl.(*ast.FuncDecl)
if !ok || funcDecl.Doc == nil {
continue
}
for _, comment := range funcDecl.Doc.List {
params, hasDirective := parseDirective(comment.Text)
if !hasDirective {
continue
}
path, _ := astutil.PathEnclosingInterval(file, funcDecl.Pos(), funcDecl.Pos())
ssaFunc := ssa.EnclosingFunction(ssaPkg, path)
if ssaFunc == nil {
panic("function exists in ast but not found in ssa")
}
ssaFuncs = append(ssaFuncs, ssaFunc)
ssaParams = append(ssaParams, params)
log.Printf("detected function for controlflow %s (params: %v)", funcDecl.Name.Name, params)
// Remove inplace function from original file
// TODO: implement a complete function removal
funcDecl.Name = ast.NewIdent("_")
funcDecl.Body = ah.BlockStmt()
funcDecl.Recv = nil
funcDecl.Type = &ast.FuncType{Params: &ast.FieldList{}}
affected = true
break
}
}
if affected {
affectedFiles = append(affectedFiles, file)
}
}
if len(ssaFuncs) == 0 {
return
}
newFile = &ast.File{
Package: token.Pos(fset.Base()),
Name: ast.NewIdent(files[0].Name.Name),
}
fset.AddFile(mergedFileName, int(newFile.Package), 1) // required for correct printer output
funcConfig := ssa2ast.DefaultConfig()
imports := make(map[string]string) // TODO: indirect imports turned into direct currently brake build process
funcConfig.ImportNameResolver = func(pkg *types.Package) *ast.Ident {
if pkg == nil || pkg.Path() == ssaPkg.Pkg.Path() {
return nil
}
name, ok := imports[pkg.Path()]
if !ok {
name = importPrefix + strconv.Itoa(len(imports))
imports[pkg.Path()] = name
astutil.AddNamedImport(fset, newFile, name, pkg.Path())
}
return ast.NewIdent(name)
}
for idx, ssaFunc := range ssaFuncs {
params := ssaParams[idx]
split := params.GetInt("block_splits", defaultBlockSplits)
junkCount := params.GetInt("junk_jumps", defaultJunkJumps)
passes := params.GetInt("flatten_passes", defaultFlattenPasses)
applyObfuscation := func(ssaFunc *ssa.Function) {
for i := 0; i < split; i++ {
if !applySplitting(ssaFunc, obfRand) {
break // no more candidates for splitting
}
}
if junkCount > 0 {
addJunkBlocks(ssaFunc, junkCount, obfRand)
}
for i := 0; i < passes; i++ {
applyFlattening(ssaFunc, obfRand)
}
fixBlockIndexes(ssaFunc)
}
applyObfuscation(ssaFunc)
for _, anonFunc := range ssaFunc.AnonFuncs {
applyObfuscation(anonFunc)
}
astFunc, err := ssa2ast.Convert(ssaFunc, funcConfig)
if err != nil {
return "", nil, nil, err
}
newFile.Decls = append(newFile.Decls, astFunc)
}
newFileName = mergedFileName
return
}

43
internal/ctrlflow/ssa.go Normal file
View File

@@ -0,0 +1,43 @@
package ctrlflow
import (
"go/constant"
"go/types"
"reflect"
"unsafe"
"golang.org/x/tools/go/ssa"
)
// setUnexportedField is used to modify unexported fields of ssa api structures.
// TODO: find an alternative way to access private fields or raise a feature request upstream
func setUnexportedField(objRaw interface{}, name string, valRaw interface{}) {
obj := reflect.ValueOf(objRaw)
for obj.Kind() == reflect.Pointer || obj.Kind() == reflect.Interface {
obj = obj.Elem()
}
field := obj.FieldByName(name)
if !field.IsValid() {
panic("invalid field: " + name)
}
fakeStruct := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr()))
fakeStruct.Elem().Set(reflect.ValueOf(valRaw))
}
func setBlockParent(block *ssa.BasicBlock, ssaFunc *ssa.Function) {
setUnexportedField(block, "parent", ssaFunc)
}
func setBlock(instr ssa.Instruction, block *ssa.BasicBlock) {
setUnexportedField(instr, "block", block)
}
func setType(instr ssa.Instruction, typ types.Type) {
setUnexportedField(instr, "typ", typ)
}
func makeSsaInt(i int) *ssa.Const {
return ssa.NewConst(constant.MakeInt64(int64(i)), types.Typ[types.Int])
}

View File

@@ -0,0 +1,212 @@
package ctrlflow
import (
"go/token"
"go/types"
mathrand "math/rand"
"strconv"
"golang.org/x/tools/go/ssa"
)
type blockMapping struct {
Fake, Target *ssa.BasicBlock
}
// applyFlattening adds a dispatcher block and uses ssa.Phi to redirect all ssa.Jump and ssa.If to the dispatcher,
// additionally shuffle all blocks
func applyFlattening(ssaFunc *ssa.Function, obfRand *mathrand.Rand) {
if len(ssaFunc.Blocks) < 3 {
return
}
phiInstr := &ssa.Phi{Comment: "ctrflow.phi"}
setType(phiInstr, types.Typ[types.Int])
entryBlock := &ssa.BasicBlock{
Comment: "ctrflow.entry",
Instrs: []ssa.Instruction{phiInstr},
}
setBlockParent(entryBlock, ssaFunc)
makeJumpBlock := func(from *ssa.BasicBlock) *ssa.BasicBlock {
jumpBlock := &ssa.BasicBlock{
Comment: "ctrflow.jump",
Instrs: []ssa.Instruction{&ssa.Jump{}},
Preds: []*ssa.BasicBlock{from},
Succs: []*ssa.BasicBlock{entryBlock},
}
setBlockParent(jumpBlock, ssaFunc)
return jumpBlock
}
// map for track fake block -> real block jump
var blocksMapping []blockMapping
for _, block := range ssaFunc.Blocks {
existInstr := block.Instrs[len(block.Instrs)-1]
switch existInstr.(type) {
case *ssa.Jump:
targetBlock := block.Succs[0]
fakeBlock := makeJumpBlock(block)
blocksMapping = append(blocksMapping, blockMapping{fakeBlock, targetBlock})
block.Succs[0] = fakeBlock
case *ssa.If:
tblock, fblock := block.Succs[0], block.Succs[1]
fakeTblock, fakeFblock := makeJumpBlock(tblock), makeJumpBlock(fblock)
blocksMapping = append(blocksMapping, blockMapping{fakeTblock, tblock})
blocksMapping = append(blocksMapping, blockMapping{fakeFblock, fblock})
block.Succs[0] = fakeTblock
block.Succs[1] = fakeFblock
case *ssa.Return, *ssa.Panic:
// control flow flattening is not applicable
default:
panic("unreachable")
}
}
phiIdxs := obfRand.Perm(len(blocksMapping))
for i := range phiIdxs {
phiIdxs[i]++ // 0 reserved for real entry block
}
var entriesBlocks []*ssa.BasicBlock
obfuscatedBlocks := ssaFunc.Blocks
for i, m := range blocksMapping {
entryBlock.Preds = append(entryBlock.Preds, m.Fake)
phiInstr.Edges = append(phiInstr.Edges, makeSsaInt(phiIdxs[i]))
obfuscatedBlocks = append(obfuscatedBlocks, m.Fake)
cond := &ssa.BinOp{X: phiInstr, Op: token.EQL, Y: makeSsaInt(phiIdxs[i])}
setType(cond, types.Typ[types.Bool])
*phiInstr.Referrers() = append(*phiInstr.Referrers(), cond)
ifInstr := &ssa.If{Cond: cond}
*cond.Referrers() = append(*cond.Referrers(), ifInstr)
ifBlock := &ssa.BasicBlock{
Instrs: []ssa.Instruction{cond, ifInstr},
Succs: []*ssa.BasicBlock{m.Target, nil}, // false branch fulfilled in next iteration or linked to real entry block
}
setBlockParent(ifBlock, ssaFunc)
setBlock(cond, ifBlock)
setBlock(ifInstr, ifBlock)
entriesBlocks = append(entriesBlocks, ifBlock)
if i == 0 {
entryBlock.Instrs = append(entryBlock.Instrs, &ssa.Jump{})
entryBlock.Succs = []*ssa.BasicBlock{ifBlock}
ifBlock.Preds = append(ifBlock.Preds, entryBlock)
} else {
// link previous block to current
entriesBlocks[i-1].Succs[1] = ifBlock
ifBlock.Preds = append(ifBlock.Preds, entriesBlocks[i-1])
}
}
lastFakeEntry := entriesBlocks[len(entriesBlocks)-1]
realEntryBlock := ssaFunc.Blocks[0]
lastFakeEntry.Succs[1] = realEntryBlock
realEntryBlock.Preds = append(realEntryBlock.Preds, lastFakeEntry)
obfuscatedBlocks = append(obfuscatedBlocks, entriesBlocks...)
obfRand.Shuffle(len(obfuscatedBlocks), func(i, j int) {
obfuscatedBlocks[i], obfuscatedBlocks[j] = obfuscatedBlocks[j], obfuscatedBlocks[i]
})
ssaFunc.Blocks = append([]*ssa.BasicBlock{entryBlock}, obfuscatedBlocks...)
}
// addJunkBlocks adds junk jumps into random blocks. Can create chains of junk jumps.
func addJunkBlocks(ssaFunc *ssa.Function, count int, obfRand *mathrand.Rand) {
if count == 0 {
return
}
var candidates []*ssa.BasicBlock
for _, block := range ssaFunc.Blocks {
if len(block.Succs) > 0 {
candidates = append(candidates, block)
}
}
if len(candidates) == 0 {
return
}
for i := 0; i < count; i++ {
targetBlock := candidates[obfRand.Intn(len(candidates))]
succsIdx := obfRand.Intn(len(targetBlock.Succs))
succs := targetBlock.Succs[succsIdx]
fakeBlock := &ssa.BasicBlock{
Comment: "ctrflow.fake." + strconv.Itoa(i),
Instrs: []ssa.Instruction{&ssa.Jump{}},
Preds: []*ssa.BasicBlock{targetBlock},
Succs: []*ssa.BasicBlock{succs},
}
setBlockParent(fakeBlock, ssaFunc)
targetBlock.Succs[succsIdx] = fakeBlock
ssaFunc.Blocks = append(ssaFunc.Blocks, fakeBlock)
candidates = append(candidates, fakeBlock)
}
}
// applySplitting splits biggest block into 2 parts of random size.
// Returns false if no block large enough for splitting is found
func applySplitting(ssaFunc *ssa.Function, obfRand *mathrand.Rand) bool {
var targetBlock *ssa.BasicBlock
for _, block := range ssaFunc.Blocks {
if targetBlock == nil || len(block.Instrs) > len(targetBlock.Instrs) {
targetBlock = block
}
}
const minInstrCount = 1 + 1 // 1 exit instruction + 1 any instruction
if targetBlock == nil || len(targetBlock.Instrs) <= minInstrCount {
return false
}
splitIdx := 1 + obfRand.Intn(len(targetBlock.Instrs)-2)
firstPart := make([]ssa.Instruction, splitIdx+1)
copy(firstPart, targetBlock.Instrs)
firstPart[len(firstPart)-1] = &ssa.Jump{}
secondPart := targetBlock.Instrs[splitIdx:]
targetBlock.Instrs = firstPart
newBlock := &ssa.BasicBlock{
Comment: "ctrflow.split." + strconv.Itoa(targetBlock.Index),
Instrs: secondPart,
Preds: []*ssa.BasicBlock{targetBlock},
Succs: targetBlock.Succs,
}
setBlockParent(newBlock, ssaFunc)
for _, instr := range newBlock.Instrs {
setBlock(instr, newBlock)
}
// Fix preds for ssa.Phi working
for _, succ := range targetBlock.Succs {
for i, pred := range succ.Preds {
if pred == targetBlock {
succ.Preds[i] = newBlock
}
}
}
ssaFunc.Blocks = append(ssaFunc.Blocks, newBlock)
targetBlock.Succs = []*ssa.BasicBlock{newBlock}
return true
}
func fixBlockIndexes(ssaFunc *ssa.Function) {
for i, block := range ssaFunc.Blocks {
block.Index = i
}
}

1140
internal/ssa2ast/func.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,398 @@
package ssa2ast
import (
"go/ast"
"go/importer"
"go/printer"
"go/types"
"os"
"os/exec"
"path/filepath"
"testing"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/ssa"
"github.com/google/go-cmp/cmp"
"golang.org/x/tools/go/ssa/ssautil"
)
const sigSrc = `package main
import "unsafe"
type genericStruct[T interface{}] struct{}
type plainStruct struct {
Dummy struct{}
}
func (s *plainStruct) plainStructFunc() {
}
func (*plainStruct) plainStructAnonFunc() {
}
func (s *genericStruct[T]) genericStructFunc() {
}
func (s *genericStruct[T]) genericStructAnonFunc() (test T) {
return
}
func plainFuncSignature(a int, b string, c struct{}, d struct{ string }, e interface{ Dummy() string }, pointer unsafe.Pointer) (i int, er error) {
return
}
func genericFuncSignature[T interface{ interface{} | ~int64 | bool }, X interface{ comparable }](a T, b X, c genericStruct[struct{ a T }], d genericStruct[T]) (res T) {
return
}
`
func TestConvertSignature(t *testing.T) {
conv := newFuncConverter(DefaultConfig())
f, _, info, _ := mustParseAndTypeCheckFile(sigSrc)
for _, funcName := range []string{"plainStructFunc", "plainStructAnonFunc", "genericStructFunc", "plainFuncSignature", "genericFuncSignature"} {
funcDecl := findFunc(f, funcName)
funcDecl.Body = nil
funcObj := info.Defs[funcDecl.Name].(*types.Func)
funcDeclConverted, err := conv.convertSignatureToFuncDecl(funcObj.Name(), funcObj.Type().(*types.Signature))
if err != nil {
t.Fatal(err)
}
if structDiff := cmp.Diff(funcDecl, funcDeclConverted, astCmpOpt); structDiff != "" {
t.Fatalf("method decl not equals: %s", structDiff)
}
}
}
const mainSrc = `package main
import (
"encoding/binary"
"fmt"
"io"
"sort"
"strconv"
"sync"
"time"
"unsafe"
)
func main() {
methodOps()
slicesOps()
iterAndMapsOps()
chanOps()
flowOps()
typeOps()
}
func makeSprintf(tag string) func(vals ...interface{}) {
i := 0
return func(vals ...interface{}) {
fmt.Printf("%s(%d): %v\n", tag, i, vals)
i++
}
}
func return42() int {
return 42
}
type arrayOfInts []int
type structOfArraysOfInts struct {
a arrayOfInts
b arrayOfInts
}
func slicesOps() {
sprintf := makeSprintf("slicesOps")
slice := [...]int{1, 2}
sprintf(slice[0:1:2])
// *ssa.IndexAddr
sprintf(slice)
slice[0] += 1
sprintf(slice)
sprintf(slice[:1])
sprintf(slice[slice[0]:])
sprintf(slice[0:2])
sprintf((*[2]int)(slice[:])[return42()%2]) // *ssa.SliceToArrayPointer
sprintf("test"[return42()%3]) // *ssa.Index
structOfArrays := structOfArraysOfInts{a: slice[1:], b: slice[:1]}
sprintf(structOfArrays.a[:1])
sprintf(structOfArrays.b[:1])
slice2 := make([]string, return42(), return42()*2)
slice2[return42()-1] = "test"
sprintf(slice2)
return
}
func iterAndMapsOps() {
sprintf := makeSprintf("iterAndMapsOps")
// *ssa.MakeMap + *ssa.MapUpdate
mmap := map[string]time.Month{
"April": time.April,
"December": time.December,
"January": time.January,
}
var vals []string
for k := range mmap {
vals = append(vals, k)
}
for _, v := range mmap {
vals = append(vals, v.String())
}
sort.Strings(vals) // Required. Order of map iteration not guaranteed
sprintf(vals)
if v, ok := mmap["?"]; ok {
panic("unreachable: " + v.String())
}
for idx, s := range "hello world" {
sprintf(idx, s)
}
sprintf(mmap["April"].String())
return
}
type interfaceCalls interface {
Return1() string
}
type structCalls struct {
}
func (r structCalls) Return1() string {
return "Return1"
}
func (r *structCalls) Return2() string {
return "Return2"
}
func multiOutputRes() (int, string) {
return 42, "24"
}
func returnInterfaceCalls() interfaceCalls {
return structCalls{}
}
func methodOps() {
sprintf := makeSprintf("methodOps")
defer func() {
sprintf("from defer")
}()
defer sprintf("from defer 2")
var wg sync.WaitGroup
wg.Add(1)
go func() {
sprintf("from go")
wg.Done()
}()
wg.Wait()
i, s := multiOutputRes()
sprintf(strconv.Itoa(i))
var strct structCalls
strct.Return1()
strct.Return2()
intrfs := returnInterfaceCalls()
intrfs.Return1()
sprintf(strconv.Itoa(len(s)))
strconv.Itoa(binary.Size(4))
sprintf(binary.LittleEndian.AppendUint32(nil, 42))
if len(s) == 0 {
panic("unreachable")
}
sprintf(*unsafe.StringData(s))
thunkMethod1 := structCalls.Return1
sprintf(thunkMethod1(strct))
thunkMethod2 := (*structCalls).Return2
sprintf(thunkMethod2(&strct))
closureVar := "c " + s
anonFnc := func(n func(structCalls) string) string {
return n(structCalls{}) + "anon" + closureVar
}
sprintf(anonFnc(structCalls.Return1))
}
func chanOps() {
sprintf := makeSprintf("chanOps")
a := make(chan string)
b := make(chan string)
c := make(chan string)
d := make(chan string)
select {
case r1, ok := <-a:
sprintf(r1, ok)
case r2 := <-b:
sprintf(r2)
case <-c:
sprintf("r3")
case d <- "test":
sprintf("d triggered")
default:
sprintf("default")
}
e := make(chan string, 1)
e <- "hi"
sprintf(<-e)
close(a)
val, ok := <-a
sprintf(val, ok)
return
}
func flowOps() {
sprintf := makeSprintf("flowOps")
i := 1
if return42()%2 == 0 {
sprintf("a")
i++
} else {
sprintf("b")
}
sprintf(i)
switch return42() {
case 1:
sprintf("1")
case 2:
sprintf("2")
case 3:
sprintf("3")
case 42:
sprintf("42")
}
}
type interfaceB interface {
}
type testStruct struct {
A, B int
}
func typeOps() {
sprintf := makeSprintf("typeOps")
// *ssa.ChangeType
var interA interfaceCalls
sprintf(interA)
// *ssa.ChangeInterface
var interB interfaceB = struct{}{}
var inter0 interface{} = interB
sprintf(inter0)
// *ssa.Convert
var f float64 = 1.0
sprintf(int(f))
casted, ok := inter0.(interfaceB)
sprintf(casted, ok)
casted2 := inter0.(interfaceB)
sprintf(casted2)
strc := testStruct{return42(), return42() + 2}
strc.B += strc.A
sprintf(strc)
// Access to unexported structure
discard := io.Discard
if return42() == 0 {
sprintf(discard) // Trigger phi block
}
_, _ = discard.Write([]byte("test"))
}`
func TestConvert(t *testing.T) {
runGoFile := func(f string) string {
cmd := exec.Command("go", "run", f)
out, err := cmd.CombinedOutput()
if err != nil {
t.Fatalf("compile failed: %v\n%s", err, string(out))
}
return string(out)
}
testFile := filepath.Join(t.TempDir(), "convert.go")
if err := os.WriteFile(testFile, []byte(mainSrc), 0o777); err != nil {
t.Fatal(err)
}
originalOut := runGoFile(testFile)
file, fset, _, _ := mustParseAndTypeCheckFile(mainSrc)
ssaPkg, _, err := ssautil.BuildPackage(&types.Config{Importer: importer.Default()}, fset, types.NewPackage("test/main", ""), []*ast.File{file}, 0)
if err != nil {
t.Fatal(err)
}
for fIdx, decl := range file.Decls {
funcDecl, ok := decl.(*ast.FuncDecl)
if !ok {
continue
}
path, _ := astutil.PathEnclosingInterval(file, funcDecl.Pos(), funcDecl.Pos())
ssaFunc := ssa.EnclosingFunction(ssaPkg, path)
astFunc, err := Convert(ssaFunc, DefaultConfig())
if err != nil {
t.Fatal(err)
}
file.Decls[fIdx] = astFunc
}
convertedFile := filepath.Join(t.TempDir(), "main.go")
f, err := os.Create(convertedFile)
if err != nil {
t.Fatal(err)
}
if err := printer.Fprint(f, fset, file); err != nil {
t.Fatal(err)
}
_ = f.Close()
convertedOut := runGoFile(convertedFile)
if convertedOut != originalOut {
t.Fatalf("Output not equals:\n\n%s\n\n%s", originalOut, convertedOut)
}
}

View File

@@ -0,0 +1,72 @@
package ssa2ast
import (
"go/ast"
"go/importer"
"go/parser"
"go/token"
"go/types"
"github.com/google/go-cmp/cmp/cmpopts"
)
var astCmpOpt = cmpopts.IgnoreTypes(token.NoPos, &ast.Object{})
func findStruct(file *ast.File, structName string) (name *ast.Ident, structType *ast.StructType) {
ast.Inspect(file, func(node ast.Node) bool {
if structType != nil {
return false
}
typeSpec, ok := node.(*ast.TypeSpec)
if !ok || typeSpec.Name == nil || typeSpec.Name.Name != structName {
return true
}
typ, ok := typeSpec.Type.(*ast.StructType)
if !ok {
return true
}
structType = typ
name = typeSpec.Name
return true
})
if structType == nil {
panic(structName + " not found")
}
return
}
func findFunc(file *ast.File, funcName string) *ast.FuncDecl {
for _, decl := range file.Decls {
fDecl, ok := decl.(*ast.FuncDecl)
if ok && fDecl.Name.Name == funcName {
return fDecl
}
}
panic(funcName + " not found")
}
func mustParseAndTypeCheckFile(src string) (*ast.File, *token.FileSet, *types.Info, *types.Package) {
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, "a.go", src, 0)
if err != nil {
panic(err)
}
config := types.Config{Importer: importer.Default()}
info := &types.Info{
Types: make(map[ast.Expr]types.TypeAndValue),
Defs: make(map[*ast.Ident]types.Object),
Uses: make(map[*ast.Ident]types.Object),
Instances: make(map[*ast.Ident]types.Instance),
Implicits: make(map[ast.Node]types.Object),
Scopes: make(map[ast.Node]*types.Scope),
Selections: make(map[*ast.SelectorExpr]*types.Selection),
}
pkg, err := config.Check("test/main", fset, []*ast.File{f}, info)
if err != nil {
panic(err)
}
return f, fset, info, pkg
}

View File

@@ -0,0 +1,176 @@
package ssa2ast
import (
"go/ast"
"go/token"
"go/types"
)
func makeMapIteratorPolyfill(tc *typeConverter, mapType *types.Map) (ast.Expr, types.Type, error) {
keyTypeExpr, err := tc.Convert(mapType.Key())
if err != nil {
return nil, nil, err
}
valueTypeExpr, err := tc.Convert(mapType.Elem())
if err != nil {
return nil, nil, err
}
nextType := types.NewSignatureType(nil, nil, nil, nil, types.NewTuple(
types.NewVar(token.NoPos, nil, "", types.Typ[types.Bool]),
types.NewVar(token.NoPos, nil, "", mapType.Key()),
types.NewVar(token.NoPos, nil, "", mapType.Elem()),
), false)
// Generated using https://github.com/lu4p/astextract from snippet:
/*
func(m map[<key type>]<value type>) func() (bool, <key type>, <value type>) {
keys := make([]<key type>, 0, len(m))
for k := range m {
keys = append(keys, k)
}
i := 0
return func() (ok bool, k <key type>, r <value type>) {
if i < len(keys) {
k = keys[i]
ok, r = true, m[k]
i++
}
return
}
}
*/
return &ast.FuncLit{
Type: &ast.FuncType{
Params: &ast.FieldList{List: []*ast.Field{{
Names: []*ast.Ident{{Name: "m"}},
Type: &ast.MapType{
Key: keyTypeExpr,
Value: valueTypeExpr,
},
}}},
Results: &ast.FieldList{List: []*ast.Field{{
Type: &ast.FuncType{
Params: &ast.FieldList{},
Results: &ast.FieldList{List: []*ast.Field{
{Type: &ast.Ident{Name: "bool"}},
{Type: keyTypeExpr},
{Type: valueTypeExpr},
}},
},
}}},
},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.AssignStmt{
Lhs: []ast.Expr{&ast.Ident{Name: "keys"}},
Tok: token.DEFINE,
Rhs: []ast.Expr{
&ast.CallExpr{
Fun: &ast.Ident{Name: "make"},
Args: []ast.Expr{
&ast.ArrayType{Elt: keyTypeExpr},
&ast.BasicLit{Kind: token.INT, Value: "0"},
&ast.CallExpr{
Fun: &ast.Ident{Name: "len"},
Args: []ast.Expr{&ast.Ident{Name: "m"}},
},
},
},
},
},
&ast.RangeStmt{
Key: &ast.Ident{Name: "k"},
Tok: token.DEFINE,
X: &ast.Ident{Name: "m"},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.AssignStmt{
Lhs: []ast.Expr{&ast.Ident{Name: "keys"}},
Tok: token.ASSIGN,
Rhs: []ast.Expr{
&ast.CallExpr{
Fun: &ast.Ident{Name: "append"},
Args: []ast.Expr{
&ast.Ident{Name: "keys"},
&ast.Ident{Name: "k"},
},
},
},
},
},
},
},
&ast.AssignStmt{
Lhs: []ast.Expr{&ast.Ident{Name: "i"}},
Tok: token.DEFINE,
Rhs: []ast.Expr{&ast.BasicLit{Kind: token.INT, Value: "0"}},
},
&ast.ReturnStmt{Results: []ast.Expr{
&ast.FuncLit{
Type: &ast.FuncType{
Params: &ast.FieldList{},
Results: &ast.FieldList{List: []*ast.Field{
{
Names: []*ast.Ident{{Name: "ok"}},
Type: &ast.Ident{Name: "bool"},
},
{
Names: []*ast.Ident{{Name: "k"}},
Type: keyTypeExpr,
},
{
Names: []*ast.Ident{{Name: "r"}},
Type: valueTypeExpr,
},
}},
},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.IfStmt{
Cond: &ast.BinaryExpr{
X: &ast.Ident{Name: "i"},
Op: token.LSS,
Y: &ast.CallExpr{
Fun: &ast.Ident{Name: "len"},
Args: []ast.Expr{&ast.Ident{Name: "keys"}},
},
},
Body: &ast.BlockStmt{List: []ast.Stmt{
&ast.AssignStmt{
Lhs: []ast.Expr{&ast.Ident{Name: "k"}},
Tok: token.ASSIGN,
Rhs: []ast.Expr{&ast.IndexExpr{
X: &ast.Ident{Name: "keys"},
Index: &ast.Ident{Name: "i"},
}},
},
&ast.AssignStmt{
Lhs: []ast.Expr{
&ast.Ident{Name: "ok"},
&ast.Ident{Name: "r"},
},
Tok: token.ASSIGN,
Rhs: []ast.Expr{
&ast.Ident{Name: "true"},
&ast.IndexExpr{
X: &ast.Ident{Name: "m"},
Index: &ast.Ident{Name: "k"},
},
},
},
&ast.IncDecStmt{
X: &ast.Ident{Name: "i"},
Tok: token.INC,
},
}},
},
&ast.ReturnStmt{},
},
},
},
}},
},
},
}, nextType, nil
}

249
internal/ssa2ast/type.go Normal file
View File

@@ -0,0 +1,249 @@
package ssa2ast
import (
"fmt"
"go/ast"
"go/token"
"go/types"
"reflect"
"strconv"
)
type typeConverter struct {
resolver ImportNameResolver
}
func (tc *typeConverter) Convert(t types.Type) (ast.Expr, error) {
switch typ := t.(type) {
case *types.Array:
eltExpr, err := tc.Convert(typ.Elem())
if err != nil {
return nil, err
}
return &ast.ArrayType{
Len: &ast.BasicLit{
Kind: token.INT,
Value: strconv.FormatInt(typ.Len(), 10),
},
Elt: eltExpr,
}, nil
case *types.Basic:
if typ.Kind() == types.UnsafePointer {
unsafePkgIdent := tc.resolver(types.Unsafe)
if unsafePkgIdent == nil {
return nil, fmt.Errorf("cannot resolve unsafe package")
}
return &ast.SelectorExpr{X: unsafePkgIdent, Sel: ast.NewIdent("Pointer")}, nil
}
return ast.NewIdent(typ.Name()), nil
case *types.Chan:
chanValueExpr, err := tc.Convert(typ.Elem())
if err != nil {
return nil, err
}
chanExpr := &ast.ChanType{Value: chanValueExpr}
switch typ.Dir() {
case types.SendRecv:
chanExpr.Dir = ast.SEND | ast.RECV
case types.RecvOnly:
chanExpr.Dir = ast.RECV
case types.SendOnly:
chanExpr.Dir = ast.SEND
}
return chanExpr, nil
case *types.Interface:
methods := &ast.FieldList{}
for i := 0; i < typ.NumEmbeddeds(); i++ {
embeddedType := typ.EmbeddedType(i)
embeddedExpr, err := tc.Convert(embeddedType)
if err != nil {
return nil, err
}
methods.List = append(methods.List, &ast.Field{Type: embeddedExpr})
}
for i := 0; i < typ.NumExplicitMethods(); i++ {
method := typ.ExplicitMethod(i)
methodSig, err := tc.Convert(method.Type())
if err != nil {
return nil, err
}
methods.List = append(methods.List, &ast.Field{
Names: []*ast.Ident{ast.NewIdent(method.Name())},
Type: methodSig,
})
}
return &ast.InterfaceType{Methods: methods}, nil
case *types.Map:
keyExpr, err := tc.Convert(typ.Key())
if err != nil {
return nil, err
}
valueExpr, err := tc.Convert(typ.Elem())
if err != nil {
return nil, err
}
return &ast.MapType{Key: keyExpr, Value: valueExpr}, nil
case *types.Named:
obj := typ.Obj()
// TODO: rewrite struct inlining without reflection hack
if parent := obj.Parent(); parent != nil {
isFuncScope := reflect.ValueOf(parent).Elem().FieldByName("isFunc")
if isFuncScope.Bool() {
return tc.Convert(obj.Type().Underlying())
}
}
var namedExpr ast.Expr
if pkgIdent := tc.resolver(obj.Pkg()); pkgIdent != nil {
// reference to unexported named emulated through new interface with explicit declarated methods
if !token.IsExported(obj.Name()) {
var methods []*types.Func
for i := 0; i < typ.NumMethods(); i++ {
method := typ.Method(i)
if token.IsExported(method.Name()) {
methods = append(methods, method)
}
}
fakeInterface := types.NewInterfaceType(methods, nil)
return tc.Convert(fakeInterface)
}
namedExpr = &ast.SelectorExpr{X: pkgIdent, Sel: ast.NewIdent(obj.Name())}
} else {
namedExpr = ast.NewIdent(obj.Name())
}
typeParams := typ.TypeArgs()
if typeParams == nil || typeParams.Len() == 0 {
return namedExpr, nil
}
if typeParams.Len() == 1 {
typeParamExpr, err := tc.Convert(typeParams.At(0))
if err != nil {
return nil, err
}
return &ast.IndexExpr{X: namedExpr, Index: typeParamExpr}, nil
}
genericExpr := &ast.IndexListExpr{X: namedExpr}
for i := 0; i < typeParams.Len(); i++ {
typeArgs := typeParams.At(i)
typeParamExpr, err := tc.Convert(typeArgs)
if err != nil {
return nil, err
}
genericExpr.Indices = append(genericExpr.Indices, typeParamExpr)
}
return genericExpr, nil
case *types.Pointer:
expr, err := tc.Convert(typ.Elem())
if err != nil {
return nil, err
}
return &ast.StarExpr{X: expr}, nil
case *types.Signature:
funcSigExpr := &ast.FuncType{Params: &ast.FieldList{}}
if sigParams := typ.Params(); sigParams != nil {
for i := 0; i < sigParams.Len(); i++ {
param := sigParams.At(i)
var paramType ast.Expr
if typ.Variadic() && i == sigParams.Len()-1 {
slice := param.Type().(*types.Slice)
eltExpr, err := tc.Convert(slice.Elem())
if err != nil {
return nil, err
}
paramType = &ast.Ellipsis{Elt: eltExpr}
} else {
paramExpr, err := tc.Convert(param.Type())
if err != nil {
return nil, err
}
paramType = paramExpr
}
f := &ast.Field{Type: paramType}
if name := param.Name(); name != "" {
f.Names = []*ast.Ident{ast.NewIdent(name)}
}
funcSigExpr.Params.List = append(funcSigExpr.Params.List, f)
}
}
if sigResults := typ.Results(); sigResults != nil {
funcSigExpr.Results = &ast.FieldList{}
for i := 0; i < sigResults.Len(); i++ {
result := sigResults.At(i)
resultExpr, err := tc.Convert(result.Type())
if err != nil {
return nil, err
}
f := &ast.Field{Type: resultExpr}
if name := result.Name(); name != "" {
f.Names = []*ast.Ident{ast.NewIdent(name)}
}
funcSigExpr.Results.List = append(funcSigExpr.Results.List, f)
}
}
if typeParams := typ.TypeParams(); typeParams != nil {
funcSigExpr.TypeParams = &ast.FieldList{}
for i := 0; i < typeParams.Len(); i++ {
typeParam := typeParams.At(i)
resultExpr, err := tc.Convert(typeParam.Constraint().Underlying())
if err != nil {
return nil, err
}
f := &ast.Field{Type: resultExpr, Names: []*ast.Ident{ast.NewIdent(typeParam.Obj().Name())}}
funcSigExpr.TypeParams.List = append(funcSigExpr.TypeParams.List, f)
}
}
return funcSigExpr, nil
case *types.Slice:
eltExpr, err := tc.Convert(typ.Elem())
if err != nil {
return nil, err
}
return &ast.ArrayType{Elt: eltExpr}, nil
case *types.Struct:
fieldList := &ast.FieldList{}
for i := 0; i < typ.NumFields(); i++ {
f := typ.Field(i)
fieldExpr, err := tc.Convert(f.Type())
if err != nil {
return nil, err
}
field := &ast.Field{Type: fieldExpr}
if !f.Anonymous() {
field.Names = []*ast.Ident{ast.NewIdent(f.Name())}
}
if tag := typ.Tag(i); len(tag) > 0 {
field.Tag = &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("%q", tag)}
}
fieldList.List = append(fieldList.List, field)
}
return &ast.StructType{Fields: fieldList}, nil
case *types.TypeParam:
return ast.NewIdent(typ.Obj().Name()), nil
case *types.Union:
var unionExpr ast.Expr
for i := 0; i < typ.Len(); i++ {
term := typ.Term(i)
expr, err := tc.Convert(term.Type())
if err != nil {
return nil, err
}
if term.Tilde() {
expr = &ast.UnaryExpr{Op: token.TILDE, X: expr}
}
if unionExpr == nil {
unionExpr = expr
} else {
unionExpr = &ast.BinaryExpr{X: unionExpr, Op: token.OR, Y: expr}
}
}
return unionExpr, nil
default:
return nil, fmt.Errorf("type %v: %w", typ, ErrUnsupported)
}
}

View File

@@ -0,0 +1,108 @@
package ssa2ast
import (
"go/ast"
"testing"
"github.com/google/go-cmp/cmp"
)
const typesSrc = `package main
import (
"io"
"time"
)
type localNamed bool
type embedStruct struct {
int
}
type genericStruct[K comparable, V int64 | float64] struct {
int
}
type exampleStruct struct {
embedStruct
// *types.Array
array [3]int
array2 [0]int
// *types.Basic
bool // anonymous
string string "test:\"tag\""
int int
int8 int8
int16 int16
int32 int32
int64 int64
uint uint
uint8 uint8
uint16 uint16
uint32 uint32
uint64 uint64
uintptr uintptr
byte byte
rune rune
float32 float32
float64 float64
complex64 complex64
complex128 complex128
// *types.Chan
chanSendRecv chan struct{}
chanRecv <-chan struct{}
chanSend chan<- struct{}
// *types.Interface
interface1 interface{}
interface2 interface{ io.Reader }
interface3 interface{ Dummy(int) bool }
interface4 interface {
io.Reader
io.ByteReader
Dummy(int) bool
}
// *types.Map
strMap map[string]string
// *types.Named
localNamed localNamed
importedNamed time.Month
// *types.Pointer
pointer1 *string
pointer2 **string
// *types.Signature
func1 func(int, int) int
func2 func(a int, b int, varargs ...struct{ string }) (res int)
// *types.Slice
slice1 []int
slice2 [][]int
// generics
generic genericStruct[genericStruct[genericStruct[bool, int64], int64], int64]
}
`
func TestTypeToExpr(t *testing.T) {
f, _, info, _ := mustParseAndTypeCheckFile(typesSrc)
name, structAst := findStruct(f, "exampleStruct")
obj := info.Defs[name]
fc := &typeConverter{resolver: defaultImportNameResolver}
convAst, err := fc.Convert(obj.Type().Underlying())
if err != nil {
t.Fatal(err)
}
structConvAst := convAst.(*ast.StructType)
if structDiff := cmp.Diff(structAst, structConvAst, astCmpOpt); structDiff != "" {
t.Fatalf("struct not equals: %s", structDiff)
}
}