refactor: adjust the logic of generating gRPC code

This commit is contained in:
zhuyasen
2025-04-06 19:29:12 +08:00
parent 137031562f
commit 70b6d66c93
4 changed files with 26 additions and 50 deletions

View File

@@ -10,15 +10,16 @@ import (
)
// GenerateFiles generate service template code and error codes
func GenerateFiles(file *protogen.File, moduleName string) ([]byte, []byte, []byte) {
func GenerateFiles(file *protogen.File, moduleName string) (serviceTmplContent []byte,
serviceTestTmplContent []byte, errCodeFileContent []byte) {
if len(file.Services) == 0 {
return nil, nil, nil
}
pss := parse.GetServices(file, moduleName)
serviceTmplContent := genServiceTmplFile(pss)
serviceTestTmplContent := genServiceTestTmplFile(pss)
errCodeFileContent := genErrCodeFile(pss)
serviceTmplContent = genServiceTmplFile(pss)
serviceTestTmplContent = genServiceTestTmplFile(pss)
errCodeFileContent = genErrCodeFile(pss)
return serviceTmplContent, serviceTestTmplContent, errCodeFileContent
}
@@ -47,7 +48,7 @@ func (f *serviceTmplFields) execute() []byte {
if err := serviceLogicTmpl.Execute(buf, f); err != nil {
panic(err)
}
content := handleSplitLineMark(buf.Bytes())
content := buf.Bytes()
return bytes.ReplaceAll(content, []byte(importPkgPathMark), parse.GetImportPkg(f.PbServices))
}
@@ -74,25 +75,7 @@ func (f *errCodeFields) execute() []byte {
panic(err)
}
data := bytes.ReplaceAll(buf.Bytes(), []byte("// --blank line--"), []byte{})
return handleSplitLineMark(data)
return data
}
const importPkgPathMark = "// import api service package here"
var splitLineMark = []byte(`// ---------- Do not delete or move this split line, this is the merge code marker ----------`)
func handleSplitLineMark(data []byte) []byte {
ss := bytes.Split(data, splitLineMark)
if len(ss) <= 2 {
return ss[0]
}
var out []byte
for i, s := range ss {
out = append(out, s...)
if i < len(ss)-2 {
out = append(out, splitLineMark...)
}
}
return out
}

View File

@@ -80,7 +80,7 @@ func New{{.Name}}Server() {{.ProtoPkgName}}.{{.Name}}Server {
{{if eq .InvokeType 1}}
{{.Comment}}
func (s *{{.LowerServiceName}}) {{.MethodName}}(stream {{.RequestImportPkgName}}.{{.ServiceName}}_{{.MethodName}}Server) error {
panic("{{.Prompt}}")
panic("implement me")
// fill in the business logic code here
// example:
@@ -118,7 +118,7 @@ func (s *{{.LowerServiceName}}) {{.MethodName}}(stream {{.RequestImportPkgName}}
{{else if eq .InvokeType 2}}
{{.Comment}}
func (s *{{.LowerServiceName}}) {{.MethodName}}(req *{{.RequestImportPkgName}}.{{.Request}}, stream {{.ReplyImportPkgName}}.{{.ServiceName}}_{{.MethodName}}Server) error {
panic("{{.Prompt}}")
panic("implement me")
// fill in the business logic code here
// example:
@@ -155,7 +155,7 @@ func (s *{{.LowerServiceName}}) {{.MethodName}}(req *{{.RequestImportPkgName}}.{
{{else if eq .InvokeType 3}}
{{.Comment}}
func (s *{{.LowerServiceName}}) {{.MethodName}}(stream {{.RequestImportPkgName}}.{{.ServiceName}}_{{.MethodName}}Server) error {
panic("{{.Prompt}}")
panic("implement me")
// fill in the business logic code here
// example:
@@ -199,7 +199,7 @@ func (s *{{.LowerServiceName}}) {{.MethodName}}(stream {{.RequestImportPkgName}}
{{else}}
{{.Comment}}
func (s *{{.LowerServiceName}}) {{.MethodName}}(ctx context.Context, req *{{.RequestImportPkgName}}.{{.Request}}) (*{{.ReplyImportPkgName}}.{{.Reply}}, error) {
panic("{{.Prompt}}")
panic("implement me")
// fill in the business logic code here
// example:
@@ -229,8 +229,6 @@ func (s *{{.LowerServiceName}}) {{.MethodName}}(ctx context.Context, req *{{.Req
{{end}}
{{- end}}
// ---------- Do not delete or move this split line, this is the merge code marker ----------
{{- end}}
`
@@ -546,9 +544,9 @@ import (
{{- range .PbServices}}
// {{.LowerName}} business-level rpc error codes.
// the _{{.LowerName}}NO value range is 1~100, if the same error code is used, it will cause panic.
// the _{{.LowerName}}NO value range is 1~999, if the same error code is used, it will cause panic.
var (
_{{.LowerName}}NO = {{.RandNumber}}
_{{.LowerName}}NO = 1
_{{.LowerName}}Name = "{{.LowerName}}"
_{{.LowerName}}BaseCode = errcode.RCode(_{{.LowerName}}NO)
// --blank line--
@@ -559,8 +557,6 @@ var (
// error codes are globally unique, adding 1 to the previous error code
)
// ---------- Do not delete or move this split line, this is the merge code marker ----------
{{- end}}
`
)

View File

@@ -19,7 +19,6 @@ type ServiceMethod struct {
Reply string // e.g. CreateReply
ReplyFields []*Field
Comment string // e.g. Create a record
Prompt string // from comments, used in AI assistant
InvokeType int // 0:unary, 1: client-side streaming, 2: server-side streaming, 3: bidirectional streaming
ServiceName string // Greeter
@@ -100,8 +99,8 @@ func parsePbService(s *protogen.Service, protoFileDir string, moduleName string)
Reply: m.Output.GoIdent.GoName,
ReplyFields: getFields(m.Output),
Comment: comment,
Prompt: getPrompt(m, comment),
InvokeType: getInvokeType(m.Desc.IsStreamingClient(), m.Desc.IsStreamingServer()),
InvokeType: getInvokeType(m.Desc.IsStreamingClient(), m.Desc.IsStreamingServer()),
ServiceName: s.GoName,
LowerServiceName: strings.ToLower(s.GoName[:1]) + s.GoName[1:],
@@ -165,18 +164,6 @@ func getMethodComment(m *protogen.Method) string {
return commentPrefix + "......"
}
func getPrompt(m *protogen.Method, comment string) string {
if strings.HasSuffix(comment, "......") {
return "prompt: implement me"
}
prompt := strings.TrimPrefix(comment, "// "+m.GoName)
prompt = strings.TrimSpace(prompt)
prompt = strings.ReplaceAll(prompt, "\n//", " ")
prompt = strings.ReplaceAll(prompt, "\r//", " ")
prompt = strings.ReplaceAll(prompt, "\r\n//", " ")
return "prompt: " + prompt
}
func getComment(commentSet protogen.CommentSet) string {
comment1 := getCommentStr(commentSet.Leading.String())
comment2 := getCommentStr(commentSet.Trailing.String())

View File

@@ -14,6 +14,7 @@ import (
"google.golang.org/protobuf/types/pluginpb"
"github.com/go-dev-frame/sponge/cmd/protoc-gen-go-rpc-tmpl/internal/generate/service"
"github.com/go-dev-frame/sponge/pkg/gofile"
)
const (
@@ -106,7 +107,7 @@ func saveRPCTmplFiles(f *protogen.File, moduleName string, serverName string, tm
}
filePath = filenamePrefix + "_client_test.go"
err = saveFile(moduleName, serverName, tmplOut, filePath, testTmplFileContent, true, suitedMonoRepo)
err = saveFile(moduleName, serverName, tmplOut, filePath, testTmplFileContent, false, suitedMonoRepo)
if err != nil {
return err
}
@@ -136,6 +137,7 @@ func saveFile(moduleName string, serverName string, out string, filePath string,
_, name := filepath.Split(filePath)
file := out + "/" + name
if !isNeedCovered && isExists(file) {
removeOldGenFile(file)
file += ".gen" + time.Now().Format("20060102T150405")
}
@@ -158,6 +160,7 @@ func saveFileSimple(out string, filePath string, content []byte, isNeedCovered b
_, name := filepath.Split(filePath)
file := out + "/" + name
if !isNeedCovered && isExists(file) {
removeOldGenFile(file)
file += ".gen" + time.Now().Format("20060102T150405")
}
@@ -172,6 +175,13 @@ func isExists(f string) bool {
return true
}
func removeOldGenFile(file string) {
oldGenFiles := gofile.FuzzyMatchFiles(file + ".gen*")
for _, oldGenFile := range oldGenFiles {
_ = os.Remove(oldGenFile)
}
}
func firstLetterToUpper(s string) []byte {
if s == "" {
return []byte{}