mirror of
https://github.com/zhufuyi/sponge.git
synced 2025-09-26 20:51:14 +08:00
feat: add Go AST libary
This commit is contained in:
69
pkg/goast/README.md
Normal file
69
pkg/goast/README.md
Normal file
@@ -0,0 +1,69 @@
|
||||
## goast
|
||||
|
||||
`goast` is a library for parsing Go code and extracting information, it supports merging two Go files into one.
|
||||
|
||||
## Example of use
|
||||
|
||||
### Parse Go code and extract information
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/go-dev-frame/sponge/pkg/goast"
|
||||
)
|
||||
|
||||
func main() {
|
||||
src := []byte(`package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func main() {
|
||||
fmt.Println("Hello, world!")
|
||||
}
|
||||
`)
|
||||
|
||||
// Case 1: Parse Go code and extract information
|
||||
{
|
||||
astInfos, err := goast.ParseGoCode("main.go", src)
|
||||
}
|
||||
|
||||
// Case 2: Parse file and extract information
|
||||
{
|
||||
astInfos, err := goast.ParseFile("main.go")
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Merge two Go files into one
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/go-dev-frame/sponge/pkg/goast"
|
||||
)
|
||||
|
||||
func main() {
|
||||
const (
|
||||
srcFile = "data/src.go.code"
|
||||
genFile = "data/gen.go.code"
|
||||
)
|
||||
|
||||
// Case 1: without covering the same function
|
||||
{
|
||||
codeAst, err := goast.MergeGoFile(srcFile, genFile)
|
||||
fmt.Println(codeAst.Code)
|
||||
}
|
||||
|
||||
// Case 2: with covering the same function
|
||||
{
|
||||
codeAst, err := goast.MergeGoFile(srcFile, genFile, goast.WithCoverSameFunc())
|
||||
fmt.Println(codeAst.Code)
|
||||
}
|
||||
}
|
||||
```
|
814
pkg/goast/ast.go
Normal file
814
pkg/goast/ast.go
Normal file
@@ -0,0 +1,814 @@
|
||||
// Package goast is a library for parsing Go code and extracting information from it.
|
||||
package goast
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// ast types
|
||||
|
||||
PackageType = "package"
|
||||
ImportType = "import"
|
||||
ConstType = "const"
|
||||
VarType = "var"
|
||||
FuncType = "func"
|
||||
TypeType = "type"
|
||||
|
||||
// for TypeType
|
||||
|
||||
StructType = "struct"
|
||||
InterfaceType = "interface"
|
||||
ArrayType = "array"
|
||||
MapType = "map"
|
||||
ChanType = "chan"
|
||||
)
|
||||
|
||||
// AstInfo Go code block information
|
||||
type AstInfo struct {
|
||||
// Type is the type of the code block, such as "func", "type", "const", "var", "import", "package".
|
||||
Type string
|
||||
|
||||
// Names is the name of the code block, such as "func Name", "type Names", "const Names", "var Names", "import Paths".
|
||||
// If Type is "func", a standalone function without a receiver has a single name.
|
||||
// If the function is a method belonging to a struct, it has two names: the first
|
||||
// represents the function name, and the second represents the struct name.
|
||||
Names []string
|
||||
|
||||
Comment string
|
||||
|
||||
Body string
|
||||
}
|
||||
|
||||
func (a *AstInfo) IsPackageType() bool {
|
||||
return a.Type == PackageType
|
||||
}
|
||||
func (a *AstInfo) IsImportType() bool {
|
||||
return a.Type == ImportType
|
||||
}
|
||||
func (a *AstInfo) IsConstType() bool {
|
||||
return a.Type == ConstType
|
||||
}
|
||||
func (a *AstInfo) IsVarType() bool {
|
||||
return a.Type == VarType
|
||||
}
|
||||
func (a *AstInfo) IsTypeType() bool {
|
||||
return a.Type == TypeType
|
||||
}
|
||||
func (a *AstInfo) IsFuncType() bool {
|
||||
return a.Type == FuncType
|
||||
}
|
||||
|
||||
func (a *AstInfo) GetName() string {
|
||||
return strings.Join(a.Names, ",")
|
||||
}
|
||||
|
||||
// ParseFile parses a go file and returns a list of AstInfo
|
||||
func ParseFile(goFilePath string) ([]*AstInfo, error) {
|
||||
filename := filepath.Base(goFilePath)
|
||||
data, err := os.ReadFile(goFilePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ParseGoCode(filename, data)
|
||||
}
|
||||
|
||||
// ParseGoCode parses a go code and returns a list of AstInfo
|
||||
func ParseGoCode(filename string, data []byte) ([]*AstInfo, error) {
|
||||
src := string(data)
|
||||
fset := token.NewFileSet()
|
||||
file, err := parser.ParseFile(fset, filename, src, parser.ParseComments)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var astInfos []*AstInfo
|
||||
|
||||
pkgNames, pkgComment, pkgBody := getPackageCode(fset, file, src)
|
||||
astInfos = append(astInfos, &AstInfo{Type: PackageType, Names: pkgNames, Comment: pkgComment, Body: pkgBody})
|
||||
|
||||
// traverse AST code blocks
|
||||
ast.Inspect(file, func(n ast.Node) bool {
|
||||
switch node := n.(type) {
|
||||
case *ast.FuncDecl:
|
||||
receiverName, comment, body := getFuncDeclCode(fset, node, src)
|
||||
names := []string{node.Name.Name}
|
||||
if receiverName != "" {
|
||||
names = append(names, receiverName)
|
||||
}
|
||||
astInfos = append(astInfos, &AstInfo{Type: FuncType, Names: names, Comment: comment, Body: body})
|
||||
|
||||
case *ast.GenDecl:
|
||||
names, comment, body := getGenDeclCode(fset, node, src)
|
||||
astInfos = append(astInfos, &AstInfo{Type: node.Tok.String(), Names: names, Comment: comment, Body: body})
|
||||
|
||||
//case *ast.BadDecl:
|
||||
// code := getBadDeclCode(fset, node, src)
|
||||
// println(code)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return astInfos, nil
|
||||
}
|
||||
|
||||
func getPackageCode(fset *token.FileSet, f *ast.File, src string) (names []string, comment string, body string) {
|
||||
if f.Doc != nil {
|
||||
var comments []string
|
||||
for _, cmt := range f.Doc.List {
|
||||
comments = append(comments, cmt.Text)
|
||||
}
|
||||
comment = strings.Join(comments, "\n")
|
||||
}
|
||||
|
||||
packagePos := fset.Position(f.Package)
|
||||
body = src[packagePos.Offset : packagePos.Offset+len("package "+f.Name.Name)]
|
||||
|
||||
return []string{f.Name.Name}, comment, body
|
||||
}
|
||||
|
||||
func getFuncDeclCode(fset *token.FileSet, fn *ast.FuncDecl, src string) (receiverName string, comment string, body string) {
|
||||
if fn == nil {
|
||||
return "", "", ""
|
||||
}
|
||||
|
||||
if fn.Recv != nil && len(fn.Recv.List) > 0 {
|
||||
recvType := fn.Recv.List[0].Type
|
||||
switch t := recvType.(type) {
|
||||
case *ast.StarExpr:
|
||||
if ident, ok := t.X.(*ast.Ident); ok {
|
||||
receiverName = ident.Name
|
||||
}
|
||||
case *ast.Ident:
|
||||
receiverName = t.Name
|
||||
}
|
||||
}
|
||||
|
||||
commentText := ""
|
||||
if fn.Doc != nil {
|
||||
var parts []string
|
||||
for _, c := range fn.Doc.List {
|
||||
parts = append(parts, strings.TrimSpace(c.Text))
|
||||
}
|
||||
commentText = strings.Join(parts, "\n")
|
||||
}
|
||||
|
||||
start := fn.Type.Func // the starting position of the func keyword
|
||||
end := fn.Body.Rbrace + 1 // end position of function body
|
||||
return receiverName, commentText, getCodeFromPos(fset, start, end, src)
|
||||
}
|
||||
|
||||
func getCodeFromPos(fset *token.FileSet, start, end token.Pos, src string) string {
|
||||
file := fset.File(start)
|
||||
if file == nil {
|
||||
return ""
|
||||
}
|
||||
startOffset := file.Offset(start)
|
||||
endOffset := file.Offset(end)
|
||||
if startOffset < 0 || endOffset > len(src) || startOffset >= endOffset {
|
||||
return ""
|
||||
}
|
||||
return src[startOffset:endOffset]
|
||||
}
|
||||
|
||||
func getGenDeclCode(fset *token.FileSet, gen *ast.GenDecl, src string) (names []string, comment string, body string) {
|
||||
if gen == nil {
|
||||
return nil, "", ""
|
||||
}
|
||||
|
||||
commentText := ""
|
||||
if gen.Doc != nil {
|
||||
var parts []string
|
||||
for _, c := range gen.Doc.List {
|
||||
parts = append(parts, strings.TrimSpace(c.Text))
|
||||
}
|
||||
commentText = strings.Join(parts, "\n")
|
||||
}
|
||||
|
||||
start := gen.TokPos // keyword starting position
|
||||
var end token.Pos
|
||||
if gen.Rparen.IsValid() {
|
||||
end = gen.Rparen + 1 // end position of parentheses
|
||||
} else if len(gen.Specs) > 0 {
|
||||
lastSpec := gen.Specs[len(gen.Specs)-1]
|
||||
end = lastSpec.End() // end position of the last Spec
|
||||
} else {
|
||||
end = start + token.Pos(len(gen.Tok.String())) // in the case of keywords only
|
||||
}
|
||||
|
||||
return getGenName(gen), commentText, getCodeFromPos(fset, start, end, src)
|
||||
}
|
||||
|
||||
func getGenName(gen *ast.GenDecl) []string {
|
||||
var names []string
|
||||
switch gen.Tok {
|
||||
case token.IMPORT:
|
||||
for _, spec := range gen.Specs {
|
||||
imp := spec.(*ast.ImportSpec)
|
||||
names = append(names, imp.Path.Value)
|
||||
}
|
||||
|
||||
case token.CONST:
|
||||
for _, spec := range gen.Specs {
|
||||
val := spec.(*ast.ValueSpec)
|
||||
names = append(names, val.Names[0].Name)
|
||||
}
|
||||
|
||||
case token.TYPE:
|
||||
for _, spec := range gen.Specs {
|
||||
typ := spec.(*ast.TypeSpec)
|
||||
names = append(names, typ.Name.Name)
|
||||
}
|
||||
|
||||
case token.VAR:
|
||||
for _, spec := range gen.Specs {
|
||||
val := spec.(*ast.ValueSpec)
|
||||
names = append(names, val.Names[0].Name)
|
||||
}
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------------
|
||||
|
||||
func adaptPackage(src string) string {
|
||||
if len(src) > 50 {
|
||||
if strings.Contains(src[:50], "package ") {
|
||||
return src
|
||||
}
|
||||
}
|
||||
if strings.Contains(src, "\npackage ") {
|
||||
return src
|
||||
}
|
||||
return "package parse\n\n" + src
|
||||
}
|
||||
|
||||
// nolint
|
||||
func parseBody(body string) (*token.FileSet, *ast.File, string, error) {
|
||||
src := adaptPackage(body)
|
||||
fset := token.NewFileSet()
|
||||
f, err := parser.ParseFile(fset, "", src, parser.ParseComments)
|
||||
if err != nil {
|
||||
f, err = parser.ParseFile(fset, "", "package parse\n\n"+src, parser.ParseComments)
|
||||
if err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
}
|
||||
return fset, f, src, nil
|
||||
}
|
||||
|
||||
type ImportInfo struct {
|
||||
Path string
|
||||
Alias string
|
||||
Comment string
|
||||
Body string
|
||||
}
|
||||
|
||||
// ParseImportGroup parse import group from source code
|
||||
func ParseImportGroup(body string) ([]*ImportInfo, error) {
|
||||
fset, f, src, err := parseBody(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var srcLines = strings.Split(src, "\n")
|
||||
var imports []*ImportInfo
|
||||
|
||||
for _, decl := range f.Decls {
|
||||
genDecl, ok := decl.(*ast.GenDecl)
|
||||
if !ok || genDecl.Tok != token.IMPORT {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, spec := range genDecl.Specs {
|
||||
importSpec := spec.(*ast.ImportSpec)
|
||||
|
||||
// get path
|
||||
path := strings.Trim(importSpec.Path.Value, `"`)
|
||||
|
||||
// get alias
|
||||
var alias string
|
||||
if importSpec.Name != nil {
|
||||
alias = importSpec.Name.Name
|
||||
}
|
||||
|
||||
// get comment doc
|
||||
var comment string
|
||||
if importSpec.Doc != nil {
|
||||
comment = getSrcContent(srcLines, fset.Position(importSpec.Doc.List[0].Pos()).Line,
|
||||
fset.Position(importSpec.Doc.List[len(importSpec.Doc.List)-1].End()).Line)
|
||||
}
|
||||
|
||||
// get source code of import path
|
||||
code := getSrcContent(srcLines, fset.Position(importSpec.Pos()).Line, fset.Position(importSpec.End()).Line)
|
||||
|
||||
imports = append(imports, &ImportInfo{
|
||||
Path: path,
|
||||
Alias: alias,
|
||||
Comment: comment,
|
||||
Body: code,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return imports, nil
|
||||
}
|
||||
|
||||
type ConstInfo struct {
|
||||
Name string
|
||||
Value string
|
||||
Comment string
|
||||
Body string
|
||||
}
|
||||
|
||||
// ParseConstGroup parse const group from source code
|
||||
func ParseConstGroup(body string) ([]*ConstInfo, error) {
|
||||
fset, f, src, err := parseBody(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var srcLines = strings.Split(src, "\n")
|
||||
var consts []*ConstInfo
|
||||
|
||||
for _, decl := range f.Decls {
|
||||
genDecl, ok := decl.(*ast.GenDecl)
|
||||
if !ok || genDecl.Tok != token.CONST {
|
||||
continue
|
||||
}
|
||||
|
||||
singleComment := ""
|
||||
if genDecl.Doc != nil {
|
||||
singleComment = getSrcContent(srcLines, fset.Position(genDecl.Doc.List[0].Pos()).Line,
|
||||
fset.Position(genDecl.Doc.List[len(genDecl.Doc.List)-1].End()).Line)
|
||||
}
|
||||
|
||||
for _, spec := range genDecl.Specs {
|
||||
valueSpec := spec.(*ast.ValueSpec)
|
||||
for i, name := range valueSpec.Names {
|
||||
constName := name.Name
|
||||
|
||||
// get line content
|
||||
var comment string
|
||||
if valueSpec.Doc != nil {
|
||||
comment = getSrcContent(srcLines, fset.Position(valueSpec.Doc.List[0].Pos()).Line,
|
||||
fset.Position(valueSpec.Doc.List[len(valueSpec.Doc.List)-1].End()).Line)
|
||||
}
|
||||
if len(genDecl.Specs) == 1 && singleComment != "" && comment == "" {
|
||||
comment = singleComment
|
||||
}
|
||||
|
||||
// get code content
|
||||
code := getSrcContent(srcLines, fset.Position(valueSpec.Pos()).Line, fset.Position(valueSpec.End()).Line)
|
||||
|
||||
// get value (if exists)
|
||||
var constValue string
|
||||
if i < len(valueSpec.Values) {
|
||||
if basicLit, ok := valueSpec.Values[i].(*ast.BasicLit); ok {
|
||||
constValue = basicLit.Value
|
||||
}
|
||||
}
|
||||
|
||||
consts = append(consts, &ConstInfo{
|
||||
Name: constName,
|
||||
Value: constValue,
|
||||
Comment: comment,
|
||||
Body: code,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return consts, nil
|
||||
}
|
||||
|
||||
type VarInfo struct {
|
||||
Name string
|
||||
Value string
|
||||
Comment string
|
||||
Body string
|
||||
}
|
||||
|
||||
// ParseVarGroup parse var group from source code
|
||||
func ParseVarGroup(body string) ([]*VarInfo, error) {
|
||||
fset, f, src, err := parseBody(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var srcLines = strings.Split(src, "\n")
|
||||
var vars []*VarInfo
|
||||
|
||||
for _, decl := range f.Decls {
|
||||
genDecl, ok := decl.(*ast.GenDecl)
|
||||
if !ok || genDecl.Tok != token.VAR {
|
||||
continue
|
||||
}
|
||||
|
||||
singleComment := ""
|
||||
if genDecl.Doc != nil {
|
||||
singleComment = getSrcContent(srcLines, fset.Position(genDecl.Doc.List[0].Pos()).Line,
|
||||
fset.Position(genDecl.Doc.List[len(genDecl.Doc.List)-1].End()).Line)
|
||||
}
|
||||
|
||||
for _, spec := range genDecl.Specs {
|
||||
valueSpec := spec.(*ast.ValueSpec)
|
||||
for i, name := range valueSpec.Names {
|
||||
varName := name.Name
|
||||
|
||||
// get comment
|
||||
var comment string
|
||||
if valueSpec.Doc != nil {
|
||||
comment = getSrcContent(srcLines, fset.Position(valueSpec.Doc.List[0].Pos()).Line,
|
||||
fset.Position(valueSpec.Doc.List[len(valueSpec.Doc.List)-1].End()).Line)
|
||||
}
|
||||
if len(genDecl.Specs) == 1 && singleComment != "" && comment == "" {
|
||||
comment = singleComment
|
||||
}
|
||||
|
||||
// get code content
|
||||
code := getSrcContent(srcLines, fset.Position(valueSpec.Pos()).Line, fset.Position(valueSpec.End()).Line)
|
||||
|
||||
// get var value (if exists)
|
||||
var varValue string
|
||||
if i < len(valueSpec.Values) {
|
||||
if basicLit, ok := valueSpec.Values[i].(*ast.BasicLit); ok {
|
||||
varValue = basicLit.Value
|
||||
}
|
||||
}
|
||||
|
||||
vars = append(vars, &VarInfo{
|
||||
Name: varName,
|
||||
Value: varValue,
|
||||
Comment: comment,
|
||||
Body: code,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return vars, nil
|
||||
}
|
||||
|
||||
type TypeInfo struct {
|
||||
Type string
|
||||
Name string
|
||||
Comment string
|
||||
Body string
|
||||
IsIdent bool
|
||||
}
|
||||
|
||||
// ParseTypeGroup parse type group from source code
|
||||
func ParseTypeGroup(body string) ([]*TypeInfo, error) {
|
||||
fset, f, src, err := parseBody(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var srcLines = strings.Split(src, "\n")
|
||||
var types []*TypeInfo
|
||||
|
||||
for _, decl := range f.Decls {
|
||||
genDecl, ok := decl.(*ast.GenDecl)
|
||||
if !ok || genDecl.Tok != token.TYPE {
|
||||
continue
|
||||
}
|
||||
|
||||
singleComment := ""
|
||||
if genDecl.Doc != nil {
|
||||
singleComment = getSrcContent(srcLines, fset.Position(genDecl.Doc.List[0].Pos()).Line,
|
||||
fset.Position(genDecl.Doc.List[len(genDecl.Doc.List)-1].End()).Line)
|
||||
}
|
||||
|
||||
for _, spec := range genDecl.Specs {
|
||||
typeSpec := spec.(*ast.TypeSpec)
|
||||
typeName := typeSpec.Name.Name
|
||||
|
||||
// get comment
|
||||
var comment string
|
||||
if typeSpec.Doc != nil {
|
||||
comment = getSrcContent(srcLines, fset.Position(typeSpec.Doc.List[0].Pos()).Line,
|
||||
fset.Position(typeSpec.Doc.List[len(typeSpec.Doc.List)-1].End()).Line)
|
||||
}
|
||||
if len(genDecl.Specs) == 1 && singleComment != "" && comment == "" {
|
||||
comment = singleComment
|
||||
}
|
||||
|
||||
// get code content
|
||||
code := getSrcContent(srcLines, fset.Position(typeSpec.Pos()).Line, fset.Position(typeSpec.End()).Line)
|
||||
|
||||
// get type definition
|
||||
var isIdent bool
|
||||
var typeDef string
|
||||
switch t := typeSpec.Type.(type) {
|
||||
case *ast.StructType:
|
||||
typeDef = StructType
|
||||
case *ast.InterfaceType:
|
||||
typeDef = InterfaceType
|
||||
case *ast.FuncType:
|
||||
typeDef = FuncType
|
||||
case *ast.MapType:
|
||||
typeDef = MapType
|
||||
case *ast.ArrayType:
|
||||
typeDef = ArrayType
|
||||
case *ast.ChanType:
|
||||
typeDef = ChanType
|
||||
case *ast.Ident:
|
||||
typeDef = t.Name
|
||||
isIdent = true
|
||||
default:
|
||||
typeDef = fmt.Sprintf("%T", t)
|
||||
}
|
||||
|
||||
types = append(types, &TypeInfo{
|
||||
Type: typeDef,
|
||||
Name: typeName,
|
||||
Comment: comment,
|
||||
Body: code,
|
||||
IsIdent: isIdent,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return types, nil
|
||||
}
|
||||
|
||||
type InterfaceInfo struct {
|
||||
Name string
|
||||
Comment string
|
||||
MethodInfos []*MethodInfo
|
||||
}
|
||||
|
||||
// ParseInterface parse interface group from source code
|
||||
func ParseInterface(body string) ([]*InterfaceInfo, error) {
|
||||
fset, f, src, err := parseBody(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var srcLines = strings.Split(src, "\n")
|
||||
var interfaceInfos []*InterfaceInfo
|
||||
|
||||
for _, decl := range f.Decls {
|
||||
genDecl, ok := decl.(*ast.GenDecl)
|
||||
if !ok || genDecl.Tok != token.TYPE {
|
||||
continue
|
||||
}
|
||||
|
||||
singleComment := ""
|
||||
if genDecl.Doc != nil {
|
||||
singleComment = getSrcContent(srcLines, fset.Position(genDecl.Doc.List[0].Pos()).Line,
|
||||
fset.Position(genDecl.Doc.List[len(genDecl.Doc.List)-1].End()).Line)
|
||||
}
|
||||
|
||||
var methodInfos []*MethodInfo
|
||||
for _, spec := range genDecl.Specs {
|
||||
typeSpec := spec.(*ast.TypeSpec)
|
||||
interfaceType, ok := typeSpec.Type.(*ast.InterfaceType)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
interfaceName := typeSpec.Name.Name
|
||||
|
||||
// get interface comment
|
||||
var interfaceComment string
|
||||
if typeSpec.Doc != nil {
|
||||
interfaceComment = getSrcContent(srcLines, fset.Position(typeSpec.Doc.List[0].Pos()).Line,
|
||||
fset.Position(typeSpec.Doc.List[len(typeSpec.Doc.List)-1].End()).Line)
|
||||
}
|
||||
if len(genDecl.Specs) == 1 && singleComment != "" && interfaceComment == "" {
|
||||
interfaceComment = singleComment
|
||||
}
|
||||
|
||||
var isIdent bool
|
||||
for _, method := range interfaceType.Methods.List {
|
||||
// get method name
|
||||
var methodName string
|
||||
switch t := method.Type.(type) {
|
||||
case *ast.FuncType:
|
||||
if len(method.Names) > 0 {
|
||||
methodName = method.Names[0].Name
|
||||
}
|
||||
case *ast.Ident: // embedded interface
|
||||
methodName = t.Name
|
||||
isIdent = true
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
// get method comment
|
||||
var methodComment string
|
||||
if method.Doc != nil {
|
||||
methodComment = getSrcContent(srcLines, fset.Position(method.Doc.List[0].Pos()).Line,
|
||||
fset.Position(method.Doc.List[len(method.Doc.List)-1].End()).Line)
|
||||
}
|
||||
|
||||
// get method line content
|
||||
code := getSrcContent(srcLines, fset.Position(method.Pos()).Line, fset.Position(method.End()).Line)
|
||||
|
||||
methodInfos = append(methodInfos, &MethodInfo{
|
||||
Name: methodName,
|
||||
Comment: methodComment,
|
||||
Body: code,
|
||||
ReceiverName: interfaceName,
|
||||
IsIdent: isIdent,
|
||||
})
|
||||
}
|
||||
interfaceInfos = append(interfaceInfos, &InterfaceInfo{
|
||||
Name: interfaceName,
|
||||
Comment: interfaceComment,
|
||||
MethodInfos: methodInfos,
|
||||
})
|
||||
}
|
||||
}
|
||||
return interfaceInfos, nil
|
||||
}
|
||||
|
||||
// MethodInfo method function info
|
||||
type MethodInfo struct {
|
||||
Name string
|
||||
Comment string
|
||||
Body string
|
||||
ReceiverName string
|
||||
IsIdent bool
|
||||
}
|
||||
|
||||
// ParseStructMethods parse struct methods from ast infos
|
||||
func ParseStructMethods(astInfos []*AstInfo) map[string][]*MethodInfo {
|
||||
var m = make(map[string][]*MethodInfo) // map[structName][]*MethodInfo
|
||||
|
||||
for _, info := range astInfos {
|
||||
if !info.IsFuncType() {
|
||||
continue
|
||||
}
|
||||
if len(info.Names) == 2 {
|
||||
funcName := info.Names[0]
|
||||
structName := info.Names[1]
|
||||
methodAst := &MethodInfo{
|
||||
Name: funcName,
|
||||
Comment: info.Comment,
|
||||
Body: info.Body,
|
||||
ReceiverName: structName,
|
||||
}
|
||||
if methodInfos, ok := m[structName]; !ok {
|
||||
m[structName] = []*MethodInfo{methodAst}
|
||||
} else {
|
||||
m[structName] = append(methodInfos, methodAst)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
type StructInfo struct {
|
||||
Name string
|
||||
Comment string
|
||||
Fields []*StructFieldInfo
|
||||
}
|
||||
|
||||
type StructFieldInfo struct {
|
||||
Name string
|
||||
Type string
|
||||
Comment string
|
||||
Body string
|
||||
}
|
||||
|
||||
// ParseStruct parse struct info from source code
|
||||
func ParseStruct(body string) (map[string]*StructInfo, error) { //nolint
|
||||
fset, f, src, err := parseBody(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var srcLines = strings.Split(src, "\n")
|
||||
var structInfos = make(map[string]*StructInfo)
|
||||
|
||||
for _, decl := range f.Decls {
|
||||
genDecl, ok := decl.(*ast.GenDecl)
|
||||
if !ok || genDecl.Tok != token.TYPE {
|
||||
continue
|
||||
}
|
||||
|
||||
singleComment := ""
|
||||
if genDecl.Doc != nil {
|
||||
singleComment = getSrcContent(srcLines, fset.Position(genDecl.Doc.List[0].Pos()).Line,
|
||||
fset.Position(genDecl.Doc.List[len(genDecl.Doc.List)-1].End()).Line)
|
||||
}
|
||||
|
||||
for _, spec := range genDecl.Specs {
|
||||
typeSpec, ok := spec.(*ast.TypeSpec)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
structType, ok := typeSpec.Type.(*ast.StructType)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
structName := typeSpec.Name.Name
|
||||
|
||||
// get struct comment
|
||||
var structComment string
|
||||
if typeSpec.Doc != nil {
|
||||
structComment = getSrcContent(srcLines, fset.Position(typeSpec.Doc.List[0].Pos()).Line,
|
||||
fset.Position(typeSpec.Doc.List[len(typeSpec.Doc.List)-1].End()).Line)
|
||||
}
|
||||
if len(genDecl.Specs) == 1 && singleComment != "" && structComment == "" {
|
||||
structComment = singleComment
|
||||
}
|
||||
|
||||
var fields []*StructFieldInfo
|
||||
for _, field := range structType.Fields.List {
|
||||
var fieldNames []string
|
||||
if len(field.Names) > 0 {
|
||||
for _, name := range field.Names {
|
||||
fieldNames = append(fieldNames, name.Name)
|
||||
}
|
||||
} else {
|
||||
// 处理嵌入字段
|
||||
switch x := field.Type.(type) {
|
||||
case *ast.Ident:
|
||||
fieldNames = append(fieldNames, x.Name)
|
||||
case *ast.StarExpr:
|
||||
if ident, ok := x.X.(*ast.Ident); ok {
|
||||
fieldNames = append(fieldNames, "*"+ident.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// get field name
|
||||
var fieldName string
|
||||
if len(fieldNames) > 0 {
|
||||
fieldName = fieldNames[0]
|
||||
}
|
||||
|
||||
// get comment
|
||||
var comment string
|
||||
if field.Doc != nil {
|
||||
comment = getSrcContent(srcLines, fset.Position(field.Doc.List[0].Pos()).Line,
|
||||
fset.Position(field.Doc.List[len(field.Doc.List)-1].End()).Line)
|
||||
}
|
||||
|
||||
// get source code of field
|
||||
code := getSrcContent(srcLines, fset.Position(field.Pos()).Line, fset.Position(field.End()).Line)
|
||||
|
||||
fields = append(fields, &StructFieldInfo{
|
||||
Name: fieldName,
|
||||
Type: getTypeString(field.Type),
|
||||
Comment: comment,
|
||||
Body: code,
|
||||
})
|
||||
}
|
||||
structInfos[structName] = &StructInfo{
|
||||
Name: structName,
|
||||
Comment: structComment,
|
||||
Fields: fields,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return structInfos, nil
|
||||
}
|
||||
|
||||
func getTypeString(expr ast.Expr) string {
|
||||
switch t := expr.(type) {
|
||||
case *ast.Ident:
|
||||
return t.Name
|
||||
case *ast.ArrayType:
|
||||
return "[]" + getTypeString(t.Elt)
|
||||
case *ast.StructType:
|
||||
return "struct"
|
||||
case *ast.SelectorExpr:
|
||||
return getTypeString(t.X) + "." + t.Sel.Name
|
||||
case *ast.StarExpr:
|
||||
return "*" + getTypeString(t.X)
|
||||
case *ast.MapType:
|
||||
return "map[" + getTypeString(t.Key) + "]" + getTypeString(t.Value)
|
||||
case *ast.InterfaceType:
|
||||
return "interface{}"
|
||||
case *ast.ChanType:
|
||||
return "chan " + getTypeString(t.Value)
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func getSrcContent(srcLines []string, start, end int) string {
|
||||
var srcContent string
|
||||
l := len(srcLines)
|
||||
if start < 1 || start > l || end < 1 || end > l {
|
||||
return ""
|
||||
}
|
||||
if start == end {
|
||||
srcContent = srcLines[start-1]
|
||||
} else {
|
||||
srcContent = strings.Join(srcLines[start-1:end], "\n")
|
||||
}
|
||||
return srcContent
|
||||
}
|
341
pkg/goast/ast_test.go
Normal file
341
pkg/goast/ast_test.go
Normal file
@@ -0,0 +1,341 @@
|
||||
package goast
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseGoFile(t *testing.T) {
|
||||
astInfos, err := ParseFile("ast.go")
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, len(astInfos), 10)
|
||||
}
|
||||
|
||||
func TestParseGoCode(t *testing.T) {
|
||||
var src = `
|
||||
package main
|
||||
|
||||
import "fmt"
|
||||
import "strings"
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
const (
|
||||
pi = 3.14
|
||||
language = "Go"
|
||||
)
|
||||
|
||||
var (
|
||||
version = "v1.0.0"
|
||||
repo = "sponge"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
|
||||
func (u *User) SayHello() {
|
||||
fmt.Println("Hello, my name is", u.Name)
|
||||
}
|
||||
|
||||
func main() {
|
||||
fmt.Println(pi)
|
||||
fmt.Println(language)
|
||||
fmt.Println(version)
|
||||
|
||||
user:=&User{Name:"Tom",Age:20}
|
||||
fmt.Println(user.Name)
|
||||
fmt.Println(user.Age)
|
||||
user.SayHello()
|
||||
}
|
||||
`
|
||||
|
||||
astInfos, err := ParseGoCode("", []byte(src))
|
||||
assert.NoError(t, err)
|
||||
for _, info := range astInfos {
|
||||
fmt.Printf(" %-20s: %s\n", "name", info.Names)
|
||||
fmt.Printf(" %-20s: %s\n", "comment", info.Comment)
|
||||
fmt.Printf(" %-20s: %s\n\n\n", "body", info.Body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseImportGroup(t *testing.T) {
|
||||
body := `
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
//"github.com/spf13/viper"
|
||||
apiV1 "yourModuleName/api/v1"
|
||||
// api v2
|
||||
apiV2 "yourModuleName/api/v2"
|
||||
)
|
||||
`
|
||||
|
||||
importInfos, err := ParseImportGroup(body)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
for _, ii := range importInfos {
|
||||
fmt.Printf(" %-20s: %s\n", "name", ii.Path)
|
||||
fmt.Printf(" %-20s: %s\n", "alias", ii.Alias)
|
||||
fmt.Printf(" %-20s: %s\n", "comment", ii.Comment)
|
||||
fmt.Printf(" %-20s: %s\n\n\n", "body", ii.Body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConstGroup(t *testing.T) {
|
||||
body := `
|
||||
// pi constant
|
||||
const pi = 3.14
|
||||
|
||||
const (
|
||||
// Version number
|
||||
version = "v1.0.0"
|
||||
)
|
||||
|
||||
const (
|
||||
// Development language
|
||||
language = "Go"
|
||||
|
||||
// database type
|
||||
dbDriver = "mysql"
|
||||
)
|
||||
`
|
||||
|
||||
constInfos, err := ParseConstGroup(body)
|
||||
if err != nil {
|
||||
assert.NotNil(t, err)
|
||||
return
|
||||
}
|
||||
for _, ci := range constInfos {
|
||||
fmt.Printf(" %-20s: %s\n", "name", ci.Name)
|
||||
fmt.Printf(" %-20s: %s\n", "value", ci.Value)
|
||||
fmt.Printf(" %-20s: %s\n", "comment", ci.Comment)
|
||||
fmt.Printf(" %-20s: %s\n\n\n", "body", ci.Body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseVarGroup(t *testing.T) {
|
||||
body := `
|
||||
var (
|
||||
// Version number
|
||||
version = "v1.0.0"
|
||||
|
||||
// Author
|
||||
author = "name"
|
||||
|
||||
// Repository
|
||||
repo = "sponge"
|
||||
|
||||
// Function variable
|
||||
f1 = func() {
|
||||
fmt.Println("hello")
|
||||
}
|
||||
)
|
||||
`
|
||||
|
||||
varInfos, err := ParseVarGroup(body)
|
||||
if err != nil {
|
||||
assert.NotNil(t, err)
|
||||
return
|
||||
}
|
||||
for _, vi := range varInfos {
|
||||
fmt.Printf(" %-20s: %s\n", "name", vi.Name)
|
||||
fmt.Printf(" %-20s: %s\n", "value", vi.Value)
|
||||
fmt.Printf(" %-20s: %s\n", "comment", vi.Comment)
|
||||
fmt.Printf(" %-20s: %s\n\n\n", "body", vi.Body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTypeGroup(t *testing.T) {
|
||||
body := `
|
||||
type (
|
||||
// Struct type
|
||||
ts struct {
|
||||
name string
|
||||
}
|
||||
|
||||
// Function type
|
||||
tfn func(name string) bool
|
||||
|
||||
// Interface type
|
||||
iFace interface {}
|
||||
|
||||
// Channel type
|
||||
ch chan int
|
||||
|
||||
// Map type
|
||||
m map[string]bool
|
||||
|
||||
// Slice type
|
||||
slice []int
|
||||
)
|
||||
`
|
||||
|
||||
typeInfos, err := ParseTypeGroup(body)
|
||||
if err != nil {
|
||||
assert.NotNil(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, ti := range typeInfos {
|
||||
fmt.Printf(" %-20s: %s\n", "type", ti.Type)
|
||||
fmt.Printf(" %-20s: %s\n", "name", ti.Name)
|
||||
fmt.Printf(" %-20s: %s\n", "comment", ti.Comment)
|
||||
fmt.Printf(" %-20s: %s\n\n\n", "body", ti.Body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseInterface(t *testing.T) {
|
||||
body := `
|
||||
type GreeterDao interface {
|
||||
// get by id
|
||||
Create(ctx context.Context, table *model.Greeter) error
|
||||
// delete by id
|
||||
DeleteByID(ctx context.Context, id uint64) error
|
||||
// update by id
|
||||
UpdateByID(ctx context.Context, table *model.Greeter) error
|
||||
UserExampleDao
|
||||
}
|
||||
|
||||
type UserExampleDao interface {
|
||||
// get by id
|
||||
Create(ctx context.Context, table *model.UserExample) error
|
||||
// update by id
|
||||
UpdateByID(ctx context.Context, table *model.UserExample) error
|
||||
}
|
||||
`
|
||||
|
||||
interfaceInfos, err := ParseInterface(body)
|
||||
if err != nil {
|
||||
assert.NotNil(t, err)
|
||||
return
|
||||
}
|
||||
for _, info := range interfaceInfos {
|
||||
fmt.Printf("%-20s : %s\n", "name", info.Name)
|
||||
fmt.Printf("%-20s : %s\n", "comment", info.Comment)
|
||||
for _, mi := range info.MethodInfos {
|
||||
fmt.Printf(" %-20s: %s\n", "method name", mi.Name)
|
||||
fmt.Printf(" %-20s: %s\n", "comment", mi.Comment)
|
||||
fmt.Printf(" %-20s: %s\n", "body", mi.Body)
|
||||
fmt.Printf(" %-20v: %t\n\n\n", "embedded", mi.IsIdent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStructMethods(t *testing.T) {
|
||||
src := `
|
||||
package demo
|
||||
|
||||
type userHandler struct {
|
||||
server userV1.UserServer
|
||||
}
|
||||
|
||||
// Create a record
|
||||
func (h *userHandler) Create(ctx context.Context, req *userV1.CreateUserRequest) (*userV1.CreateUserReply, error) {
|
||||
return h.server.Create(ctx, req)
|
||||
}
|
||||
|
||||
// DeleteByID delete a record by id
|
||||
func (h *userHandler) DeleteByID(ctx context.Context, req *userV1.DeleteUserByIDRequest) (*userV1.DeleteUserByIDReply, error) {
|
||||
return h.server.DeleteByID(ctx, req)
|
||||
}
|
||||
|
||||
// UpdateByID update a record by id
|
||||
func (h *userHandler) UpdateByID(ctx context.Context, req *userV1.UpdateUserByIDRequest) (*userV1.UpdateUserByIDReply, error) {
|
||||
return h.server.UpdateByID(ctx, req)
|
||||
}
|
||||
|
||||
type greeterHandler struct {
|
||||
server greeterV1.GreeterServer
|
||||
}
|
||||
|
||||
// Create a record
|
||||
func (h *greeterHandler) Create(ctx context.Context, req *greeterV1.CreateGreeterRequest) (*greeterV1.CreateGreeterReply, error) {
|
||||
return h.server.Create(ctx, req)
|
||||
}
|
||||
|
||||
// DeleteByID delete a record by id
|
||||
func (h *greeterHandler) DeleteByID(ctx context.Context, req *greeterV1.DeleteGreeterByIDRequest) (*greeterV1.DeleteGreeterByIDReply, error) {
|
||||
return h.server.DeleteByID(ctx, req)
|
||||
}
|
||||
`
|
||||
astInfos, err := ParseGoCode("", []byte(src))
|
||||
if err != nil {
|
||||
assert.NotNil(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
methods := ParseStructMethods(astInfos)
|
||||
for structName, methodInfos := range methods {
|
||||
fmt.Printf("%-20s : %s\n", "name", structName)
|
||||
for _, mi := range methodInfos {
|
||||
fmt.Printf(" %-20s: %s\n", "method name", mi.Name)
|
||||
fmt.Printf(" %-20s: %s\n", "comment", mi.Comment)
|
||||
fmt.Printf(" %-20s: %s\n\n\n", "body", mi.Body)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStruct(t *testing.T) {
|
||||
body := `
|
||||
package goast
|
||||
|
||||
// AstInfo is the information of a code block.
|
||||
type AstInfo struct {
|
||||
Kind string
|
||||
|
||||
// Names is the name of the code block, such as "func Name", "type Names", "const Names", "var Names", "import Paths".
|
||||
// If Type is "func", a standalone function without a receiver has a single name.
|
||||
// If the function is a method belonging to a struct, it has two names: the first
|
||||
// represents the function name, and the second represents the struct name.
|
||||
Names []string // todo add name
|
||||
|
||||
// User information
|
||||
User struct {
|
||||
Name string
|
||||
Age string
|
||||
}
|
||||
|
||||
// embedded struct
|
||||
*Address
|
||||
// embedded struct2
|
||||
Address
|
||||
|
||||
reader interface {}
|
||||
writer any
|
||||
sayMap map[string]string
|
||||
ch1 chan int
|
||||
ch2 chan *Address
|
||||
}
|
||||
|
||||
// Address is address
|
||||
type Address struct {
|
||||
State string
|
||||
// Addr is address
|
||||
Addr string
|
||||
}
|
||||
`
|
||||
structInfos, err := ParseStruct(body)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
for name, structInfo := range structInfos {
|
||||
fmt.Printf("%-20s : %s\n", "struct name", name)
|
||||
fmt.Printf("%-20s : %s\n", "struct comment", structInfo.Comment)
|
||||
for _, field := range structInfo.Fields {
|
||||
fmt.Printf(" %-20s: %s\n", "name", field.Name)
|
||||
fmt.Printf(" %-20s: %s\n", "type", field.Type)
|
||||
fmt.Printf(" %-20s: %s\n", "comment", field.Comment)
|
||||
fmt.Printf(" %-20s: %s\n\n", "body", field.Body)
|
||||
}
|
||||
fmt.Printf("\n\n\n")
|
||||
}
|
||||
}
|
105
pkg/goast/data/gen.go.code
Normal file
105
pkg/goast/data/gen.go.code
Normal file
@@ -0,0 +1,105 @@
|
||||
package data
|
||||
|
||||
import (
|
||||
// this is a comment
|
||||
"bytes"
|
||||
)
|
||||
|
||||
import "fmt"
|
||||
|
||||
// pi is a constant
|
||||
const pi = 3.14
|
||||
|
||||
// new const
|
||||
const (
|
||||
// MinRetries min retries
|
||||
MinRetries = 1
|
||||
// Timeout1 time out 1
|
||||
Timeout1 = 60
|
||||
// Timeout2 time out 2
|
||||
Timeout2 = 60
|
||||
Timeout = 5
|
||||
)
|
||||
|
||||
// a global variable a
|
||||
var a = 1
|
||||
var b = 2
|
||||
|
||||
// c global variable c
|
||||
var c = 3
|
||||
|
||||
type person struct {
|
||||
email string
|
||||
// age is the person's age
|
||||
age int
|
||||
}
|
||||
|
||||
// SayHello is a method
|
||||
// sayHello is a method 2
|
||||
func (p *person) SayHello() {
|
||||
fmt.Println("Hello, my age is", p.age)
|
||||
}
|
||||
|
||||
// GetEmail is a method
|
||||
// getEmail is a method 2
|
||||
func (p *person) GetEmail() {
|
||||
fmt.Println("Hi, my email is", p.email)
|
||||
}
|
||||
|
||||
type iSayer interface {
|
||||
// Say is a method
|
||||
Say()
|
||||
// SayHello is a method
|
||||
SayHello(name string)
|
||||
}
|
||||
|
||||
type iSayer2 interface {
|
||||
// Hi is a method
|
||||
Hi()
|
||||
}
|
||||
|
||||
type sayer interface {
|
||||
Say()
|
||||
Hi()
|
||||
}
|
||||
|
||||
type (
|
||||
fn1 func()
|
||||
chan1 chan int
|
||||
map1 map[string]int
|
||||
slice1 []int
|
||||
Speaker interface {
|
||||
Say()
|
||||
Hi()
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
_ = func() {}
|
||||
_ = chan1(nil)
|
||||
)
|
||||
|
||||
var _ []int
|
||||
var _ = fn1(func() {})
|
||||
|
||||
// Merge merges two slices
|
||||
func Merge(slice [][]byte) {
|
||||
fmt.Println(bytes.Join(slice, []byte("\n")))
|
||||
}
|
||||
|
||||
func GetByID(id int) {
|
||||
// do something
|
||||
}
|
||||
|
||||
// Hi is a method
|
||||
func (p *person) Hi() {
|
||||
fmt.Println("Hi, my age is", p.age)
|
||||
}
|
||||
|
||||
func init() {
|
||||
fmt.Println("init")
|
||||
}
|
||||
|
||||
func init() {
|
||||
fmt.Printf("%s\n", "init2")
|
||||
}
|
74
pkg/goast/data/src.go.code
Normal file
74
pkg/goast/data/src.go.code
Normal file
@@ -0,0 +1,74 @@
|
||||
package data
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const e = 2.7
|
||||
|
||||
const (
|
||||
MinRetries = 2
|
||||
Timeout = 1
|
||||
)
|
||||
|
||||
func join() {
|
||||
fmt.Println(strings.Join([]string{"a", "b", "c"}, ", "))
|
||||
}
|
||||
|
||||
// x is a variable
|
||||
var x = 1
|
||||
|
||||
var (
|
||||
a = 1
|
||||
y = 2
|
||||
)
|
||||
|
||||
type person struct {
|
||||
name string
|
||||
}
|
||||
|
||||
// Say is a method
|
||||
func (p *person) Say() {
|
||||
fmt.Println("Hello, my name is", p.name)
|
||||
}
|
||||
|
||||
// Hi is a method
|
||||
func (p *person) Hi() {
|
||||
fmt.Println("Hi, my name is", p.name)
|
||||
}
|
||||
|
||||
type iSayer interface {
|
||||
// Say is a method
|
||||
Say()
|
||||
// Hi is a method
|
||||
Hi()
|
||||
}
|
||||
|
||||
type iSayer2 interface {
|
||||
// Say is a method
|
||||
Say()
|
||||
}
|
||||
|
||||
type sayer interface{}
|
||||
|
||||
type (
|
||||
fn1 func()
|
||||
chan1 chan int
|
||||
)
|
||||
|
||||
type _ func(name int) bool
|
||||
|
||||
var _ = fn1(func() {})
|
||||
|
||||
func Merge(slice [][]byte) {
|
||||
|
||||
}
|
||||
|
||||
func GetByID(id int) {
|
||||
|
||||
}
|
||||
|
||||
func init() {
|
||||
fmt.Println("init")
|
||||
}
|
192
pkg/goast/filter.go
Normal file
192
pkg/goast/filter.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package goast
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/printer"
|
||||
"go/token"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// FuncInfo represents function information
|
||||
type FuncInfo struct {
|
||||
Name string
|
||||
Comment string
|
||||
}
|
||||
|
||||
// ExtractComment extracts function comments in Go code
|
||||
func (f FuncInfo) ExtractComment() string {
|
||||
if f.Comment == "" {
|
||||
return ""
|
||||
}
|
||||
comment := f.Comment
|
||||
|
||||
// regular matching `//` or `/* */` comments
|
||||
lineComment := regexp.MustCompile(`(?m)^//\s?`)
|
||||
blockComment := regexp.MustCompile(`(?m)/\*|\*/`)
|
||||
|
||||
// remove the `//` or `/* */` tags
|
||||
comment = lineComment.ReplaceAllString(comment, "")
|
||||
comment = blockComment.ReplaceAllString(comment, "")
|
||||
|
||||
// remove the space at the beginning of the line and split the line
|
||||
lines := strings.Split(comment, "\n")
|
||||
for i := range lines {
|
||||
lines[i] = strings.TrimSpace(lines[i])
|
||||
}
|
||||
|
||||
// output the comment string
|
||||
commentStr := strings.Join(lines, "\n")
|
||||
commentStr = strings.TrimSpace(commentStr)
|
||||
commentStr = strings.TrimPrefix(commentStr, f.Name)
|
||||
return strings.TrimSpace(commentStr)
|
||||
}
|
||||
|
||||
// containsPanicCall determine if there is a panic("implement me"), or customized flag, e.g. panic("ai to do")
|
||||
func containsPanicCall(fn *ast.FuncDecl, customFlag ...string) bool {
|
||||
found := false
|
||||
ast.Inspect(fn.Body, func(n ast.Node) bool {
|
||||
ce, ok := n.(*ast.CallExpr)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
ident, ok := ce.Fun.(*ast.Ident)
|
||||
if !ok || ident.Name != "panic" {
|
||||
return true
|
||||
}
|
||||
if len(ce.Args) > 0 {
|
||||
basicLit, ok := ce.Args[0].(*ast.BasicLit)
|
||||
if ok && basicLit.Kind == token.STRING {
|
||||
s, err := strconv.Unquote(basicLit.Value)
|
||||
if err == nil {
|
||||
if strings.HasPrefix(s, "implement me") {
|
||||
found = true
|
||||
return false // stop traversing if you find it.
|
||||
}
|
||||
for _, flag := range customFlag {
|
||||
if strings.Contains(s, flag) {
|
||||
found = true
|
||||
return false // stop traversing if you find it.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return found
|
||||
}
|
||||
|
||||
// interval indicates an area to delete
|
||||
type interval struct {
|
||||
start token.Pos
|
||||
end token.Pos
|
||||
}
|
||||
|
||||
// FilterFuncCodeByFile filters out the code of functions that contain panic("implement me") or customized flag, e.g. panic("ai to do")
|
||||
func FilterFuncCodeByFile(goFilePath string, customFlag ...string) ([]byte, []FuncInfo, error) {
|
||||
filename := filepath.Base(goFilePath)
|
||||
data, err := os.ReadFile(goFilePath)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return FilterFuncCode(filename, data, customFlag...)
|
||||
}
|
||||
|
||||
// FilterFuncCode filters out the code of functions that contain panic("implement me") or customized flag, e.g. panic("ai to do")
|
||||
func FilterFuncCode(filename string, data []byte, customFlag ...string) ([]byte, []FuncInfo, error) {
|
||||
fset := token.NewFileSet()
|
||||
// parse source code for comments
|
||||
f, err := parser.ParseFile(fset, filename, string(data), parser.ParseComments)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// used to record the code interval corresponding to the function to be deleted (including its Doc comment)
|
||||
var removeIntervals []interval
|
||||
// used to collect function names and comment information that contain panic ("implementation me")
|
||||
var panicFuncInfos []FuncInfo
|
||||
|
||||
var isMatch bool
|
||||
|
||||
// traverse declarations in the file, keeping only qualified function declarations
|
||||
var decls []ast.Decl
|
||||
for _, decl := range f.Decls {
|
||||
funcDecl, ok := decl.(*ast.FuncDecl)
|
||||
if !ok {
|
||||
decls = append(decls, decl)
|
||||
continue
|
||||
}
|
||||
|
||||
// preserve if function name starts with New
|
||||
if strings.HasPrefix(funcDecl.Name.Name, "New") {
|
||||
decls = append(decls, decl)
|
||||
continue
|
||||
}
|
||||
// if the function body is nil, it remains
|
||||
if funcDecl.Body == nil {
|
||||
decls = append(decls, decl)
|
||||
continue
|
||||
}
|
||||
|
||||
// if matches call panic("implement me") and has function comment, the function is retained
|
||||
if containsPanicCall(funcDecl, customFlag...) {
|
||||
commentText := ""
|
||||
if funcDecl.Doc != nil {
|
||||
var parts []string
|
||||
for _, c := range funcDecl.Doc.List {
|
||||
parts = append(parts, strings.TrimSpace(c.Text))
|
||||
}
|
||||
commentText = strings.Join(parts, "\n")
|
||||
}
|
||||
panicFuncInfos = append(panicFuncInfos, FuncInfo{
|
||||
Name: funcDecl.Name.Name,
|
||||
Comment: commentText,
|
||||
})
|
||||
if commentText != "" {
|
||||
decls = append(decls, decl)
|
||||
isMatch = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// delete other cases: record deletion interval
|
||||
start := funcDecl.Pos()
|
||||
if funcDecl.Doc != nil {
|
||||
start = funcDecl.Doc.Pos()
|
||||
}
|
||||
removeIntervals = append(removeIntervals, interval{start: start, end: funcDecl.End()})
|
||||
}
|
||||
if !isMatch {
|
||||
return nil, nil, fmt.Errorf("no function satisfies both conditions: 1. the function body contains the" +
|
||||
" panic(\"implement me\") marker, and 2. the function includes a comment describing its functionality")
|
||||
}
|
||||
|
||||
f.Decls = decls
|
||||
|
||||
// filter comment groups to remove comments that fall within the deletion interval
|
||||
var newComments []*ast.CommentGroup
|
||||
COMMENT:
|
||||
for _, cg := range f.Comments {
|
||||
for _, rem := range removeIntervals {
|
||||
if cg.Pos() >= rem.start && cg.End() <= rem.end {
|
||||
continue COMMENT
|
||||
}
|
||||
}
|
||||
newComments = append(newComments, cg)
|
||||
}
|
||||
f.Comments = newComments
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err = printer.Fprint(&buf, fset, f); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return buf.Bytes(), panicFuncInfos, nil
|
||||
}
|
40
pkg/goast/filter_test.go
Normal file
40
pkg/goast/filter_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package goast
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// This is a demo function for testing, default panic message is "implement me"
|
||||
func demoFn1() {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
// This is a demo function for testing, default panic message is "ai todo"
|
||||
func demoFn2() {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
// This is a demo function for testing, default panic message is "foobar"
|
||||
func demoFn3() {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func TestFilterFuncCodeByFile(t *testing.T) {
|
||||
code, infos, err := FilterFuncCodeByFile("filter_test.go")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, code)
|
||||
assert.Equal(t, 3, len(infos))
|
||||
assert.Equal(t, "demoFn1", infos[0].Name)
|
||||
assert.Contains(t, infos[0].ExtractComment(), `"implement me"`)
|
||||
//t.Log(code)
|
||||
|
||||
code, infos, err = FilterFuncCodeByFile("filter_test.go", "ai todo", "foobar")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, code)
|
||||
assert.Equal(t, 3, len(infos))
|
||||
assert.Equal(t, "demoFn3", infos[2].Name)
|
||||
assert.Contains(t, infos[2].ExtractComment(), `"foobar"`)
|
||||
//t.Log(code)
|
||||
}
|
799
pkg/goast/merge.go
Normal file
799
pkg/goast/merge.go
Normal file
@@ -0,0 +1,799 @@
|
||||
package goast
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/format"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type CodeAstOption func(*CodeAst)
|
||||
|
||||
func defaultClientOptions() *CodeAst {
|
||||
return &CodeAst{
|
||||
ignoreFuncNameMap: make(map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *CodeAst) apply(opts ...CodeAstOption) {
|
||||
for _, opt := range opts {
|
||||
opt(a)
|
||||
}
|
||||
}
|
||||
|
||||
// WithCoverSameFunc sets cover same function in the merged code
|
||||
func WithCoverSameFunc() CodeAstOption {
|
||||
return func(a *CodeAst) {
|
||||
a.isCoverSameFunc = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithIgnoreMergeFunc sets ignore to merge the same function name in the two code
|
||||
func WithIgnoreMergeFunc(funcName ...string) CodeAstOption {
|
||||
return func(a *CodeAst) {
|
||||
for _, name := range funcName {
|
||||
a.ignoreFuncNameMap[name] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CodeAst is the struct for code
|
||||
type CodeAst struct {
|
||||
FilePath string
|
||||
Code string
|
||||
|
||||
AstInfos []*AstInfo
|
||||
packageInfo *AstInfo
|
||||
importInfos []*AstInfo
|
||||
constInfos []*AstInfo
|
||||
varInfos []*AstInfo
|
||||
typeInfos []*AstInfo
|
||||
funcInfos []*AstInfo
|
||||
|
||||
nonExistedConstCode []string
|
||||
nonExistedVarCode []string
|
||||
nonExistedTypeInfoMap map[string]*TypeInfo // key is type name
|
||||
mergedStructMethodsMap map[string]struct{} // key is struct name
|
||||
ignoreFuncNameMap map[string]struct{} // key is function name
|
||||
|
||||
changeCodeFlag bool
|
||||
isCoverSameFunc bool
|
||||
}
|
||||
|
||||
// NewCodeAst creates a new CodeAst object from file path
|
||||
func NewCodeAst(filePath string, opts ...CodeAstOption) (*CodeAst, error) {
|
||||
o := defaultClientOptions()
|
||||
o.apply(opts...)
|
||||
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
astInfos, err := ParseGoCode(filePath, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
codeAst := &CodeAst{
|
||||
FilePath: filePath,
|
||||
Code: string(data),
|
||||
AstInfos: astInfos,
|
||||
nonExistedTypeInfoMap: make(map[string]*TypeInfo),
|
||||
mergedStructMethodsMap: make(map[string]struct{}),
|
||||
isCoverSameFunc: o.isCoverSameFunc,
|
||||
ignoreFuncNameMap: o.ignoreFuncNameMap,
|
||||
}
|
||||
codeAst.setSlices()
|
||||
|
||||
return codeAst, nil
|
||||
}
|
||||
|
||||
// NewCodeAstFromData creates a new CodeAst object from data
|
||||
func NewCodeAstFromData(data []byte, opts ...CodeAstOption) (*CodeAst, error) {
|
||||
o := defaultClientOptions()
|
||||
o.apply(opts...)
|
||||
|
||||
astInfos, err := ParseGoCode("", data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
codeAst := &CodeAst{
|
||||
Code: string(data),
|
||||
AstInfos: astInfos,
|
||||
nonExistedTypeInfoMap: make(map[string]*TypeInfo),
|
||||
mergedStructMethodsMap: make(map[string]struct{}),
|
||||
isCoverSameFunc: o.isCoverSameFunc,
|
||||
ignoreFuncNameMap: o.ignoreFuncNameMap,
|
||||
}
|
||||
codeAst.setSlices()
|
||||
|
||||
return codeAst, nil
|
||||
}
|
||||
|
||||
func (a *CodeAst) setSlices() {
|
||||
for _, astInfo := range a.AstInfos {
|
||||
switch astInfo.Type {
|
||||
case PackageType:
|
||||
a.packageInfo = astInfo
|
||||
case ImportType:
|
||||
a.importInfos = append(a.importInfos, astInfo)
|
||||
case ConstType:
|
||||
a.constInfos = append(a.constInfos, astInfo)
|
||||
case VarType:
|
||||
a.varInfos = append(a.varInfos, astInfo)
|
||||
case TypeType:
|
||||
a.typeInfos = append(a.typeInfos, astInfo)
|
||||
case FuncType:
|
||||
a.funcInfos = append(a.funcInfos, astInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *CodeAst) mergeImportCode(genAst *CodeAst) error {
|
||||
if len(genAst.importInfos) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 1. append import code to package
|
||||
|
||||
if len(a.importInfos) == 0 {
|
||||
srcStr := a.packageInfo.Body
|
||||
dstStr := ""
|
||||
for _, info := range genAst.AstInfos {
|
||||
if info.IsImportType() {
|
||||
dstStr += info.Body + "\n"
|
||||
}
|
||||
}
|
||||
if strings.Count(a.Code, srcStr) > 1 {
|
||||
return errDuplication("mergeImportCode", srcStr)
|
||||
}
|
||||
|
||||
a.Code = strings.Replace(a.Code, srcStr, srcStr+"\n\n"+dstStr, 1)
|
||||
a.changeCodeFlag = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// 2. append import code to import
|
||||
|
||||
srcImportInfos, err := a.parseImportCode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
genImportInfos, err := genAst.parseImportCode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
srcLen := len(srcImportInfos)
|
||||
srcImportInfoMap := make(map[string]struct{}, srcLen)
|
||||
for _, srcIi := range srcImportInfos {
|
||||
srcImportInfoMap[srcIi.Path] = struct{}{}
|
||||
}
|
||||
|
||||
var nonExistedImportInfos []*ImportInfo
|
||||
//var nonExistedImportPaths []string
|
||||
for _, genIfi := range genImportInfos {
|
||||
if _, ok := srcImportInfoMap[genIfi.Path]; !ok {
|
||||
nonExistedImportInfos = append(nonExistedImportInfos, genIfi)
|
||||
//nonExistedImportPaths = append(nonExistedImportPaths, genIfi.Path)
|
||||
}
|
||||
}
|
||||
|
||||
if len(nonExistedImportInfos) > 0 {
|
||||
var srcStr = a.packageInfo.Body
|
||||
var dstStr = "import (\n"
|
||||
for _, info := range srcImportInfos {
|
||||
if info.Comment != "" {
|
||||
dstStr += info.Comment + "\n"
|
||||
}
|
||||
dstStr += "\t" + trimBody(info.Body, ImportType) + "\n"
|
||||
}
|
||||
for _, info := range nonExistedImportInfos {
|
||||
if info.Comment != "" {
|
||||
dstStr += info.Comment + "\n"
|
||||
}
|
||||
dstStr += "\t" + trimBody(info.Body, ImportType) + "\n"
|
||||
}
|
||||
dstStr += ")"
|
||||
|
||||
a.Code = strings.Replace(a.Code, srcStr, srcStr+"\n\n"+dstStr, 1)
|
||||
a.changeCodeFlag = true
|
||||
for _, info := range a.importInfos {
|
||||
a.Code = strings.Replace(a.Code, info.Body, "", 1)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *CodeAst) compareConstCode(genAst *CodeAst) error {
|
||||
if len(genAst.constInfos) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(a.constInfos) == 0 {
|
||||
dstStr := ""
|
||||
for _, info := range genAst.AstInfos {
|
||||
if info.IsConstType() {
|
||||
if info.Comment != "" {
|
||||
dstStr += info.Comment + "\n"
|
||||
}
|
||||
dstStr += info.Body + "\n"
|
||||
}
|
||||
}
|
||||
a.nonExistedConstCode = append(a.nonExistedConstCode, dstStr)
|
||||
return nil
|
||||
}
|
||||
|
||||
srcConstInfos, err := a.parseConstCode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
genConstInfos, err := genAst.parseConstCode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
srcLen := len(srcConstInfos)
|
||||
srcConstInfoMap := make(map[string]struct{}, srcLen)
|
||||
for _, srcCi := range srcConstInfos {
|
||||
srcConstInfoMap[srcCi.Name] = struct{}{}
|
||||
}
|
||||
|
||||
var nonExistedConstInfos []*ConstInfo
|
||||
//var nonExistedConstNames []string
|
||||
for _, genCi := range genConstInfos {
|
||||
if _, ok := srcConstInfoMap[genCi.Name]; !ok {
|
||||
nonExistedConstInfos = append(nonExistedConstInfos, genCi)
|
||||
//nonExistedConstNames = append(nonExistedConstNames, genCi.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if len(nonExistedConstInfos) > 0 {
|
||||
var dstStr = "const (\n"
|
||||
for _, info := range nonExistedConstInfos {
|
||||
if info.Comment != "" {
|
||||
if !strings.HasPrefix(info.Comment, "\t") {
|
||||
info.Comment = "\t" + info.Comment
|
||||
}
|
||||
dstStr += info.Comment + "\n"
|
||||
}
|
||||
dstStr += "\t" + trimBody(info.Body, ConstType) + "\n"
|
||||
}
|
||||
dstStr += ")\n"
|
||||
a.nonExistedConstCode = append(a.nonExistedConstCode, dstStr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *CodeAst) compareVarCode(genAst *CodeAst) error {
|
||||
if len(genAst.varInfos) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(a.varInfos) == 0 {
|
||||
dstStr := ""
|
||||
for _, info := range genAst.AstInfos {
|
||||
if info.IsVarType() {
|
||||
if info.Comment != "" {
|
||||
dstStr += info.Comment + "\n"
|
||||
}
|
||||
dstStr += info.Body + "\n"
|
||||
}
|
||||
}
|
||||
a.nonExistedVarCode = append(a.nonExistedVarCode, dstStr)
|
||||
return nil
|
||||
}
|
||||
|
||||
srcVarInfos, err := a.parseVarCode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
genVarInfos, err := genAst.parseVarCode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
srcLen := len(srcVarInfos)
|
||||
srcVarInfoMap := make(map[string]struct{}, srcLen)
|
||||
for _, srcVi := range srcVarInfos {
|
||||
srcVarInfoMap[srcVi.Name] = struct{}{}
|
||||
}
|
||||
|
||||
var nonExistedVarInfos []*VarInfo
|
||||
//var nonExistedVarNames []string
|
||||
for _, genVi := range genVarInfos {
|
||||
if _, ok := srcVarInfoMap[genVi.Name]; !ok {
|
||||
nonExistedVarInfos = append(nonExistedVarInfos, genVi)
|
||||
//nonExistedVarNames = append(nonExistedVarNames, genVi.Name)
|
||||
} else {
|
||||
if genVi.Name == "_" && !strings.Contains(a.Code, strings.TrimSpace(genVi.Body)) {
|
||||
nonExistedVarInfos = append(nonExistedVarInfos, genVi)
|
||||
//nonExistedVarNames = append(nonExistedVarNames, genVi.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(nonExistedVarInfos) > 0 {
|
||||
var dstStr = "var (\n"
|
||||
for _, info := range nonExistedVarInfos {
|
||||
if info.Comment != "" {
|
||||
if !strings.HasPrefix(info.Comment, "\t") {
|
||||
info.Comment = "\t" + info.Comment
|
||||
}
|
||||
dstStr += info.Comment + "\n"
|
||||
}
|
||||
dstStr += "\t" + trimBody(info.Body, VarType) + "\n"
|
||||
}
|
||||
dstStr += ")\n"
|
||||
a.nonExistedVarCode = append(a.nonExistedVarCode, dstStr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *CodeAst) mergeExistedTypeCode(genAst *CodeAst) error {
|
||||
srcTypeInfosMap, err := a.parseTypeCode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
genTypeInfosMap, err := genAst.parseTypeCode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var srcTypeNameMap = make(map[string]struct{})
|
||||
for _, srcTypeInfos := range srcTypeInfosMap {
|
||||
for _, info := range srcTypeInfos {
|
||||
srcTypeNameMap[info.Name] = struct{}{}
|
||||
}
|
||||
}
|
||||
var nonExistedTypeInfoMap = make(map[string]*TypeInfo)
|
||||
|
||||
for typeName, genTypeInfos := range genTypeInfosMap {
|
||||
// get non-existed type infos
|
||||
for _, info := range genTypeInfos {
|
||||
if _, ok := srcTypeNameMap[info.Name]; !ok {
|
||||
nonExistedTypeInfoMap[info.Name] = info
|
||||
}
|
||||
}
|
||||
|
||||
// merge existed interface method code and struct fields code
|
||||
srcTypeInfos, ok := srcTypeInfosMap[typeName]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
srcTypeInfoMap := make(map[string]*TypeInfo, len(srcTypeInfos))
|
||||
for _, srcTi := range srcTypeInfos {
|
||||
srcTypeInfoMap[srcTi.Name] = srcTi
|
||||
}
|
||||
for _, genTypeInfo := range genTypeInfos {
|
||||
switch genTypeInfo.Type {
|
||||
case InterfaceType:
|
||||
if srcTypeInfo, ok := srcTypeInfoMap[genTypeInfo.Name]; ok {
|
||||
err = a.mergeInterfaceMethodCode(srcTypeInfo, genTypeInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case StructType:
|
||||
if srcTypeInfo, ok := srcTypeInfoMap[genTypeInfo.Name]; ok {
|
||||
err = a.mergeStructFieldsCode(srcTypeInfo, genTypeInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
a.nonExistedTypeInfoMap = nonExistedTypeInfoMap
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *CodeAst) mergeInterfaceMethodCode(srcTypeInfo *TypeInfo, genTypeInfo *TypeInfo) error {
|
||||
srcInterfaceInfos, err := ParseInterface(srcTypeInfo.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
genInterfaceInfos, err := ParseInterface(genTypeInfo.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(srcInterfaceInfos) == 1 && len(genInterfaceInfos) == 1 {
|
||||
srcLastMethodStr := ""
|
||||
mLen := len(srcInterfaceInfos[0].MethodInfos)
|
||||
srcInterfaceMethodInfoMap := make(map[string]struct{}, mLen)
|
||||
for i, srcIm := range srcInterfaceInfos[0].MethodInfos {
|
||||
srcInterfaceMethodInfoMap[srcIm.Name] = struct{}{}
|
||||
if i == mLen-1 {
|
||||
srcLastMethodStr = srcIm.Body
|
||||
}
|
||||
}
|
||||
var newMethods []string
|
||||
for _, genMethodInfo := range genInterfaceInfos[0].MethodInfos {
|
||||
if _, ok := srcInterfaceMethodInfoMap[genMethodInfo.Name]; !ok {
|
||||
newMethodStr := ""
|
||||
if genMethodInfo.Comment != "" {
|
||||
newMethodStr += genMethodInfo.Comment + "\n"
|
||||
}
|
||||
newMethodStr += genMethodInfo.Body
|
||||
newMethods = append(newMethods, newMethodStr)
|
||||
}
|
||||
}
|
||||
if len(newMethods) > 0 {
|
||||
srcStr := srcTypeInfo.Body
|
||||
if strings.Count(a.Code, srcStr) > 1 {
|
||||
return errDuplication("mergeInterfaceMethodCode", srcStr)
|
||||
}
|
||||
dstStr := ""
|
||||
if len(srcInterfaceInfos[0].MethodInfos) == 0 {
|
||||
dstStr = "type " + srcInterfaceInfos[0].Name + " interface {\n" + strings.Join(newMethods, "\n") + "\n}"
|
||||
} else {
|
||||
dstStr = strings.Replace(srcStr, srcLastMethodStr, srcLastMethodStr+"\n"+strings.Join(newMethods, "\n"), 1)
|
||||
}
|
||||
a.Code = strings.Replace(a.Code, srcStr, dstStr, 1)
|
||||
a.changeCodeFlag = true
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *CodeAst) mergeStructFieldsCode(srcTypeInfo *TypeInfo, genTypeInfo *TypeInfo) error {
|
||||
srcStructInfos, err := ParseStruct(srcTypeInfo.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
genStructInfos, err := ParseStruct(genTypeInfo.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, genStructInfo := range genStructInfos {
|
||||
srcStructInfo, ok := srcStructInfos[name]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
fLen := len(srcStructInfo.Fields)
|
||||
srcLastFieldStr := ""
|
||||
srcFieldMap := make(map[string]struct{}, fLen)
|
||||
for i, field := range srcStructInfo.Fields {
|
||||
srcFieldMap[field.Name] = struct{}{}
|
||||
if i == fLen-1 {
|
||||
srcLastFieldStr = field.Body
|
||||
}
|
||||
}
|
||||
|
||||
var newFields []string
|
||||
for _, field := range genStructInfo.Fields {
|
||||
newFieldStr := ""
|
||||
if _, ok := srcFieldMap[field.Name]; !ok {
|
||||
if field.Comment != "" {
|
||||
newFieldStr += field.Comment + "\n"
|
||||
}
|
||||
newFieldStr += field.Body
|
||||
newFields = append(newFields, newFieldStr)
|
||||
}
|
||||
}
|
||||
if len(newFields) > 0 {
|
||||
srcStr := srcTypeInfo.Body
|
||||
if strings.Count(a.Code, srcStr) > 1 {
|
||||
return errDuplication("mergeStructFieldsCode", srcStr)
|
||||
}
|
||||
dstStr := ""
|
||||
if len(srcStructInfo.Fields) == 0 {
|
||||
dstStr = "type " + srcStructInfo.Name + " struct {\n" + strings.Join(newFields, "\n") + "\n}"
|
||||
} else {
|
||||
dstStr = strings.Replace(srcStr, srcLastFieldStr, srcLastFieldStr+"\n"+strings.Join(newFields, "\n"), 1)
|
||||
}
|
||||
a.Code = strings.Replace(a.Code, srcStr, dstStr, 1)
|
||||
a.changeCodeFlag = true
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *CodeAst) mergeStructMethodsCode(genAst *CodeAst) error {
|
||||
srcImportInfoMap := ParseStructMethods(a.AstInfos)
|
||||
genImportInfoMap := ParseStructMethods(genAst.AstInfos)
|
||||
|
||||
for structName, genMethods := range genImportInfoMap {
|
||||
var nonExistedImports []string
|
||||
var lastMethodFuncCode string
|
||||
if srcMethods, ok := srcImportInfoMap[structName]; ok {
|
||||
var srcMethodMap = make(map[string]struct{}, len(srcMethods))
|
||||
for i, srcMethod := range srcMethods {
|
||||
srcMethodMap[srcMethod.Name] = struct{}{}
|
||||
if i == len(srcMethods)-1 {
|
||||
lastMethodFuncCode = srcMethod.Body
|
||||
}
|
||||
}
|
||||
for _, genMethod := range genMethods {
|
||||
if _, isExisted := srcMethodMap[genMethod.Name]; !isExisted {
|
||||
nonExistedImports = append(nonExistedImports, genMethod.Comment+"\n"+genMethod.Body)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(nonExistedImports) > 0 {
|
||||
srcStr := lastMethodFuncCode
|
||||
if strings.Count(a.Code, srcStr) > 1 {
|
||||
return errDuplication("mergeStructMethodsCode", srcStr)
|
||||
}
|
||||
dstStr := lastMethodFuncCode + "\n\n" + strings.Join(nonExistedImports, "\n\n")
|
||||
a.Code = strings.Replace(a.Code, srcStr, dstStr, 1)
|
||||
a.changeCodeFlag = true
|
||||
a.mergedStructMethodsMap[structName] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *CodeAst) coverFuncCode(genAst *CodeAst) {
|
||||
var srcFuncNameMap = make(map[string]*AstInfo)
|
||||
for _, srcFuncInfo := range a.funcInfos {
|
||||
srcFuncNameMap[srcFuncInfo.GetName()] = srcFuncInfo
|
||||
}
|
||||
for _, genFuncInfo := range genAst.funcInfos {
|
||||
genFuncName := genFuncInfo.GetName()
|
||||
if genFuncName == "init" || genFuncName == "_" {
|
||||
continue
|
||||
}
|
||||
|
||||
var ignoreFuncName string
|
||||
if len(genFuncInfo.Names) == 2 {
|
||||
ignoreFuncName = genFuncInfo.Names[0]
|
||||
} else {
|
||||
ignoreFuncName = genFuncName
|
||||
}
|
||||
if _, ok := a.ignoreFuncNameMap[ignoreFuncName]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if srcFuncInfo, ok := srcFuncNameMap[genFuncName]; ok {
|
||||
srcStr := srcFuncInfo.Body
|
||||
dstStr := genFuncInfo.Body
|
||||
comment := ""
|
||||
if srcFuncInfo.Comment == "" && genFuncInfo.Comment != "" {
|
||||
comment = genFuncInfo.Comment
|
||||
}
|
||||
if comment != "" {
|
||||
dstStr = comment + "\n" + dstStr
|
||||
}
|
||||
a.Code = strings.Replace(a.Code, srcStr, dstStr, 1)
|
||||
a.changeCodeFlag = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// appends non-existed code to the end of the source code.
|
||||
func (a *CodeAst) appendNonExistedCode(genAsts []*AstInfo) error { // nolint
|
||||
srcNameAstMap := make(map[string]struct{}, len(a.AstInfos))
|
||||
for _, info := range a.AstInfos {
|
||||
srcNameAstMap[info.GetName()] = struct{}{}
|
||||
}
|
||||
|
||||
if len(a.nonExistedConstCode) > 0 {
|
||||
a.Code += "\n" + strings.Join(a.nonExistedConstCode, "\n")
|
||||
a.changeCodeFlag = true
|
||||
}
|
||||
if len(a.nonExistedVarCode) > 0 {
|
||||
a.Code += "\n" + strings.Join(a.nonExistedVarCode, "\n")
|
||||
a.changeCodeFlag = true
|
||||
}
|
||||
|
||||
var appendCodes []string
|
||||
for _, genAst := range genAsts {
|
||||
if genAst.IsPackageType() || genAst.IsImportType() || genAst.IsConstType() || genAst.IsVarType() {
|
||||
continue
|
||||
}
|
||||
if genAst.IsFuncType() && len(genAst.Names) == 2 {
|
||||
if _, ok := a.mergedStructMethodsMap[genAst.Names[1]]; ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
isNeedAppend := false
|
||||
name := genAst.GetName()
|
||||
if _, ok := srcNameAstMap[name]; !ok {
|
||||
if genAst.IsTypeType() && len(genAst.Names) > 1 {
|
||||
var nonExistedTypes []string
|
||||
for _, name := range genAst.Names {
|
||||
var dstStr string
|
||||
if info, ok := a.nonExistedTypeInfoMap[name]; ok {
|
||||
if info.Comment != "" {
|
||||
dstStr += info.Comment + "\n"
|
||||
}
|
||||
dstStr += info.Body
|
||||
nonExistedTypes = append(nonExistedTypes, dstStr)
|
||||
}
|
||||
}
|
||||
if len(nonExistedTypes) > 0 {
|
||||
appendCodes = append(appendCodes, "type (\n"+strings.Join(nonExistedTypes, "\n")+"\n)")
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
isNeedAppend = true
|
||||
} else {
|
||||
if name == "_" && !strings.Contains(a.Code, genAst.Body) {
|
||||
isNeedAppend = true
|
||||
}
|
||||
if genAst.IsFuncType() && name == "init" && !strings.Contains(a.Code, genAst.Body) {
|
||||
isNeedAppend = true
|
||||
}
|
||||
}
|
||||
if isNeedAppend {
|
||||
comment := ""
|
||||
if genAst.Comment != "" {
|
||||
comment = genAst.Comment
|
||||
}
|
||||
appendCodes = append(appendCodes, comment+"\n"+genAst.Body)
|
||||
}
|
||||
}
|
||||
|
||||
if len(appendCodes) > 0 {
|
||||
a.Code += strings.Join(appendCodes, "\n\n") + "\n"
|
||||
a.changeCodeFlag = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *CodeAst) parseImportCode() ([]*ImportInfo, error) {
|
||||
body := ""
|
||||
for _, info := range a.importInfos {
|
||||
if info.Comment != "" {
|
||||
body += info.Comment + "\n"
|
||||
}
|
||||
body += info.Body + "\n"
|
||||
}
|
||||
return ParseImportGroup(body)
|
||||
}
|
||||
|
||||
func (a *CodeAst) parseConstCode() ([]*ConstInfo, error) {
|
||||
body := ""
|
||||
for _, info := range a.constInfos {
|
||||
if info.Comment != "" {
|
||||
body += info.Comment + "\n"
|
||||
}
|
||||
body += info.Body + "\n"
|
||||
}
|
||||
return ParseConstGroup(body)
|
||||
}
|
||||
|
||||
func (a *CodeAst) parseVarCode() ([]*VarInfo, error) {
|
||||
body := ""
|
||||
for _, info := range a.varInfos {
|
||||
if info.Comment != "" {
|
||||
body += info.Comment + "\n"
|
||||
}
|
||||
body += info.Body + "\n"
|
||||
}
|
||||
return ParseVarGroup(body)
|
||||
}
|
||||
|
||||
func (a *CodeAst) parseTypeCode() (map[string][]*TypeInfo, error) {
|
||||
body := ""
|
||||
for _, info := range a.typeInfos {
|
||||
if info.Comment != "" {
|
||||
body += info.Comment + "\n"
|
||||
}
|
||||
body += info.Body + "\n"
|
||||
}
|
||||
|
||||
typeInfos, err := ParseTypeGroup(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
typeMap := make(map[string][]*TypeInfo, len(typeInfos))
|
||||
for _, info := range typeInfos {
|
||||
if info.Name == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := typeMap[info.Name]; !ok {
|
||||
typeMap[info.Name] = []*TypeInfo{}
|
||||
}
|
||||
typeMap[info.Name] = append(typeMap[info.Name], info)
|
||||
}
|
||||
return typeMap, nil
|
||||
}
|
||||
|
||||
func errDuplication(marker string, srcStr string) error {
|
||||
return fmt.Errorf("%s: multiple duplicate string `%s` exists, please modify the source code to ensure uniqueness", marker, srcStr)
|
||||
}
|
||||
|
||||
func trimBody(body string, codeType string) string {
|
||||
body = strings.TrimSpace(body)
|
||||
return strings.TrimPrefix(body, codeType+" ")
|
||||
}
|
||||
|
||||
// MergeGoFile merges two Go code files into one.
|
||||
func MergeGoFile(srcFile string, genFile string, opts ...CodeAstOption) (*CodeAst, error) {
|
||||
srcAst, err := NewCodeAst(srcFile, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
genAst, err := NewCodeAst(genFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mergeCode(srcAst, genAst)
|
||||
}
|
||||
|
||||
// MergeGoCode merges two Go code strings into one.
|
||||
func MergeGoCode(srcCode []byte, genCode []byte, opts ...CodeAstOption) (*CodeAst, error) {
|
||||
srcAst, err := NewCodeAstFromData(srcCode, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
genAst, err := NewCodeAstFromData(genCode, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mergeCode(srcAst, genAst)
|
||||
}
|
||||
|
||||
func mergeCode(srcAst *CodeAst, genAst *CodeAst) (*CodeAst, error) {
|
||||
// merge import code
|
||||
err := srcAst.mergeImportCode(genAst)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// compare const code
|
||||
err = srcAst.compareConstCode(genAst)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// compare var code
|
||||
err = srcAst.compareVarCode(genAst)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// merge interface method and struct fields code
|
||||
err = srcAst.mergeExistedTypeCode(genAst)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// merge struct method function code
|
||||
err = srcAst.mergeStructMethodsCode(genAst)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if srcAst.isCoverSameFunc {
|
||||
// cover same function code
|
||||
srcAst.coverFuncCode(genAst)
|
||||
}
|
||||
|
||||
// append non-existed code
|
||||
err = srcAst.appendNonExistedCode(genAst.AstInfos)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if srcAst.changeCodeFlag {
|
||||
data, err := format.Source([]byte(srcAst.Code))
|
||||
if err == nil {
|
||||
srcAst.Code = string(data)
|
||||
}
|
||||
}
|
||||
|
||||
return srcAst, nil
|
||||
}
|
110
pkg/goast/merge_test.go
Normal file
110
pkg/goast/merge_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package goast
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
srcFile = "data/src.go.code"
|
||||
genFile = "data/gen.go.code"
|
||||
)
|
||||
|
||||
func TestMergeGoFile(t *testing.T) {
|
||||
t.Run("without cover same func", func(t *testing.T) {
|
||||
codeAst, err := MergeGoFile(srcFile, genFile)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
assert.Equal(t, codeAst.changeCodeFlag, true)
|
||||
fmt.Println(codeAst.Code)
|
||||
})
|
||||
|
||||
t.Run("cover same func", func(t *testing.T) {
|
||||
codeAst, err := MergeGoFile(srcFile, genFile, WithCoverSameFunc())
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
assert.Equal(t, codeAst.changeCodeFlag, true)
|
||||
fmt.Println(codeAst.Code)
|
||||
})
|
||||
|
||||
t.Run("test same file", func(t *testing.T) {
|
||||
codeAst, err := MergeGoFile(srcFile, srcFile)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
assert.Equal(t, codeAst.changeCodeFlag, false)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMergeGoCode(t *testing.T) {
|
||||
t.Run("without cover same func", func(t *testing.T) {
|
||||
srcData, genData, _ := getGoCode()
|
||||
codeAst, err := MergeGoCode(srcData, genData)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
assert.Equal(t, codeAst.changeCodeFlag, true)
|
||||
fmt.Println(codeAst.Code)
|
||||
})
|
||||
|
||||
t.Run("cover same func", func(t *testing.T) {
|
||||
srcData, genData, _ := getGoCode()
|
||||
codeAst, err := MergeGoCode(srcData, genData,
|
||||
WithCoverSameFunc(),
|
||||
WithIgnoreMergeFunc("GetByID", "Hi"))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
assert.Equal(t, codeAst.changeCodeFlag, true)
|
||||
fmt.Println(codeAst.Code)
|
||||
})
|
||||
|
||||
t.Run("test same file", func(t *testing.T) {
|
||||
srcData, _, _ := getGoCode()
|
||||
codeAst, err := MergeGoCode(srcData, srcData)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
assert.Equal(t, codeAst.changeCodeFlag, false)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewCodeAstFromData(t *testing.T) {
|
||||
srcData, _, err := getGoCode()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
codeAst, err := NewCodeAstFromData(srcData)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
codeAst.FilePath = srcFile
|
||||
assert.Equal(t, true, len(codeAst.AstInfos) > 0)
|
||||
}
|
||||
|
||||
func getGoCode() ([]byte, []byte, error) {
|
||||
srcData, err := os.ReadFile(srcFile)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
||||
}
|
||||
genData, err := os.ReadFile(genFile)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return srcData, genData, nil
|
||||
}
|
Reference in New Issue
Block a user