diff --git a/console/commands/swagger.go b/console/commands/swagger.go index 3a28ef9..574bffb 100644 --- a/console/commands/swagger.go +++ b/console/commands/swagger.go @@ -6,9 +6,9 @@ import ( "github.com/ctfang/command" "github.com/go-home-admin/toolset/console/commands/openapi" "github.com/go-home-admin/toolset/parser" - "io/ioutil" "os" path2 "path" + "regexp" "strconv" "strings" ) @@ -106,7 +106,7 @@ func (SwaggerCommand) Execute(input command.Input) { if !parser.DirIsExist(path2.Dir(out)) { _ = os.MkdirAll(path2.Dir(out), 0760) } - err = ioutil.WriteFile(out, by, 0766) + err = os.WriteFile(out, by, 0766) if err != nil { fmt.Println("gen openapi.json err " + err.Error() + ", out = " + out) } else { @@ -191,17 +191,20 @@ func messageToParameters(message string, nowDirProtoc []parser.ProtocFileParser, return got } for _, option := range protocMessage.Attr { + doc, isRequired := filterRequired(option.Doc) + doc = getTitle(doc) if option.Repeated { if isProtoBaseType(option.Ty) { // 基础类型的数组 attr := &openapi.Parameter{ Name: option.Name, - Description: option.Doc, + Description: doc, Enum: nil, Format: option.Ty, In: "query", + Required: isRequired, Items: &openapi.Schema{ - Description: getTitle(option.Doc), + Description: doc, Type: getProtoToSwagger(option.Ty), Format: option.Ty, }, @@ -211,12 +214,14 @@ func messageToParameters(message string, nowDirProtoc []parser.ProtocFileParser, } else { // 引用其他对象 attr := &openapi.Parameter{ - Name: option.Name, - Type: "array", - In: "query", + Name: option.Name, + Description: doc, + Type: "array", + In: "query", + Required: isRequired, Items: &openapi.Schema{ Ref: getRef(pge, option.Ty), - Description: getTitle(option.Doc), + Description: doc, Type: "object", Format: option.Ty, }, @@ -227,19 +232,21 @@ func messageToParameters(message string, nowDirProtoc []parser.ProtocFileParser, attr := &openapi.Parameter{ Name: option.Name, In: "query", - Description: getTitle(option.Doc), + Description: doc, Type: getProtoToSwagger(option.Ty), Format: option.Ty, + Required: isRequired, } got = append(got, attr) } else { // 引用其他对象 attr := &openapi.Parameter{ Name: option.Name, - Description: getTitle(option.Doc), + Description: doc, Type: getProtoToSwagger(option.Ty), Format: option.Ty, In: "query", + Required: isRequired, Schema: &openapi.Schema{ Type: "object", Description: getTitle(option.Doc), @@ -267,15 +274,20 @@ func messageToSchemas(pge string, message parser.Message, swagger *openapi.Spec) schema := &openapi.Schema{} schema.Description = message.Doc properties := make(map[string]*openapi.Schema) - + var requireArr []string for _, option := range message.Attr { + doc, isRequired := filterRequired(option.Doc) + doc = getTitle(doc) + if !isRequired { + requireArr = append(requireArr, option.Name) + } if option.Repeated { if isProtoBaseType(option.Ty) { // 基础类型的数组 attr := &openapi.Schema{ Type: "array", Items: &openapi.Schema{ - Description: getTitle(option.Doc), + Description: doc, Type: getProtoToSwagger(option.Ty), Format: option.Ty, }, @@ -286,7 +298,7 @@ func messageToSchemas(pge string, message parser.Message, swagger *openapi.Spec) name = pge + "." + option.Name + "_" + name swagger.Definitions[defName(name)] = parameter attr := &openapi.Schema{ - Description: getTitle(option.Doc), + Description: doc, Ref: "#/definitions/" + defName(name), // 嵌套肯定是本包 } properties[option.Name] = attr @@ -296,7 +308,7 @@ func messageToSchemas(pge string, message parser.Message, swagger *openapi.Spec) Type: "array", Items: &openapi.Schema{ Ref: getRef(pge, option.Ty), - Description: getTitle(option.Doc), + Description: doc, Type: "object", Format: option.Ty, }, @@ -305,7 +317,7 @@ func messageToSchemas(pge string, message parser.Message, swagger *openapi.Spec) } } else if isProtoBaseType(option.Ty) { attr := &openapi.Schema{ - Description: getTitle(option.Doc), + Description: doc, Type: getProtoToSwagger(option.Ty), Format: option.Ty, } @@ -315,13 +327,13 @@ func messageToSchemas(pge string, message parser.Message, swagger *openapi.Spec) name = pge + "." + option.Name + "_" + name swagger.Definitions[defName(name)] = parameter attr := &openapi.Schema{ - Description: getTitle(option.Doc), + Description: doc, Ref: "#/definitions/" + defName(name), // 嵌套肯定是本包 } properties[option.Name] = attr } else { attr := &openapi.Schema{ - Description: getTitle(option.Doc), + Description: doc, Ref: getRef(pge, option.Ty), } properties[option.Name] = attr @@ -330,6 +342,7 @@ func messageToSchemas(pge string, message parser.Message, swagger *openapi.Spec) schema.Type = "object" schema.Properties = properties + schema.Required = requireArr return pge + "." + message.Name, schema } @@ -438,3 +451,13 @@ func findMessage(message string, nowDirProtoc []parser.ProtocFileParser, allProt } return nil, "" } + +func filterRequired(doc string) (string, bool) { + re := regexp.MustCompile("@(tag|Tag|TAG)\\(\\\"([a-zA-Z]+)\"[,\\s\\\"]+([a-zA-Z]+)\"\\)") + arr := re.FindStringSubmatch(doc) + if len(arr) == 4 && strings.ToLower(arr[2]) == "binding" && strings.ToLower(arr[3]) == "required" { + doc = strings.Trim(re.ReplaceAllString(doc, ""), "\r\n") + return doc, true + } + return doc, false +}