feat: add Go AST libary

This commit is contained in:
zhuyasen
2025-04-06 19:41:41 +08:00
parent 358e7939db
commit 04d9f87ee3
9 changed files with 2544 additions and 0 deletions

69
pkg/goast/README.md Normal file
View File

@@ -0,0 +1,69 @@
## goast
`goast` is a library for parsing Go code and extracting information, it supports merging two Go files into one.
## Example of use
### Parse Go code and extract information
```go
package main
import (
"fmt"
"github.com/go-dev-frame/sponge/pkg/goast"
)
func main() {
src := []byte(`package main
import (
"fmt"
)
func main() {
fmt.Println("Hello, world!")
}
`)
// Case 1: Parse Go code and extract information
{
astInfos, err := goast.ParseGoCode("main.go", src)
}
// Case 2: Parse file and extract information
{
astInfos, err := goast.ParseFile("main.go")
}
}
```
### Merge two Go files into one
```go
package main
import (
"fmt"
"github.com/go-dev-frame/sponge/pkg/goast"
)
func main() {
const (
srcFile = "data/src.go.code"
genFile = "data/gen.go.code"
)
// Case 1: without covering the same function
{
codeAst, err := goast.MergeGoFile(srcFile, genFile)
fmt.Println(codeAst.Code)
}
// Case 2: with covering the same function
{
codeAst, err := goast.MergeGoFile(srcFile, genFile, goast.WithCoverSameFunc())
fmt.Println(codeAst.Code)
}
}
```

814
pkg/goast/ast.go Normal file
View File

@@ -0,0 +1,814 @@
// Package goast is a library for parsing Go code and extracting information from it.
package goast
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"os"
"path/filepath"
"strings"
)
const (
// ast types
PackageType = "package"
ImportType = "import"
ConstType = "const"
VarType = "var"
FuncType = "func"
TypeType = "type"
// for TypeType
StructType = "struct"
InterfaceType = "interface"
ArrayType = "array"
MapType = "map"
ChanType = "chan"
)
// AstInfo Go code block information
type AstInfo struct {
// Type is the type of the code block, such as "func", "type", "const", "var", "import", "package".
Type string
// Names is the name of the code block, such as "func Name", "type Names", "const Names", "var Names", "import Paths".
// If Type is "func", a standalone function without a receiver has a single name.
// If the function is a method belonging to a struct, it has two names: the first
// represents the function name, and the second represents the struct name.
Names []string
Comment string
Body string
}
func (a *AstInfo) IsPackageType() bool {
return a.Type == PackageType
}
func (a *AstInfo) IsImportType() bool {
return a.Type == ImportType
}
func (a *AstInfo) IsConstType() bool {
return a.Type == ConstType
}
func (a *AstInfo) IsVarType() bool {
return a.Type == VarType
}
func (a *AstInfo) IsTypeType() bool {
return a.Type == TypeType
}
func (a *AstInfo) IsFuncType() bool {
return a.Type == FuncType
}
func (a *AstInfo) GetName() string {
return strings.Join(a.Names, ",")
}
// ParseFile parses a go file and returns a list of AstInfo
func ParseFile(goFilePath string) ([]*AstInfo, error) {
filename := filepath.Base(goFilePath)
data, err := os.ReadFile(goFilePath)
if err != nil {
return nil, err
}
return ParseGoCode(filename, data)
}
// ParseGoCode parses a go code and returns a list of AstInfo
func ParseGoCode(filename string, data []byte) ([]*AstInfo, error) {
src := string(data)
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, filename, src, parser.ParseComments)
if err != nil {
return nil, err
}
var astInfos []*AstInfo
pkgNames, pkgComment, pkgBody := getPackageCode(fset, file, src)
astInfos = append(astInfos, &AstInfo{Type: PackageType, Names: pkgNames, Comment: pkgComment, Body: pkgBody})
// traverse AST code blocks
ast.Inspect(file, func(n ast.Node) bool {
switch node := n.(type) {
case *ast.FuncDecl:
receiverName, comment, body := getFuncDeclCode(fset, node, src)
names := []string{node.Name.Name}
if receiverName != "" {
names = append(names, receiverName)
}
astInfos = append(astInfos, &AstInfo{Type: FuncType, Names: names, Comment: comment, Body: body})
case *ast.GenDecl:
names, comment, body := getGenDeclCode(fset, node, src)
astInfos = append(astInfos, &AstInfo{Type: node.Tok.String(), Names: names, Comment: comment, Body: body})
//case *ast.BadDecl:
// code := getBadDeclCode(fset, node, src)
// println(code)
}
return true
})
return astInfos, nil
}
func getPackageCode(fset *token.FileSet, f *ast.File, src string) (names []string, comment string, body string) {
if f.Doc != nil {
var comments []string
for _, cmt := range f.Doc.List {
comments = append(comments, cmt.Text)
}
comment = strings.Join(comments, "\n")
}
packagePos := fset.Position(f.Package)
body = src[packagePos.Offset : packagePos.Offset+len("package "+f.Name.Name)]
return []string{f.Name.Name}, comment, body
}
func getFuncDeclCode(fset *token.FileSet, fn *ast.FuncDecl, src string) (receiverName string, comment string, body string) {
if fn == nil {
return "", "", ""
}
if fn.Recv != nil && len(fn.Recv.List) > 0 {
recvType := fn.Recv.List[0].Type
switch t := recvType.(type) {
case *ast.StarExpr:
if ident, ok := t.X.(*ast.Ident); ok {
receiverName = ident.Name
}
case *ast.Ident:
receiverName = t.Name
}
}
commentText := ""
if fn.Doc != nil {
var parts []string
for _, c := range fn.Doc.List {
parts = append(parts, strings.TrimSpace(c.Text))
}
commentText = strings.Join(parts, "\n")
}
start := fn.Type.Func // the starting position of the func keyword
end := fn.Body.Rbrace + 1 // end position of function body
return receiverName, commentText, getCodeFromPos(fset, start, end, src)
}
func getCodeFromPos(fset *token.FileSet, start, end token.Pos, src string) string {
file := fset.File(start)
if file == nil {
return ""
}
startOffset := file.Offset(start)
endOffset := file.Offset(end)
if startOffset < 0 || endOffset > len(src) || startOffset >= endOffset {
return ""
}
return src[startOffset:endOffset]
}
func getGenDeclCode(fset *token.FileSet, gen *ast.GenDecl, src string) (names []string, comment string, body string) {
if gen == nil {
return nil, "", ""
}
commentText := ""
if gen.Doc != nil {
var parts []string
for _, c := range gen.Doc.List {
parts = append(parts, strings.TrimSpace(c.Text))
}
commentText = strings.Join(parts, "\n")
}
start := gen.TokPos // keyword starting position
var end token.Pos
if gen.Rparen.IsValid() {
end = gen.Rparen + 1 // end position of parentheses
} else if len(gen.Specs) > 0 {
lastSpec := gen.Specs[len(gen.Specs)-1]
end = lastSpec.End() // end position of the last Spec
} else {
end = start + token.Pos(len(gen.Tok.String())) // in the case of keywords only
}
return getGenName(gen), commentText, getCodeFromPos(fset, start, end, src)
}
func getGenName(gen *ast.GenDecl) []string {
var names []string
switch gen.Tok {
case token.IMPORT:
for _, spec := range gen.Specs {
imp := spec.(*ast.ImportSpec)
names = append(names, imp.Path.Value)
}
case token.CONST:
for _, spec := range gen.Specs {
val := spec.(*ast.ValueSpec)
names = append(names, val.Names[0].Name)
}
case token.TYPE:
for _, spec := range gen.Specs {
typ := spec.(*ast.TypeSpec)
names = append(names, typ.Name.Name)
}
case token.VAR:
for _, spec := range gen.Specs {
val := spec.(*ast.ValueSpec)
names = append(names, val.Names[0].Name)
}
}
return names
}
// -----------------------------------------------------------------------------------
func adaptPackage(src string) string {
if len(src) > 50 {
if strings.Contains(src[:50], "package ") {
return src
}
}
if strings.Contains(src, "\npackage ") {
return src
}
return "package parse\n\n" + src
}
// nolint
func parseBody(body string) (*token.FileSet, *ast.File, string, error) {
src := adaptPackage(body)
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, "", src, parser.ParseComments)
if err != nil {
f, err = parser.ParseFile(fset, "", "package parse\n\n"+src, parser.ParseComments)
if err != nil {
return nil, nil, "", err
}
}
return fset, f, src, nil
}
type ImportInfo struct {
Path string
Alias string
Comment string
Body string
}
// ParseImportGroup parse import group from source code
func ParseImportGroup(body string) ([]*ImportInfo, error) {
fset, f, src, err := parseBody(body)
if err != nil {
return nil, err
}
var srcLines = strings.Split(src, "\n")
var imports []*ImportInfo
for _, decl := range f.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok || genDecl.Tok != token.IMPORT {
continue
}
for _, spec := range genDecl.Specs {
importSpec := spec.(*ast.ImportSpec)
// get path
path := strings.Trim(importSpec.Path.Value, `"`)
// get alias
var alias string
if importSpec.Name != nil {
alias = importSpec.Name.Name
}
// get comment doc
var comment string
if importSpec.Doc != nil {
comment = getSrcContent(srcLines, fset.Position(importSpec.Doc.List[0].Pos()).Line,
fset.Position(importSpec.Doc.List[len(importSpec.Doc.List)-1].End()).Line)
}
// get source code of import path
code := getSrcContent(srcLines, fset.Position(importSpec.Pos()).Line, fset.Position(importSpec.End()).Line)
imports = append(imports, &ImportInfo{
Path: path,
Alias: alias,
Comment: comment,
Body: code,
})
}
}
return imports, nil
}
type ConstInfo struct {
Name string
Value string
Comment string
Body string
}
// ParseConstGroup parse const group from source code
func ParseConstGroup(body string) ([]*ConstInfo, error) {
fset, f, src, err := parseBody(body)
if err != nil {
return nil, err
}
var srcLines = strings.Split(src, "\n")
var consts []*ConstInfo
for _, decl := range f.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok || genDecl.Tok != token.CONST {
continue
}
singleComment := ""
if genDecl.Doc != nil {
singleComment = getSrcContent(srcLines, fset.Position(genDecl.Doc.List[0].Pos()).Line,
fset.Position(genDecl.Doc.List[len(genDecl.Doc.List)-1].End()).Line)
}
for _, spec := range genDecl.Specs {
valueSpec := spec.(*ast.ValueSpec)
for i, name := range valueSpec.Names {
constName := name.Name
// get line content
var comment string
if valueSpec.Doc != nil {
comment = getSrcContent(srcLines, fset.Position(valueSpec.Doc.List[0].Pos()).Line,
fset.Position(valueSpec.Doc.List[len(valueSpec.Doc.List)-1].End()).Line)
}
if len(genDecl.Specs) == 1 && singleComment != "" && comment == "" {
comment = singleComment
}
// get code content
code := getSrcContent(srcLines, fset.Position(valueSpec.Pos()).Line, fset.Position(valueSpec.End()).Line)
// get value (if exists)
var constValue string
if i < len(valueSpec.Values) {
if basicLit, ok := valueSpec.Values[i].(*ast.BasicLit); ok {
constValue = basicLit.Value
}
}
consts = append(consts, &ConstInfo{
Name: constName,
Value: constValue,
Comment: comment,
Body: code,
})
}
}
}
return consts, nil
}
type VarInfo struct {
Name string
Value string
Comment string
Body string
}
// ParseVarGroup parse var group from source code
func ParseVarGroup(body string) ([]*VarInfo, error) {
fset, f, src, err := parseBody(body)
if err != nil {
return nil, err
}
var srcLines = strings.Split(src, "\n")
var vars []*VarInfo
for _, decl := range f.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok || genDecl.Tok != token.VAR {
continue
}
singleComment := ""
if genDecl.Doc != nil {
singleComment = getSrcContent(srcLines, fset.Position(genDecl.Doc.List[0].Pos()).Line,
fset.Position(genDecl.Doc.List[len(genDecl.Doc.List)-1].End()).Line)
}
for _, spec := range genDecl.Specs {
valueSpec := spec.(*ast.ValueSpec)
for i, name := range valueSpec.Names {
varName := name.Name
// get comment
var comment string
if valueSpec.Doc != nil {
comment = getSrcContent(srcLines, fset.Position(valueSpec.Doc.List[0].Pos()).Line,
fset.Position(valueSpec.Doc.List[len(valueSpec.Doc.List)-1].End()).Line)
}
if len(genDecl.Specs) == 1 && singleComment != "" && comment == "" {
comment = singleComment
}
// get code content
code := getSrcContent(srcLines, fset.Position(valueSpec.Pos()).Line, fset.Position(valueSpec.End()).Line)
// get var value (if exists)
var varValue string
if i < len(valueSpec.Values) {
if basicLit, ok := valueSpec.Values[i].(*ast.BasicLit); ok {
varValue = basicLit.Value
}
}
vars = append(vars, &VarInfo{
Name: varName,
Value: varValue,
Comment: comment,
Body: code,
})
}
}
}
return vars, nil
}
type TypeInfo struct {
Type string
Name string
Comment string
Body string
IsIdent bool
}
// ParseTypeGroup parse type group from source code
func ParseTypeGroup(body string) ([]*TypeInfo, error) {
fset, f, src, err := parseBody(body)
if err != nil {
return nil, err
}
var srcLines = strings.Split(src, "\n")
var types []*TypeInfo
for _, decl := range f.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok || genDecl.Tok != token.TYPE {
continue
}
singleComment := ""
if genDecl.Doc != nil {
singleComment = getSrcContent(srcLines, fset.Position(genDecl.Doc.List[0].Pos()).Line,
fset.Position(genDecl.Doc.List[len(genDecl.Doc.List)-1].End()).Line)
}
for _, spec := range genDecl.Specs {
typeSpec := spec.(*ast.TypeSpec)
typeName := typeSpec.Name.Name
// get comment
var comment string
if typeSpec.Doc != nil {
comment = getSrcContent(srcLines, fset.Position(typeSpec.Doc.List[0].Pos()).Line,
fset.Position(typeSpec.Doc.List[len(typeSpec.Doc.List)-1].End()).Line)
}
if len(genDecl.Specs) == 1 && singleComment != "" && comment == "" {
comment = singleComment
}
// get code content
code := getSrcContent(srcLines, fset.Position(typeSpec.Pos()).Line, fset.Position(typeSpec.End()).Line)
// get type definition
var isIdent bool
var typeDef string
switch t := typeSpec.Type.(type) {
case *ast.StructType:
typeDef = StructType
case *ast.InterfaceType:
typeDef = InterfaceType
case *ast.FuncType:
typeDef = FuncType
case *ast.MapType:
typeDef = MapType
case *ast.ArrayType:
typeDef = ArrayType
case *ast.ChanType:
typeDef = ChanType
case *ast.Ident:
typeDef = t.Name
isIdent = true
default:
typeDef = fmt.Sprintf("%T", t)
}
types = append(types, &TypeInfo{
Type: typeDef,
Name: typeName,
Comment: comment,
Body: code,
IsIdent: isIdent,
})
}
}
return types, nil
}
type InterfaceInfo struct {
Name string
Comment string
MethodInfos []*MethodInfo
}
// ParseInterface parse interface group from source code
func ParseInterface(body string) ([]*InterfaceInfo, error) {
fset, f, src, err := parseBody(body)
if err != nil {
return nil, err
}
var srcLines = strings.Split(src, "\n")
var interfaceInfos []*InterfaceInfo
for _, decl := range f.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok || genDecl.Tok != token.TYPE {
continue
}
singleComment := ""
if genDecl.Doc != nil {
singleComment = getSrcContent(srcLines, fset.Position(genDecl.Doc.List[0].Pos()).Line,
fset.Position(genDecl.Doc.List[len(genDecl.Doc.List)-1].End()).Line)
}
var methodInfos []*MethodInfo
for _, spec := range genDecl.Specs {
typeSpec := spec.(*ast.TypeSpec)
interfaceType, ok := typeSpec.Type.(*ast.InterfaceType)
if !ok {
continue
}
interfaceName := typeSpec.Name.Name
// get interface comment
var interfaceComment string
if typeSpec.Doc != nil {
interfaceComment = getSrcContent(srcLines, fset.Position(typeSpec.Doc.List[0].Pos()).Line,
fset.Position(typeSpec.Doc.List[len(typeSpec.Doc.List)-1].End()).Line)
}
if len(genDecl.Specs) == 1 && singleComment != "" && interfaceComment == "" {
interfaceComment = singleComment
}
var isIdent bool
for _, method := range interfaceType.Methods.List {
// get method name
var methodName string
switch t := method.Type.(type) {
case *ast.FuncType:
if len(method.Names) > 0 {
methodName = method.Names[0].Name
}
case *ast.Ident: // embedded interface
methodName = t.Name
isIdent = true
default:
continue
}
// get method comment
var methodComment string
if method.Doc != nil {
methodComment = getSrcContent(srcLines, fset.Position(method.Doc.List[0].Pos()).Line,
fset.Position(method.Doc.List[len(method.Doc.List)-1].End()).Line)
}
// get method line content
code := getSrcContent(srcLines, fset.Position(method.Pos()).Line, fset.Position(method.End()).Line)
methodInfos = append(methodInfos, &MethodInfo{
Name: methodName,
Comment: methodComment,
Body: code,
ReceiverName: interfaceName,
IsIdent: isIdent,
})
}
interfaceInfos = append(interfaceInfos, &InterfaceInfo{
Name: interfaceName,
Comment: interfaceComment,
MethodInfos: methodInfos,
})
}
}
return interfaceInfos, nil
}
// MethodInfo method function info
type MethodInfo struct {
Name string
Comment string
Body string
ReceiverName string
IsIdent bool
}
// ParseStructMethods parse struct methods from ast infos
func ParseStructMethods(astInfos []*AstInfo) map[string][]*MethodInfo {
var m = make(map[string][]*MethodInfo) // map[structName][]*MethodInfo
for _, info := range astInfos {
if !info.IsFuncType() {
continue
}
if len(info.Names) == 2 {
funcName := info.Names[0]
structName := info.Names[1]
methodAst := &MethodInfo{
Name: funcName,
Comment: info.Comment,
Body: info.Body,
ReceiverName: structName,
}
if methodInfos, ok := m[structName]; !ok {
m[structName] = []*MethodInfo{methodAst}
} else {
m[structName] = append(methodInfos, methodAst)
}
}
}
return m
}
type StructInfo struct {
Name string
Comment string
Fields []*StructFieldInfo
}
type StructFieldInfo struct {
Name string
Type string
Comment string
Body string
}
// ParseStruct parse struct info from source code
func ParseStruct(body string) (map[string]*StructInfo, error) { //nolint
fset, f, src, err := parseBody(body)
if err != nil {
return nil, err
}
var srcLines = strings.Split(src, "\n")
var structInfos = make(map[string]*StructInfo)
for _, decl := range f.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok || genDecl.Tok != token.TYPE {
continue
}
singleComment := ""
if genDecl.Doc != nil {
singleComment = getSrcContent(srcLines, fset.Position(genDecl.Doc.List[0].Pos()).Line,
fset.Position(genDecl.Doc.List[len(genDecl.Doc.List)-1].End()).Line)
}
for _, spec := range genDecl.Specs {
typeSpec, ok := spec.(*ast.TypeSpec)
if !ok {
continue
}
structType, ok := typeSpec.Type.(*ast.StructType)
if !ok {
continue
}
structName := typeSpec.Name.Name
// get struct comment
var structComment string
if typeSpec.Doc != nil {
structComment = getSrcContent(srcLines, fset.Position(typeSpec.Doc.List[0].Pos()).Line,
fset.Position(typeSpec.Doc.List[len(typeSpec.Doc.List)-1].End()).Line)
}
if len(genDecl.Specs) == 1 && singleComment != "" && structComment == "" {
structComment = singleComment
}
var fields []*StructFieldInfo
for _, field := range structType.Fields.List {
var fieldNames []string
if len(field.Names) > 0 {
for _, name := range field.Names {
fieldNames = append(fieldNames, name.Name)
}
} else {
// 处理嵌入字段
switch x := field.Type.(type) {
case *ast.Ident:
fieldNames = append(fieldNames, x.Name)
case *ast.StarExpr:
if ident, ok := x.X.(*ast.Ident); ok {
fieldNames = append(fieldNames, "*"+ident.Name)
}
}
}
// get field name
var fieldName string
if len(fieldNames) > 0 {
fieldName = fieldNames[0]
}
// get comment
var comment string
if field.Doc != nil {
comment = getSrcContent(srcLines, fset.Position(field.Doc.List[0].Pos()).Line,
fset.Position(field.Doc.List[len(field.Doc.List)-1].End()).Line)
}
// get source code of field
code := getSrcContent(srcLines, fset.Position(field.Pos()).Line, fset.Position(field.End()).Line)
fields = append(fields, &StructFieldInfo{
Name: fieldName,
Type: getTypeString(field.Type),
Comment: comment,
Body: code,
})
}
structInfos[structName] = &StructInfo{
Name: structName,
Comment: structComment,
Fields: fields,
}
}
}
return structInfos, nil
}
func getTypeString(expr ast.Expr) string {
switch t := expr.(type) {
case *ast.Ident:
return t.Name
case *ast.ArrayType:
return "[]" + getTypeString(t.Elt)
case *ast.StructType:
return "struct"
case *ast.SelectorExpr:
return getTypeString(t.X) + "." + t.Sel.Name
case *ast.StarExpr:
return "*" + getTypeString(t.X)
case *ast.MapType:
return "map[" + getTypeString(t.Key) + "]" + getTypeString(t.Value)
case *ast.InterfaceType:
return "interface{}"
case *ast.ChanType:
return "chan " + getTypeString(t.Value)
default:
return "unknown"
}
}
func getSrcContent(srcLines []string, start, end int) string {
var srcContent string
l := len(srcLines)
if start < 1 || start > l || end < 1 || end > l {
return ""
}
if start == end {
srcContent = srcLines[start-1]
} else {
srcContent = strings.Join(srcLines[start-1:end], "\n")
}
return srcContent
}

341
pkg/goast/ast_test.go Normal file
View File

@@ -0,0 +1,341 @@
package goast
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestParseGoFile(t *testing.T) {
astInfos, err := ParseFile("ast.go")
assert.NoError(t, err)
assert.Greater(t, len(astInfos), 10)
}
func TestParseGoCode(t *testing.T) {
var src = `
package main
import "fmt"
import "strings"
import (
"github.com/gin-gonic/gin"
"github.com/spf13/viper"
)
const (
pi = 3.14
language = "Go"
)
var (
version = "v1.0.0"
repo = "sponge"
)
type User struct {
Name string
Age int
}
func (u *User) SayHello() {
fmt.Println("Hello, my name is", u.Name)
}
func main() {
fmt.Println(pi)
fmt.Println(language)
fmt.Println(version)
user:=&User{Name:"Tom",Age:20}
fmt.Println(user.Name)
fmt.Println(user.Age)
user.SayHello()
}
`
astInfos, err := ParseGoCode("", []byte(src))
assert.NoError(t, err)
for _, info := range astInfos {
fmt.Printf(" %-20s: %s\n", "name", info.Names)
fmt.Printf(" %-20s: %s\n", "comment", info.Comment)
fmt.Printf(" %-20s: %s\n\n\n", "body", info.Body)
}
}
func TestParseImportGroup(t *testing.T) {
body := `
import (
"fmt"
"github.com/gin-gonic/gin"
//"github.com/spf13/viper"
apiV1 "yourModuleName/api/v1"
// api v2
apiV2 "yourModuleName/api/v2"
)
`
importInfos, err := ParseImportGroup(body)
if err != nil {
fmt.Println(err)
return
}
for _, ii := range importInfos {
fmt.Printf(" %-20s: %s\n", "name", ii.Path)
fmt.Printf(" %-20s: %s\n", "alias", ii.Alias)
fmt.Printf(" %-20s: %s\n", "comment", ii.Comment)
fmt.Printf(" %-20s: %s\n\n\n", "body", ii.Body)
}
}
func TestParseConstGroup(t *testing.T) {
body := `
// pi constant
const pi = 3.14
const (
// Version number
version = "v1.0.0"
)
const (
// Development language
language = "Go"
// database type
dbDriver = "mysql"
)
`
constInfos, err := ParseConstGroup(body)
if err != nil {
assert.NotNil(t, err)
return
}
for _, ci := range constInfos {
fmt.Printf(" %-20s: %s\n", "name", ci.Name)
fmt.Printf(" %-20s: %s\n", "value", ci.Value)
fmt.Printf(" %-20s: %s\n", "comment", ci.Comment)
fmt.Printf(" %-20s: %s\n\n\n", "body", ci.Body)
}
}
func TestParseVarGroup(t *testing.T) {
body := `
var (
// Version number
version = "v1.0.0"
// Author
author = "name"
// Repository
repo = "sponge"
// Function variable
f1 = func() {
fmt.Println("hello")
}
)
`
varInfos, err := ParseVarGroup(body)
if err != nil {
assert.NotNil(t, err)
return
}
for _, vi := range varInfos {
fmt.Printf(" %-20s: %s\n", "name", vi.Name)
fmt.Printf(" %-20s: %s\n", "value", vi.Value)
fmt.Printf(" %-20s: %s\n", "comment", vi.Comment)
fmt.Printf(" %-20s: %s\n\n\n", "body", vi.Body)
}
}
func TestParseTypeGroup(t *testing.T) {
body := `
type (
// Struct type
ts struct {
name string
}
// Function type
tfn func(name string) bool
// Interface type
iFace interface {}
// Channel type
ch chan int
// Map type
m map[string]bool
// Slice type
slice []int
)
`
typeInfos, err := ParseTypeGroup(body)
if err != nil {
assert.NotNil(t, err)
return
}
for _, ti := range typeInfos {
fmt.Printf(" %-20s: %s\n", "type", ti.Type)
fmt.Printf(" %-20s: %s\n", "name", ti.Name)
fmt.Printf(" %-20s: %s\n", "comment", ti.Comment)
fmt.Printf(" %-20s: %s\n\n\n", "body", ti.Body)
}
}
func TestParseInterface(t *testing.T) {
body := `
type GreeterDao interface {
// get by id
Create(ctx context.Context, table *model.Greeter) error
// delete by id
DeleteByID(ctx context.Context, id uint64) error
// update by id
UpdateByID(ctx context.Context, table *model.Greeter) error
UserExampleDao
}
type UserExampleDao interface {
// get by id
Create(ctx context.Context, table *model.UserExample) error
// update by id
UpdateByID(ctx context.Context, table *model.UserExample) error
}
`
interfaceInfos, err := ParseInterface(body)
if err != nil {
assert.NotNil(t, err)
return
}
for _, info := range interfaceInfos {
fmt.Printf("%-20s : %s\n", "name", info.Name)
fmt.Printf("%-20s : %s\n", "comment", info.Comment)
for _, mi := range info.MethodInfos {
fmt.Printf(" %-20s: %s\n", "method name", mi.Name)
fmt.Printf(" %-20s: %s\n", "comment", mi.Comment)
fmt.Printf(" %-20s: %s\n", "body", mi.Body)
fmt.Printf(" %-20v: %t\n\n\n", "embedded", mi.IsIdent)
}
}
}
func TestParseStructMethods(t *testing.T) {
src := `
package demo
type userHandler struct {
server userV1.UserServer
}
// Create a record
func (h *userHandler) Create(ctx context.Context, req *userV1.CreateUserRequest) (*userV1.CreateUserReply, error) {
return h.server.Create(ctx, req)
}
// DeleteByID delete a record by id
func (h *userHandler) DeleteByID(ctx context.Context, req *userV1.DeleteUserByIDRequest) (*userV1.DeleteUserByIDReply, error) {
return h.server.DeleteByID(ctx, req)
}
// UpdateByID update a record by id
func (h *userHandler) UpdateByID(ctx context.Context, req *userV1.UpdateUserByIDRequest) (*userV1.UpdateUserByIDReply, error) {
return h.server.UpdateByID(ctx, req)
}
type greeterHandler struct {
server greeterV1.GreeterServer
}
// Create a record
func (h *greeterHandler) Create(ctx context.Context, req *greeterV1.CreateGreeterRequest) (*greeterV1.CreateGreeterReply, error) {
return h.server.Create(ctx, req)
}
// DeleteByID delete a record by id
func (h *greeterHandler) DeleteByID(ctx context.Context, req *greeterV1.DeleteGreeterByIDRequest) (*greeterV1.DeleteGreeterByIDReply, error) {
return h.server.DeleteByID(ctx, req)
}
`
astInfos, err := ParseGoCode("", []byte(src))
if err != nil {
assert.NotNil(t, err)
return
}
methods := ParseStructMethods(astInfos)
for structName, methodInfos := range methods {
fmt.Printf("%-20s : %s\n", "name", structName)
for _, mi := range methodInfos {
fmt.Printf(" %-20s: %s\n", "method name", mi.Name)
fmt.Printf(" %-20s: %s\n", "comment", mi.Comment)
fmt.Printf(" %-20s: %s\n\n\n", "body", mi.Body)
}
}
}
func TestParseStruct(t *testing.T) {
body := `
package goast
// AstInfo is the information of a code block.
type AstInfo struct {
Kind string
// Names is the name of the code block, such as "func Name", "type Names", "const Names", "var Names", "import Paths".
// If Type is "func", a standalone function without a receiver has a single name.
// If the function is a method belonging to a struct, it has two names: the first
// represents the function name, and the second represents the struct name.
Names []string // todo add name
// User information
User struct {
Name string
Age string
}
// embedded struct
*Address
// embedded struct2
Address
reader interface {}
writer any
sayMap map[string]string
ch1 chan int
ch2 chan *Address
}
// Address is address
type Address struct {
State string
// Addr is address
Addr string
}
`
structInfos, err := ParseStruct(body)
if err != nil {
t.Error(err)
return
}
for name, structInfo := range structInfos {
fmt.Printf("%-20s : %s\n", "struct name", name)
fmt.Printf("%-20s : %s\n", "struct comment", structInfo.Comment)
for _, field := range structInfo.Fields {
fmt.Printf(" %-20s: %s\n", "name", field.Name)
fmt.Printf(" %-20s: %s\n", "type", field.Type)
fmt.Printf(" %-20s: %s\n", "comment", field.Comment)
fmt.Printf(" %-20s: %s\n\n", "body", field.Body)
}
fmt.Printf("\n\n\n")
}
}

105
pkg/goast/data/gen.go.code Normal file
View File

@@ -0,0 +1,105 @@
package data
import (
// this is a comment
"bytes"
)
import "fmt"
// pi is a constant
const pi = 3.14
// new const
const (
// MinRetries min retries
MinRetries = 1
// Timeout1 time out 1
Timeout1 = 60
// Timeout2 time out 2
Timeout2 = 60
Timeout = 5
)
// a global variable a
var a = 1
var b = 2
// c global variable c
var c = 3
type person struct {
email string
// age is the person's age
age int
}
// SayHello is a method
// sayHello is a method 2
func (p *person) SayHello() {
fmt.Println("Hello, my age is", p.age)
}
// GetEmail is a method
// getEmail is a method 2
func (p *person) GetEmail() {
fmt.Println("Hi, my email is", p.email)
}
type iSayer interface {
// Say is a method
Say()
// SayHello is a method
SayHello(name string)
}
type iSayer2 interface {
// Hi is a method
Hi()
}
type sayer interface {
Say()
Hi()
}
type (
fn1 func()
chan1 chan int
map1 map[string]int
slice1 []int
Speaker interface {
Say()
Hi()
}
)
var (
_ = func() {}
_ = chan1(nil)
)
var _ []int
var _ = fn1(func() {})
// Merge merges two slices
func Merge(slice [][]byte) {
fmt.Println(bytes.Join(slice, []byte("\n")))
}
func GetByID(id int) {
// do something
}
// Hi is a method
func (p *person) Hi() {
fmt.Println("Hi, my age is", p.age)
}
func init() {
fmt.Println("init")
}
func init() {
fmt.Printf("%s\n", "init2")
}

View File

@@ -0,0 +1,74 @@
package data
import (
"fmt"
"strings"
)
const e = 2.7
const (
MinRetries = 2
Timeout = 1
)
func join() {
fmt.Println(strings.Join([]string{"a", "b", "c"}, ", "))
}
// x is a variable
var x = 1
var (
a = 1
y = 2
)
type person struct {
name string
}
// Say is a method
func (p *person) Say() {
fmt.Println("Hello, my name is", p.name)
}
// Hi is a method
func (p *person) Hi() {
fmt.Println("Hi, my name is", p.name)
}
type iSayer interface {
// Say is a method
Say()
// Hi is a method
Hi()
}
type iSayer2 interface {
// Say is a method
Say()
}
type sayer interface{}
type (
fn1 func()
chan1 chan int
)
type _ func(name int) bool
var _ = fn1(func() {})
func Merge(slice [][]byte) {
}
func GetByID(id int) {
}
func init() {
fmt.Println("init")
}

192
pkg/goast/filter.go Normal file
View File

@@ -0,0 +1,192 @@
package goast
import (
"bytes"
"fmt"
"go/ast"
"go/parser"
"go/printer"
"go/token"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
)
// FuncInfo represents function information
type FuncInfo struct {
Name string
Comment string
}
// ExtractComment extracts function comments in Go code
func (f FuncInfo) ExtractComment() string {
if f.Comment == "" {
return ""
}
comment := f.Comment
// regular matching `//` or `/* */` comments
lineComment := regexp.MustCompile(`(?m)^//\s?`)
blockComment := regexp.MustCompile(`(?m)/\*|\*/`)
// remove the `//` or `/* */` tags
comment = lineComment.ReplaceAllString(comment, "")
comment = blockComment.ReplaceAllString(comment, "")
// remove the space at the beginning of the line and split the line
lines := strings.Split(comment, "\n")
for i := range lines {
lines[i] = strings.TrimSpace(lines[i])
}
// output the comment string
commentStr := strings.Join(lines, "\n")
commentStr = strings.TrimSpace(commentStr)
commentStr = strings.TrimPrefix(commentStr, f.Name)
return strings.TrimSpace(commentStr)
}
// containsPanicCall determine if there is a panic("implement me"), or customized flag, e.g. panic("ai to do")
func containsPanicCall(fn *ast.FuncDecl, customFlag ...string) bool {
found := false
ast.Inspect(fn.Body, func(n ast.Node) bool {
ce, ok := n.(*ast.CallExpr)
if !ok {
return true
}
ident, ok := ce.Fun.(*ast.Ident)
if !ok || ident.Name != "panic" {
return true
}
if len(ce.Args) > 0 {
basicLit, ok := ce.Args[0].(*ast.BasicLit)
if ok && basicLit.Kind == token.STRING {
s, err := strconv.Unquote(basicLit.Value)
if err == nil {
if strings.HasPrefix(s, "implement me") {
found = true
return false // stop traversing if you find it.
}
for _, flag := range customFlag {
if strings.Contains(s, flag) {
found = true
return false // stop traversing if you find it.
}
}
}
}
}
return true
})
return found
}
// interval indicates an area to delete
type interval struct {
start token.Pos
end token.Pos
}
// FilterFuncCodeByFile filters out the code of functions that contain panic("implement me") or customized flag, e.g. panic("ai to do")
func FilterFuncCodeByFile(goFilePath string, customFlag ...string) ([]byte, []FuncInfo, error) {
filename := filepath.Base(goFilePath)
data, err := os.ReadFile(goFilePath)
if err != nil {
return nil, nil, err
}
return FilterFuncCode(filename, data, customFlag...)
}
// FilterFuncCode filters out the code of functions that contain panic("implement me") or customized flag, e.g. panic("ai to do")
func FilterFuncCode(filename string, data []byte, customFlag ...string) ([]byte, []FuncInfo, error) {
fset := token.NewFileSet()
// parse source code for comments
f, err := parser.ParseFile(fset, filename, string(data), parser.ParseComments)
if err != nil {
return nil, nil, err
}
// used to record the code interval corresponding to the function to be deleted (including its Doc comment)
var removeIntervals []interval
// used to collect function names and comment information that contain panic ("implementation me")
var panicFuncInfos []FuncInfo
var isMatch bool
// traverse declarations in the file, keeping only qualified function declarations
var decls []ast.Decl
for _, decl := range f.Decls {
funcDecl, ok := decl.(*ast.FuncDecl)
if !ok {
decls = append(decls, decl)
continue
}
// preserve if function name starts with New
if strings.HasPrefix(funcDecl.Name.Name, "New") {
decls = append(decls, decl)
continue
}
// if the function body is nil, it remains
if funcDecl.Body == nil {
decls = append(decls, decl)
continue
}
// if matches call panic("implement me") and has function comment, the function is retained
if containsPanicCall(funcDecl, customFlag...) {
commentText := ""
if funcDecl.Doc != nil {
var parts []string
for _, c := range funcDecl.Doc.List {
parts = append(parts, strings.TrimSpace(c.Text))
}
commentText = strings.Join(parts, "\n")
}
panicFuncInfos = append(panicFuncInfos, FuncInfo{
Name: funcDecl.Name.Name,
Comment: commentText,
})
if commentText != "" {
decls = append(decls, decl)
isMatch = true
}
continue
}
// delete other cases: record deletion interval
start := funcDecl.Pos()
if funcDecl.Doc != nil {
start = funcDecl.Doc.Pos()
}
removeIntervals = append(removeIntervals, interval{start: start, end: funcDecl.End()})
}
if !isMatch {
return nil, nil, fmt.Errorf("no function satisfies both conditions: 1. the function body contains the" +
" panic(\"implement me\") marker, and 2. the function includes a comment describing its functionality")
}
f.Decls = decls
// filter comment groups to remove comments that fall within the deletion interval
var newComments []*ast.CommentGroup
COMMENT:
for _, cg := range f.Comments {
for _, rem := range removeIntervals {
if cg.Pos() >= rem.start && cg.End() <= rem.end {
continue COMMENT
}
}
newComments = append(newComments, cg)
}
f.Comments = newComments
var buf bytes.Buffer
if err = printer.Fprint(&buf, fset, f); err != nil {
return nil, nil, err
}
return buf.Bytes(), panicFuncInfos, nil
}

40
pkg/goast/filter_test.go Normal file
View File

@@ -0,0 +1,40 @@
package goast
import (
"testing"
"github.com/stretchr/testify/assert"
)
// This is a demo function for testing, default panic message is "implement me"
func demoFn1() {
panic("implement me")
}
// This is a demo function for testing, default panic message is "ai todo"
func demoFn2() {
panic("implement me")
}
// This is a demo function for testing, default panic message is "foobar"
func demoFn3() {
panic("implement me")
}
func TestFilterFuncCodeByFile(t *testing.T) {
code, infos, err := FilterFuncCodeByFile("filter_test.go")
assert.NoError(t, err)
assert.NotNil(t, code)
assert.Equal(t, 3, len(infos))
assert.Equal(t, "demoFn1", infos[0].Name)
assert.Contains(t, infos[0].ExtractComment(), `"implement me"`)
//t.Log(code)
code, infos, err = FilterFuncCodeByFile("filter_test.go", "ai todo", "foobar")
assert.NoError(t, err)
assert.NotNil(t, code)
assert.Equal(t, 3, len(infos))
assert.Equal(t, "demoFn3", infos[2].Name)
assert.Contains(t, infos[2].ExtractComment(), `"foobar"`)
//t.Log(code)
}

799
pkg/goast/merge.go Normal file
View File

@@ -0,0 +1,799 @@
package goast
import (
"fmt"
"go/format"
"os"
"strings"
)
type CodeAstOption func(*CodeAst)
func defaultClientOptions() *CodeAst {
return &CodeAst{
ignoreFuncNameMap: make(map[string]struct{}),
}
}
func (a *CodeAst) apply(opts ...CodeAstOption) {
for _, opt := range opts {
opt(a)
}
}
// WithCoverSameFunc sets cover same function in the merged code
func WithCoverSameFunc() CodeAstOption {
return func(a *CodeAst) {
a.isCoverSameFunc = true
}
}
// WithIgnoreMergeFunc sets ignore to merge the same function name in the two code
func WithIgnoreMergeFunc(funcName ...string) CodeAstOption {
return func(a *CodeAst) {
for _, name := range funcName {
a.ignoreFuncNameMap[name] = struct{}{}
}
}
}
// CodeAst is the struct for code
type CodeAst struct {
FilePath string
Code string
AstInfos []*AstInfo
packageInfo *AstInfo
importInfos []*AstInfo
constInfos []*AstInfo
varInfos []*AstInfo
typeInfos []*AstInfo
funcInfos []*AstInfo
nonExistedConstCode []string
nonExistedVarCode []string
nonExistedTypeInfoMap map[string]*TypeInfo // key is type name
mergedStructMethodsMap map[string]struct{} // key is struct name
ignoreFuncNameMap map[string]struct{} // key is function name
changeCodeFlag bool
isCoverSameFunc bool
}
// NewCodeAst creates a new CodeAst object from file path
func NewCodeAst(filePath string, opts ...CodeAstOption) (*CodeAst, error) {
o := defaultClientOptions()
o.apply(opts...)
data, err := os.ReadFile(filePath)
if err != nil {
return nil, err
}
astInfos, err := ParseGoCode(filePath, data)
if err != nil {
return nil, err
}
codeAst := &CodeAst{
FilePath: filePath,
Code: string(data),
AstInfos: astInfos,
nonExistedTypeInfoMap: make(map[string]*TypeInfo),
mergedStructMethodsMap: make(map[string]struct{}),
isCoverSameFunc: o.isCoverSameFunc,
ignoreFuncNameMap: o.ignoreFuncNameMap,
}
codeAst.setSlices()
return codeAst, nil
}
// NewCodeAstFromData creates a new CodeAst object from data
func NewCodeAstFromData(data []byte, opts ...CodeAstOption) (*CodeAst, error) {
o := defaultClientOptions()
o.apply(opts...)
astInfos, err := ParseGoCode("", data)
if err != nil {
return nil, err
}
codeAst := &CodeAst{
Code: string(data),
AstInfos: astInfos,
nonExistedTypeInfoMap: make(map[string]*TypeInfo),
mergedStructMethodsMap: make(map[string]struct{}),
isCoverSameFunc: o.isCoverSameFunc,
ignoreFuncNameMap: o.ignoreFuncNameMap,
}
codeAst.setSlices()
return codeAst, nil
}
func (a *CodeAst) setSlices() {
for _, astInfo := range a.AstInfos {
switch astInfo.Type {
case PackageType:
a.packageInfo = astInfo
case ImportType:
a.importInfos = append(a.importInfos, astInfo)
case ConstType:
a.constInfos = append(a.constInfos, astInfo)
case VarType:
a.varInfos = append(a.varInfos, astInfo)
case TypeType:
a.typeInfos = append(a.typeInfos, astInfo)
case FuncType:
a.funcInfos = append(a.funcInfos, astInfo)
}
}
}
func (a *CodeAst) mergeImportCode(genAst *CodeAst) error {
if len(genAst.importInfos) == 0 {
return nil
}
// 1. append import code to package
if len(a.importInfos) == 0 {
srcStr := a.packageInfo.Body
dstStr := ""
for _, info := range genAst.AstInfos {
if info.IsImportType() {
dstStr += info.Body + "\n"
}
}
if strings.Count(a.Code, srcStr) > 1 {
return errDuplication("mergeImportCode", srcStr)
}
a.Code = strings.Replace(a.Code, srcStr, srcStr+"\n\n"+dstStr, 1)
a.changeCodeFlag = true
return nil
}
// 2. append import code to import
srcImportInfos, err := a.parseImportCode()
if err != nil {
return err
}
genImportInfos, err := genAst.parseImportCode()
if err != nil {
return err
}
srcLen := len(srcImportInfos)
srcImportInfoMap := make(map[string]struct{}, srcLen)
for _, srcIi := range srcImportInfos {
srcImportInfoMap[srcIi.Path] = struct{}{}
}
var nonExistedImportInfos []*ImportInfo
//var nonExistedImportPaths []string
for _, genIfi := range genImportInfos {
if _, ok := srcImportInfoMap[genIfi.Path]; !ok {
nonExistedImportInfos = append(nonExistedImportInfos, genIfi)
//nonExistedImportPaths = append(nonExistedImportPaths, genIfi.Path)
}
}
if len(nonExistedImportInfos) > 0 {
var srcStr = a.packageInfo.Body
var dstStr = "import (\n"
for _, info := range srcImportInfos {
if info.Comment != "" {
dstStr += info.Comment + "\n"
}
dstStr += "\t" + trimBody(info.Body, ImportType) + "\n"
}
for _, info := range nonExistedImportInfos {
if info.Comment != "" {
dstStr += info.Comment + "\n"
}
dstStr += "\t" + trimBody(info.Body, ImportType) + "\n"
}
dstStr += ")"
a.Code = strings.Replace(a.Code, srcStr, srcStr+"\n\n"+dstStr, 1)
a.changeCodeFlag = true
for _, info := range a.importInfos {
a.Code = strings.Replace(a.Code, info.Body, "", 1)
}
}
return nil
}
func (a *CodeAst) compareConstCode(genAst *CodeAst) error {
if len(genAst.constInfos) == 0 {
return nil
}
if len(a.constInfos) == 0 {
dstStr := ""
for _, info := range genAst.AstInfos {
if info.IsConstType() {
if info.Comment != "" {
dstStr += info.Comment + "\n"
}
dstStr += info.Body + "\n"
}
}
a.nonExistedConstCode = append(a.nonExistedConstCode, dstStr)
return nil
}
srcConstInfos, err := a.parseConstCode()
if err != nil {
return err
}
genConstInfos, err := genAst.parseConstCode()
if err != nil {
return err
}
srcLen := len(srcConstInfos)
srcConstInfoMap := make(map[string]struct{}, srcLen)
for _, srcCi := range srcConstInfos {
srcConstInfoMap[srcCi.Name] = struct{}{}
}
var nonExistedConstInfos []*ConstInfo
//var nonExistedConstNames []string
for _, genCi := range genConstInfos {
if _, ok := srcConstInfoMap[genCi.Name]; !ok {
nonExistedConstInfos = append(nonExistedConstInfos, genCi)
//nonExistedConstNames = append(nonExistedConstNames, genCi.Name)
}
}
if len(nonExistedConstInfos) > 0 {
var dstStr = "const (\n"
for _, info := range nonExistedConstInfos {
if info.Comment != "" {
if !strings.HasPrefix(info.Comment, "\t") {
info.Comment = "\t" + info.Comment
}
dstStr += info.Comment + "\n"
}
dstStr += "\t" + trimBody(info.Body, ConstType) + "\n"
}
dstStr += ")\n"
a.nonExistedConstCode = append(a.nonExistedConstCode, dstStr)
}
return nil
}
func (a *CodeAst) compareVarCode(genAst *CodeAst) error {
if len(genAst.varInfos) == 0 {
return nil
}
if len(a.varInfos) == 0 {
dstStr := ""
for _, info := range genAst.AstInfos {
if info.IsVarType() {
if info.Comment != "" {
dstStr += info.Comment + "\n"
}
dstStr += info.Body + "\n"
}
}
a.nonExistedVarCode = append(a.nonExistedVarCode, dstStr)
return nil
}
srcVarInfos, err := a.parseVarCode()
if err != nil {
return err
}
genVarInfos, err := genAst.parseVarCode()
if err != nil {
return err
}
srcLen := len(srcVarInfos)
srcVarInfoMap := make(map[string]struct{}, srcLen)
for _, srcVi := range srcVarInfos {
srcVarInfoMap[srcVi.Name] = struct{}{}
}
var nonExistedVarInfos []*VarInfo
//var nonExistedVarNames []string
for _, genVi := range genVarInfos {
if _, ok := srcVarInfoMap[genVi.Name]; !ok {
nonExistedVarInfos = append(nonExistedVarInfos, genVi)
//nonExistedVarNames = append(nonExistedVarNames, genVi.Name)
} else {
if genVi.Name == "_" && !strings.Contains(a.Code, strings.TrimSpace(genVi.Body)) {
nonExistedVarInfos = append(nonExistedVarInfos, genVi)
//nonExistedVarNames = append(nonExistedVarNames, genVi.Name)
}
}
}
if len(nonExistedVarInfos) > 0 {
var dstStr = "var (\n"
for _, info := range nonExistedVarInfos {
if info.Comment != "" {
if !strings.HasPrefix(info.Comment, "\t") {
info.Comment = "\t" + info.Comment
}
dstStr += info.Comment + "\n"
}
dstStr += "\t" + trimBody(info.Body, VarType) + "\n"
}
dstStr += ")\n"
a.nonExistedVarCode = append(a.nonExistedVarCode, dstStr)
}
return nil
}
func (a *CodeAst) mergeExistedTypeCode(genAst *CodeAst) error {
srcTypeInfosMap, err := a.parseTypeCode()
if err != nil {
return err
}
genTypeInfosMap, err := genAst.parseTypeCode()
if err != nil {
return err
}
var srcTypeNameMap = make(map[string]struct{})
for _, srcTypeInfos := range srcTypeInfosMap {
for _, info := range srcTypeInfos {
srcTypeNameMap[info.Name] = struct{}{}
}
}
var nonExistedTypeInfoMap = make(map[string]*TypeInfo)
for typeName, genTypeInfos := range genTypeInfosMap {
// get non-existed type infos
for _, info := range genTypeInfos {
if _, ok := srcTypeNameMap[info.Name]; !ok {
nonExistedTypeInfoMap[info.Name] = info
}
}
// merge existed interface method code and struct fields code
srcTypeInfos, ok := srcTypeInfosMap[typeName]
if !ok {
continue
}
srcTypeInfoMap := make(map[string]*TypeInfo, len(srcTypeInfos))
for _, srcTi := range srcTypeInfos {
srcTypeInfoMap[srcTi.Name] = srcTi
}
for _, genTypeInfo := range genTypeInfos {
switch genTypeInfo.Type {
case InterfaceType:
if srcTypeInfo, ok := srcTypeInfoMap[genTypeInfo.Name]; ok {
err = a.mergeInterfaceMethodCode(srcTypeInfo, genTypeInfo)
if err != nil {
return err
}
}
case StructType:
if srcTypeInfo, ok := srcTypeInfoMap[genTypeInfo.Name]; ok {
err = a.mergeStructFieldsCode(srcTypeInfo, genTypeInfo)
if err != nil {
return err
}
}
}
}
}
a.nonExistedTypeInfoMap = nonExistedTypeInfoMap
return nil
}
func (a *CodeAst) mergeInterfaceMethodCode(srcTypeInfo *TypeInfo, genTypeInfo *TypeInfo) error {
srcInterfaceInfos, err := ParseInterface(srcTypeInfo.Body)
if err != nil {
return err
}
genInterfaceInfos, err := ParseInterface(genTypeInfo.Body)
if err != nil {
return err
}
if len(srcInterfaceInfos) == 1 && len(genInterfaceInfos) == 1 {
srcLastMethodStr := ""
mLen := len(srcInterfaceInfos[0].MethodInfos)
srcInterfaceMethodInfoMap := make(map[string]struct{}, mLen)
for i, srcIm := range srcInterfaceInfos[0].MethodInfos {
srcInterfaceMethodInfoMap[srcIm.Name] = struct{}{}
if i == mLen-1 {
srcLastMethodStr = srcIm.Body
}
}
var newMethods []string
for _, genMethodInfo := range genInterfaceInfos[0].MethodInfos {
if _, ok := srcInterfaceMethodInfoMap[genMethodInfo.Name]; !ok {
newMethodStr := ""
if genMethodInfo.Comment != "" {
newMethodStr += genMethodInfo.Comment + "\n"
}
newMethodStr += genMethodInfo.Body
newMethods = append(newMethods, newMethodStr)
}
}
if len(newMethods) > 0 {
srcStr := srcTypeInfo.Body
if strings.Count(a.Code, srcStr) > 1 {
return errDuplication("mergeInterfaceMethodCode", srcStr)
}
dstStr := ""
if len(srcInterfaceInfos[0].MethodInfos) == 0 {
dstStr = "type " + srcInterfaceInfos[0].Name + " interface {\n" + strings.Join(newMethods, "\n") + "\n}"
} else {
dstStr = strings.Replace(srcStr, srcLastMethodStr, srcLastMethodStr+"\n"+strings.Join(newMethods, "\n"), 1)
}
a.Code = strings.Replace(a.Code, srcStr, dstStr, 1)
a.changeCodeFlag = true
}
}
return nil
}
func (a *CodeAst) mergeStructFieldsCode(srcTypeInfo *TypeInfo, genTypeInfo *TypeInfo) error {
srcStructInfos, err := ParseStruct(srcTypeInfo.Body)
if err != nil {
return err
}
genStructInfos, err := ParseStruct(genTypeInfo.Body)
if err != nil {
return err
}
for name, genStructInfo := range genStructInfos {
srcStructInfo, ok := srcStructInfos[name]
if !ok {
continue
}
fLen := len(srcStructInfo.Fields)
srcLastFieldStr := ""
srcFieldMap := make(map[string]struct{}, fLen)
for i, field := range srcStructInfo.Fields {
srcFieldMap[field.Name] = struct{}{}
if i == fLen-1 {
srcLastFieldStr = field.Body
}
}
var newFields []string
for _, field := range genStructInfo.Fields {
newFieldStr := ""
if _, ok := srcFieldMap[field.Name]; !ok {
if field.Comment != "" {
newFieldStr += field.Comment + "\n"
}
newFieldStr += field.Body
newFields = append(newFields, newFieldStr)
}
}
if len(newFields) > 0 {
srcStr := srcTypeInfo.Body
if strings.Count(a.Code, srcStr) > 1 {
return errDuplication("mergeStructFieldsCode", srcStr)
}
dstStr := ""
if len(srcStructInfo.Fields) == 0 {
dstStr = "type " + srcStructInfo.Name + " struct {\n" + strings.Join(newFields, "\n") + "\n}"
} else {
dstStr = strings.Replace(srcStr, srcLastFieldStr, srcLastFieldStr+"\n"+strings.Join(newFields, "\n"), 1)
}
a.Code = strings.Replace(a.Code, srcStr, dstStr, 1)
a.changeCodeFlag = true
}
}
return nil
}
func (a *CodeAst) mergeStructMethodsCode(genAst *CodeAst) error {
srcImportInfoMap := ParseStructMethods(a.AstInfos)
genImportInfoMap := ParseStructMethods(genAst.AstInfos)
for structName, genMethods := range genImportInfoMap {
var nonExistedImports []string
var lastMethodFuncCode string
if srcMethods, ok := srcImportInfoMap[structName]; ok {
var srcMethodMap = make(map[string]struct{}, len(srcMethods))
for i, srcMethod := range srcMethods {
srcMethodMap[srcMethod.Name] = struct{}{}
if i == len(srcMethods)-1 {
lastMethodFuncCode = srcMethod.Body
}
}
for _, genMethod := range genMethods {
if _, isExisted := srcMethodMap[genMethod.Name]; !isExisted {
nonExistedImports = append(nonExistedImports, genMethod.Comment+"\n"+genMethod.Body)
}
}
}
if len(nonExistedImports) > 0 {
srcStr := lastMethodFuncCode
if strings.Count(a.Code, srcStr) > 1 {
return errDuplication("mergeStructMethodsCode", srcStr)
}
dstStr := lastMethodFuncCode + "\n\n" + strings.Join(nonExistedImports, "\n\n")
a.Code = strings.Replace(a.Code, srcStr, dstStr, 1)
a.changeCodeFlag = true
a.mergedStructMethodsMap[structName] = struct{}{}
}
}
return nil
}
func (a *CodeAst) coverFuncCode(genAst *CodeAst) {
var srcFuncNameMap = make(map[string]*AstInfo)
for _, srcFuncInfo := range a.funcInfos {
srcFuncNameMap[srcFuncInfo.GetName()] = srcFuncInfo
}
for _, genFuncInfo := range genAst.funcInfos {
genFuncName := genFuncInfo.GetName()
if genFuncName == "init" || genFuncName == "_" {
continue
}
var ignoreFuncName string
if len(genFuncInfo.Names) == 2 {
ignoreFuncName = genFuncInfo.Names[0]
} else {
ignoreFuncName = genFuncName
}
if _, ok := a.ignoreFuncNameMap[ignoreFuncName]; ok {
continue
}
if srcFuncInfo, ok := srcFuncNameMap[genFuncName]; ok {
srcStr := srcFuncInfo.Body
dstStr := genFuncInfo.Body
comment := ""
if srcFuncInfo.Comment == "" && genFuncInfo.Comment != "" {
comment = genFuncInfo.Comment
}
if comment != "" {
dstStr = comment + "\n" + dstStr
}
a.Code = strings.Replace(a.Code, srcStr, dstStr, 1)
a.changeCodeFlag = true
}
}
}
// appends non-existed code to the end of the source code.
func (a *CodeAst) appendNonExistedCode(genAsts []*AstInfo) error { // nolint
srcNameAstMap := make(map[string]struct{}, len(a.AstInfos))
for _, info := range a.AstInfos {
srcNameAstMap[info.GetName()] = struct{}{}
}
if len(a.nonExistedConstCode) > 0 {
a.Code += "\n" + strings.Join(a.nonExistedConstCode, "\n")
a.changeCodeFlag = true
}
if len(a.nonExistedVarCode) > 0 {
a.Code += "\n" + strings.Join(a.nonExistedVarCode, "\n")
a.changeCodeFlag = true
}
var appendCodes []string
for _, genAst := range genAsts {
if genAst.IsPackageType() || genAst.IsImportType() || genAst.IsConstType() || genAst.IsVarType() {
continue
}
if genAst.IsFuncType() && len(genAst.Names) == 2 {
if _, ok := a.mergedStructMethodsMap[genAst.Names[1]]; ok {
continue
}
}
isNeedAppend := false
name := genAst.GetName()
if _, ok := srcNameAstMap[name]; !ok {
if genAst.IsTypeType() && len(genAst.Names) > 1 {
var nonExistedTypes []string
for _, name := range genAst.Names {
var dstStr string
if info, ok := a.nonExistedTypeInfoMap[name]; ok {
if info.Comment != "" {
dstStr += info.Comment + "\n"
}
dstStr += info.Body
nonExistedTypes = append(nonExistedTypes, dstStr)
}
}
if len(nonExistedTypes) > 0 {
appendCodes = append(appendCodes, "type (\n"+strings.Join(nonExistedTypes, "\n")+"\n)")
continue
}
}
isNeedAppend = true
} else {
if name == "_" && !strings.Contains(a.Code, genAst.Body) {
isNeedAppend = true
}
if genAst.IsFuncType() && name == "init" && !strings.Contains(a.Code, genAst.Body) {
isNeedAppend = true
}
}
if isNeedAppend {
comment := ""
if genAst.Comment != "" {
comment = genAst.Comment
}
appendCodes = append(appendCodes, comment+"\n"+genAst.Body)
}
}
if len(appendCodes) > 0 {
a.Code += strings.Join(appendCodes, "\n\n") + "\n"
a.changeCodeFlag = true
}
return nil
}
func (a *CodeAst) parseImportCode() ([]*ImportInfo, error) {
body := ""
for _, info := range a.importInfos {
if info.Comment != "" {
body += info.Comment + "\n"
}
body += info.Body + "\n"
}
return ParseImportGroup(body)
}
func (a *CodeAst) parseConstCode() ([]*ConstInfo, error) {
body := ""
for _, info := range a.constInfos {
if info.Comment != "" {
body += info.Comment + "\n"
}
body += info.Body + "\n"
}
return ParseConstGroup(body)
}
func (a *CodeAst) parseVarCode() ([]*VarInfo, error) {
body := ""
for _, info := range a.varInfos {
if info.Comment != "" {
body += info.Comment + "\n"
}
body += info.Body + "\n"
}
return ParseVarGroup(body)
}
func (a *CodeAst) parseTypeCode() (map[string][]*TypeInfo, error) {
body := ""
for _, info := range a.typeInfos {
if info.Comment != "" {
body += info.Comment + "\n"
}
body += info.Body + "\n"
}
typeInfos, err := ParseTypeGroup(body)
if err != nil {
return nil, err
}
typeMap := make(map[string][]*TypeInfo, len(typeInfos))
for _, info := range typeInfos {
if info.Name == "" {
continue
}
if _, ok := typeMap[info.Name]; !ok {
typeMap[info.Name] = []*TypeInfo{}
}
typeMap[info.Name] = append(typeMap[info.Name], info)
}
return typeMap, nil
}
func errDuplication(marker string, srcStr string) error {
return fmt.Errorf("%s: multiple duplicate string `%s` exists, please modify the source code to ensure uniqueness", marker, srcStr)
}
func trimBody(body string, codeType string) string {
body = strings.TrimSpace(body)
return strings.TrimPrefix(body, codeType+" ")
}
// MergeGoFile merges two Go code files into one.
func MergeGoFile(srcFile string, genFile string, opts ...CodeAstOption) (*CodeAst, error) {
srcAst, err := NewCodeAst(srcFile, opts...)
if err != nil {
return nil, err
}
genAst, err := NewCodeAst(genFile)
if err != nil {
return nil, err
}
return mergeCode(srcAst, genAst)
}
// MergeGoCode merges two Go code strings into one.
func MergeGoCode(srcCode []byte, genCode []byte, opts ...CodeAstOption) (*CodeAst, error) {
srcAst, err := NewCodeAstFromData(srcCode, opts...)
if err != nil {
return nil, err
}
genAst, err := NewCodeAstFromData(genCode, opts...)
if err != nil {
return nil, err
}
return mergeCode(srcAst, genAst)
}
func mergeCode(srcAst *CodeAst, genAst *CodeAst) (*CodeAst, error) {
// merge import code
err := srcAst.mergeImportCode(genAst)
if err != nil {
return nil, err
}
// compare const code
err = srcAst.compareConstCode(genAst)
if err != nil {
return nil, err
}
// compare var code
err = srcAst.compareVarCode(genAst)
if err != nil {
return nil, err
}
// merge interface method and struct fields code
err = srcAst.mergeExistedTypeCode(genAst)
if err != nil {
return nil, err
}
// merge struct method function code
err = srcAst.mergeStructMethodsCode(genAst)
if err != nil {
return nil, err
}
if srcAst.isCoverSameFunc {
// cover same function code
srcAst.coverFuncCode(genAst)
}
// append non-existed code
err = srcAst.appendNonExistedCode(genAst.AstInfos)
if err != nil {
return nil, err
}
if srcAst.changeCodeFlag {
data, err := format.Source([]byte(srcAst.Code))
if err == nil {
srcAst.Code = string(data)
}
}
return srcAst, nil
}

110
pkg/goast/merge_test.go Normal file
View File

@@ -0,0 +1,110 @@
package goast
import (
"fmt"
"os"
"testing"
"github.com/stretchr/testify/assert"
)
const (
srcFile = "data/src.go.code"
genFile = "data/gen.go.code"
)
func TestMergeGoFile(t *testing.T) {
t.Run("without cover same func", func(t *testing.T) {
codeAst, err := MergeGoFile(srcFile, genFile)
if err != nil {
t.Error(err)
return
}
assert.Equal(t, codeAst.changeCodeFlag, true)
fmt.Println(codeAst.Code)
})
t.Run("cover same func", func(t *testing.T) {
codeAst, err := MergeGoFile(srcFile, genFile, WithCoverSameFunc())
if err != nil {
t.Error(err)
return
}
assert.Equal(t, codeAst.changeCodeFlag, true)
fmt.Println(codeAst.Code)
})
t.Run("test same file", func(t *testing.T) {
codeAst, err := MergeGoFile(srcFile, srcFile)
if err != nil {
t.Error(err)
return
}
assert.Equal(t, codeAst.changeCodeFlag, false)
})
}
func TestMergeGoCode(t *testing.T) {
t.Run("without cover same func", func(t *testing.T) {
srcData, genData, _ := getGoCode()
codeAst, err := MergeGoCode(srcData, genData)
if err != nil {
t.Error(err)
return
}
assert.Equal(t, codeAst.changeCodeFlag, true)
fmt.Println(codeAst.Code)
})
t.Run("cover same func", func(t *testing.T) {
srcData, genData, _ := getGoCode()
codeAst, err := MergeGoCode(srcData, genData,
WithCoverSameFunc(),
WithIgnoreMergeFunc("GetByID", "Hi"))
if err != nil {
t.Error(err)
return
}
assert.Equal(t, codeAst.changeCodeFlag, true)
fmt.Println(codeAst.Code)
})
t.Run("test same file", func(t *testing.T) {
srcData, _, _ := getGoCode()
codeAst, err := MergeGoCode(srcData, srcData)
if err != nil {
t.Error(err)
return
}
assert.Equal(t, codeAst.changeCodeFlag, false)
})
}
func TestNewCodeAstFromData(t *testing.T) {
srcData, _, err := getGoCode()
if err != nil {
t.Error(err)
return
}
codeAst, err := NewCodeAstFromData(srcData)
if err != nil {
t.Error(err)
return
}
codeAst.FilePath = srcFile
assert.Equal(t, true, len(codeAst.AstInfos) > 0)
}
func getGoCode() ([]byte, []byte, error) {
srcData, err := os.ReadFile(srcFile)
if err != nil {
return nil, nil, err
}
genData, err := os.ReadFile(genFile)
if err != nil {
return nil, nil, err
}
return srcData, genData, nil
}