feat: support custom table primary key type and name

This commit is contained in:
zhuyasen
2024-11-04 20:30:28 +08:00
parent 4deb203a02
commit 6eb9e51b9d
51 changed files with 5514 additions and 344 deletions

View File

@@ -0,0 +1,386 @@
package parser
import (
"encoding/json"
"fmt"
"strings"
"text/template"
"github.com/jinzhu/inflection"
)
// CrudInfo crud info for cache, dao, handler, service, protobuf, error
type CrudInfo struct {
TableNameCamel string `json:"tableNameCamel"` // camel case, example: FooBar
TableNameCamelFCL string `json:"tableNameCamelFCL"` // camel case and first character lower, example: fooBar
TableNamePluralCamel string `json:"tableNamePluralCamel"` // plural, camel case, example: FooBars
TableNamePluralCamelFCL string `json:"tableNamePluralCamelFCL"` // plural, camel case, example: fooBars
ColumnName string `json:"columnName"` // column name, example: first_name
ColumnNameCamel string `json:"columnNameCamel"` // column name, camel case, example: FirstName
ColumnNameCamelFCL string `json:"columnNameCamelFCL"` // column name, camel case and first character lower, example: firstName
ColumnNamePluralCamel string `json:"columnNamePluralCamel"` // column name, plural, camel case, example: FirstNames
ColumnNamePluralCamelFCL string `json:"columnNamePluralCamelFCL"` // column name, plural, camel case and first character lower, example: firstNames
GoType string `json:"goType"` // go type, example: string, uint64
GoTypeFCU string `json:"goTypeFCU"` // go type, first character upper, example: String, Uint64
ProtoType string `json:"protoType"` // proto type, example: string, uint64
IsStringType bool `json:"isStringType"` // go type is string or not
PrimaryKeyColumnName string `json:"PrimaryKeyColumnName"` // primary key, example: id
IsCommonType bool `json:"isCommonType"` // custom primary key name and type
IsStandardPrimaryKey bool `json:"isStandardPrimaryKey"` // standard primary key id
}
func isDesiredGoType(t string) bool {
switch t {
case "string", "uint64", "int64", "uint", "int", "uint32", "int32": //nolint
return true
}
return false
}
func setCrudInfo(field tmplField) *CrudInfo {
primaryKeyName := ""
if field.IsPrimaryKey {
primaryKeyName = field.ColName
}
pluralName := inflection.Plural(field.Name)
return &CrudInfo{
ColumnName: field.ColName,
ColumnNameCamel: field.Name,
ColumnNameCamelFCL: customFirstLetterToLower(field.Name),
ColumnNamePluralCamel: customEndOfLetterToLower(field.Name, pluralName),
ColumnNamePluralCamelFCL: customFirstLetterToLower(customEndOfLetterToLower(field.Name, pluralName)),
GoType: field.GoType,
GoTypeFCU: firstLetterToUpper(field.GoType),
ProtoType: simpleGoTypeToProtoType(field.GoType),
IsStringType: field.GoType == "string",
PrimaryKeyColumnName: primaryKeyName,
IsStandardPrimaryKey: field.ColName == "id",
}
}
func newCrudInfo(data tmplData) *CrudInfo {
var info *CrudInfo
for _, field := range data.Fields {
if field.IsPrimaryKey {
info = setCrudInfo(field)
break
}
}
// if not found primary key, find the first xxx_id column as primary key
if info == nil {
for _, field := range data.Fields {
if strings.HasSuffix(field.ColName, "_id") && isDesiredGoType(field.GoType) { // xxx_id
info = setCrudInfo(field)
break
}
}
}
// if not found xxx_id field, use the first column as primary key
if info == nil {
for _, field := range data.Fields {
if isDesiredGoType(field.GoType) {
info = setCrudInfo(field)
break
}
}
if len(data.Fields) > 0 {
info = setCrudInfo(data.Fields[0])
} else {
return nil
}
}
info.TableNameCamel = data.TableName
info.TableNameCamelFCL = data.TName
pluralName := inflection.Plural(data.TableName)
info.TableNamePluralCamel = customEndOfLetterToLower(data.TableName, pluralName)
info.TableNamePluralCamelFCL = customFirstLetterToLower(customEndOfLetterToLower(data.TableName, pluralName))
return info
}
func (info *CrudInfo) getCode() string {
if info == nil {
return ""
}
pkData, _ := json.Marshal(info)
return string(pkData)
}
func (info *CrudInfo) CheckCommonType() bool {
if info == nil {
return false
}
return info.IsCommonType
}
func (info *CrudInfo) isIDPrimaryKey() bool {
if info == nil {
return false
}
if info.ColumnName == "id" && (info.GoType == "uint64" ||
info.GoType == "int64" ||
info.GoType == "uint" ||
info.GoType == "int" ||
info.GoType == "uint32" ||
info.GoType == "int32") {
return true
}
return false
}
func (info *CrudInfo) GetGRPCProtoValidation() string {
if info == nil {
return ""
}
if info.ProtoType == "string" {
return `[(validate.rules).string.min_len = 1]`
}
return fmt.Sprintf(`[(validate.rules).%s.gt = 0]`, info.ProtoType)
}
func (info *CrudInfo) GetWebProtoValidation() string {
if info == nil {
return ""
}
if info.ProtoType == "string" {
return fmt.Sprintf(`[(validate.rules).string.min_len = 1, (tagger.tags) = "uri:\"%s\""]`, info.ColumnNameCamelFCL)
}
return fmt.Sprintf(`[(validate.rules).%s.gt = 0, (tagger.tags) = "uri:\"%s\""]`, info.ProtoType, info.ColumnNameCamelFCL)
}
func getCommonHandlerStructCodes(data tmplData, jsonNamedType int) (string, error) {
newFields := []tmplField{}
for _, field := range data.Fields {
if jsonNamedType == 0 { // snake case
field.JSONName = customToSnake(field.ColName)
} else {
field.JSONName = customToCamel(field.ColName) // camel case (default)
}
newFields = append(newFields, field)
}
data.Fields = newFields
postStructCode, err := tmplExecuteWithFilter(data, handlerCreateStructCommonTmpl)
if err != nil {
return "", fmt.Errorf("handlerCreateStructTmpl error: %v", err)
}
putStructCode, err := tmplExecuteWithFilter(data, handlerUpdateStructCommonTmpl, columnID)
if err != nil {
return "", fmt.Errorf("handlerUpdateStructTmpl error: %v", err)
}
getStructCode, err := tmplExecuteWithFilter(data, handlerDetailStructCommonTmpl, columnID, columnCreatedAt, columnUpdatedAt)
if err != nil {
return "", fmt.Errorf("handlerDetailStructTmpl error: %v", err)
}
return postStructCode + putStructCode + getStructCode, nil
}
func getCommonServiceStructCode(data tmplData) (string, error) {
builder := strings.Builder{}
err := serviceStructCommonTmpl.Execute(&builder, data)
if err != nil {
return "", err
}
code := builder.String()
serviceCreateStructCode, err := tmplExecuteWithFilter(data, serviceCreateStructCommonTmpl)
if err != nil {
return "", fmt.Errorf("handle serviceCreateStructTmpl error: %v", err)
}
serviceCreateStructCode = strings.ReplaceAll(serviceCreateStructCode, "ID:", "Id:")
serviceUpdateStructCode, err := tmplExecuteWithFilter(data, serviceUpdateStructCommonTmpl, columnID)
if err != nil {
return "", fmt.Errorf("handle serviceUpdateStructTmpl error: %v", err)
}
serviceUpdateStructCode = strings.ReplaceAll(serviceUpdateStructCode, "ID:", "Id:")
code = strings.ReplaceAll(code, "// serviceCreateStructCode", serviceCreateStructCode)
code = strings.ReplaceAll(code, "// serviceUpdateStructCode", serviceUpdateStructCode)
return code, nil
}
func getCommonProtoFileCode(data tmplData, jsonNamedType int, isWebProto bool, isExtendedAPI bool) (string, error) {
data.Fields = goTypeToProto(data.Fields, jsonNamedType, true)
var err error
builder := strings.Builder{}
if isWebProto {
if isExtendedAPI {
err = protoFileForWebCommonTmpl.Execute(&builder, data)
} else {
err = protoFileForSimpleWebCommonTmpl.Execute(&builder, data)
}
if err != nil {
return "", err
}
} else {
if isExtendedAPI {
err = protoFileCommonTmpl.Execute(&builder, data)
} else {
err = protoFileSimpleCommonTmpl.Execute(&builder, data)
}
if err != nil {
return "", err
}
}
code := builder.String()
protoMessageCreateCode, err := tmplExecuteWithFilter2(data, protoMessageCreateCommonTmpl)
if err != nil {
return "", fmt.Errorf("handle protoMessageCreateCommonTmpl error: %v", err)
}
protoMessageUpdateCode, err := tmplExecuteWithFilter2(data, protoMessageUpdateCommonTmpl, columnID)
if err != nil {
return "", fmt.Errorf("handle protoMessageUpdateCommonTmpl error: %v", err)
}
if !isWebProto {
srcStr := fmt.Sprintf(`, (tagger.tags) = "uri:\"%s\""`, getProtoFieldName(data.Fields))
protoMessageUpdateCode = strings.ReplaceAll(protoMessageUpdateCode, srcStr, "")
}
protoMessageDetailCode, err := tmplExecuteWithFilter2(data, protoMessageDetailCommonTmpl, columnID, columnCreatedAt, columnUpdatedAt)
if err != nil {
return "", fmt.Errorf("handle protoMessageDetailCommonTmpl error: %v", err)
}
code = strings.ReplaceAll(code, "// protoMessageCreateCode", protoMessageCreateCode)
code = strings.ReplaceAll(code, "// protoMessageUpdateCode", protoMessageUpdateCode)
code = strings.ReplaceAll(code, "// protoMessageDetailCode", protoMessageDetailCode)
code = strings.ReplaceAll(code, "*time.Time", "int64")
code = strings.ReplaceAll(code, "time.Time", "int64")
code = strings.ReplaceAll(code, "left_curly_bracket", "{")
code = strings.ReplaceAll(code, "right_curly_bracket", "}")
code = adaptedDbType2(data, isWebProto, code)
return code, nil
}
func tmplExecuteWithFilter2(data tmplData, tmpl *template.Template, reservedColumns ...string) (string, error) {
var newFields = []tmplField{}
for _, field := range data.Fields {
if isIgnoreFields(field.ColName, reservedColumns...) {
continue
}
newFields = append(newFields, field)
}
data.Fields = newFields
builder := strings.Builder{}
err := tmpl.Execute(&builder, data)
if err != nil {
return "", fmt.Errorf("tmpl.Execute error: %v", err)
}
return builder.String(), nil
}
// nolint
func simpleGoTypeToProtoType(goType string) string {
var protoType string
switch goType {
case "int", "int32":
protoType = "int32"
case "uint", "uint32":
protoType = "uint32"
case "int64":
protoType = "int64"
case "uint64":
protoType = "uint64"
case "string":
protoType = "string"
case "time.Time", "*time.Time":
protoType = "string"
case "float32":
protoType = "float"
case "float64":
protoType = "double"
case goTypeInts, "[]int64":
protoType = "repeated int64"
case "[]int32":
protoType = "repeated int32"
case "[]byte":
protoType = "string"
case goTypeStrings:
protoType = "repeated string"
case jsonTypeName:
protoType = "string"
default:
protoType = "string"
}
return protoType
}
func adaptedDbType2(data tmplData, isWebProto bool, code string) string {
if isWebProto {
code = replaceProtoMessageFieldCode(code, webDefaultProtoMessageFieldCodes)
} else {
code = replaceProtoMessageFieldCode(code, grpcDefaultProtoMessageFieldCodes)
}
if data.ProtoSubStructs != "" {
code += "\n" + data.ProtoSubStructs
}
return code
}
func firstLetterToUpper(str string) string {
if len(str) == 0 {
return str
}
if (str[0] >= 'A' && str[0] <= 'Z') || (str[0] >= 'a' && str[0] <= 'z') {
return strings.ToUpper(str[:1]) + str[1:]
}
return str
}
func customFirstLetterToLower(str string) string {
str = firstLetterToLower(str)
if len(str) == 2 {
if str == "iD" {
str = "id"
} else if str == "iP" {
str = "ip"
}
} else if len(str) == 3 {
if str == "iDs" {
str = "ids"
} else if str == "iPs" {
str = "ips"
}
}
return str
}
func customEndOfLetterToLower(srcStr string, str string) string {
l := len(str) - len(srcStr)
if l == 1 {
if str[len(str)-1] == 'S' {
return str[:len(str)-1] + "s"
}
} else if l == 2 {
if str[len(str)-2:] == "ES" {
return str[:len(str)-2] + "es"
}
}
return str
}

View File

@@ -0,0 +1,740 @@
package parser
import (
"sync"
"text/template"
"github.com/pkg/errors"
)
// nolint
var (
handlerCreateStructCommonTmpl *template.Template
handlerCreateStructCommonTmplRaw = `
// Create{{.TableName}}Request request params
type Create{{.TableName}}Request struct {
{{- range .Fields}}
{{.Name}} {{.GoType}} ` + "`" + `json:"{{.JSONName}}" binding:""` + "`" + `{{if .Comment}} // {{.Comment}}{{end}}
{{- end}}
}
`
handlerUpdateStructCommonTmpl *template.Template
handlerUpdateStructCommonTmplRaw = `
// Update{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Request request params
type Update{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Request struct {
{{- range .Fields}}
{{.Name}} {{.GoType}} ` + "`" + `json:"{{.JSONName}}" binding:""` + "`" + `{{if .Comment}} // {{.Comment}}{{end}}
{{- end}}
}
`
handlerDetailStructCommonTmpl *template.Template
handlerDetailStructCommonTmplRaw = `
// {{.TableName}}ObjDetail detail
type {{.TableName}}ObjDetail struct {
{{- range .Fields}}
{{.Name}} {{.GoType}} ` + "`" + `json:"{{.JSONName}}"` + "`" + `{{if .Comment}} // {{.Comment}}{{end}}
{{- end}}
}`
protoFileCommonTmpl *template.Template
protoFileCommonTmplRaw = `syntax = "proto3";
package api.serverNameExample.v1;
import "api/types/types.proto";
import "validate/validate.proto";
option go_package = "github.com/zhufuyi/sponge/api/serverNameExample/v1;v1";
service {{.TName}} {
// create {{.TName}}
rpc Create(Create{{.TableName}}Request) returns (Create{{.TableName}}Reply) {}
// delete {{.TName}} by {{.CrudInfo.ColumnNameCamelFCL}}
rpc DeleteBy{{.CrudInfo.ColumnNameCamel}}(Delete{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Request) returns (Delete{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Reply) {}
// update {{.TName}} by {{.CrudInfo.ColumnNameCamelFCL}}
rpc UpdateBy{{.CrudInfo.ColumnNameCamel}}(Update{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Request) returns (Update{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Reply) {}
// get {{.TName}} by {{.CrudInfo.ColumnNameCamelFCL}}
rpc GetBy{{.CrudInfo.ColumnNameCamel}}(Get{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Request) returns (Get{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Reply) {}
// list of {{.TName}} by query parameters
rpc List(List{{.TableName}}Request) returns (List{{.TableName}}Reply) {}
// delete {{.TName}} by batch {{.CrudInfo.ColumnNameCamelFCL}}
rpc DeleteBy{{.CrudInfo.ColumnNamePluralCamel}}(Delete{{.TableName}}By{{.CrudInfo.ColumnNamePluralCamel}}Request) returns (Delete{{.TableName}}By{{.CrudInfo.ColumnNamePluralCamel}}Reply) {}
// get {{.TName}} by condition
rpc GetByCondition(Get{{.TableName}}ByConditionRequest) returns (Get{{.TableName}}ByConditionReply) {}
// list of {{.TName}} by batch {{.CrudInfo.ColumnNameCamelFCL}}
rpc ListBy{{.CrudInfo.ColumnNamePluralCamel}}(List{{.TableName}}By{{.CrudInfo.ColumnNamePluralCamel}}Request) returns (List{{.TableName}}By{{.CrudInfo.ColumnNamePluralCamel}}Reply) {}
// list {{.TName}} by last {{.CrudInfo.ColumnNameCamelFCL}}
rpc ListByLast{{.CrudInfo.ColumnNameCamel}}(List{{.TableName}}ByLast{{.CrudInfo.ColumnNameCamel}}Request) returns (List{{.TableName}}ByLast{{.CrudInfo.ColumnNameCamel}}Reply) {}
}
/*
Notes for defining message fields:
1. Suggest using camel case style naming for message field names, such as firstName, lastName, etc.
2. If the message field name ending in 'id', it is recommended to use xxxID naming format, such as userID, orderID, etc.
3. Add validate rules https://github.com/envoyproxy/protoc-gen-validate#constraint-rules, such as:
uint64 id = 1 [(validate.rules).uint64.gte = 1];
*/
// protoMessageCreateCode
message Create{{.TableName}}Reply {
{{.CrudInfo.ProtoType}} {{.CrudInfo.ColumnNameCamelFCL}} = 1;
}
message Delete{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Request {
{{.CrudInfo.ProtoType}} {{.CrudInfo.ColumnNameCamelFCL}} = 1 {{.CrudInfo.GetGRPCProtoValidation}};
}
message Delete{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Reply {
}
// protoMessageUpdateCode
message Update{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Reply {
}
// protoMessageDetailCode
message Get{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Request {
{{.CrudInfo.ProtoType}} {{.CrudInfo.ColumnNameCamelFCL}} = 1 {{.CrudInfo.GetGRPCProtoValidation}};
}
message Get{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Reply {
{{.TableName}} {{.TName}} = 1;
}
message List{{.TableName}}Request {
api.types.Params params = 1;
}
message List{{.TableName}}Reply {
int64 total = 1;
repeated {{.TableName}} {{.CrudInfo.TableNamePluralCamelFCL}} = 2;
}
message Delete{{.TableName}}By{{.CrudInfo.ColumnNamePluralCamel}}Request {
repeated {{.CrudInfo.ProtoType}} {{.CrudInfo.ColumnNamePluralCamelFCL}} = 1 [(validate.rules).repeated.min_items = 1];
}
message Delete{{.TableName}}By{{.CrudInfo.ColumnNamePluralCamel}}Reply {
}
message Get{{.TableName}}ByConditionRequest {
types.Conditions conditions = 1;
}
message Get{{.TableName}}ByConditionReply {
{{.TableName}} {{.TName}} = 1;
}
message List{{.TableName}}By{{.CrudInfo.ColumnNamePluralCamel}}Request {
repeated {{.CrudInfo.ProtoType}} {{.CrudInfo.ColumnNamePluralCamelFCL}} = 1 [(validate.rules).repeated.min_items = 1];
}
message List{{.TableName}}By{{.CrudInfo.ColumnNamePluralCamel}}Reply {
repeated {{.TableName}} {{.CrudInfo.TableNamePluralCamelFCL}} = 1;
}
message List{{.TableName}}ByLast{{.CrudInfo.ColumnNameCamel}}Request {
{{.CrudInfo.ProtoType}} last{{.CrudInfo.ColumnNameCamel}} = 1;
uint32 limit = 2 [(validate.rules).uint32.gt = 0]; // limit size per page
string sort = 3; // sort by column name of table, default is -{{.CrudInfo.ColumnName}}, the - sign indicates descending order.
}
message List{{.TableName}}ByLast{{.CrudInfo.ColumnNameCamel}}Reply {
repeated {{.TableName}} {{.CrudInfo.TableNamePluralCamelFCL}} = 1;
}
`
protoFileSimpleCommonTmpl *template.Template
protoFileSimpleCommonTmplRaw = `syntax = "proto3";
package api.serverNameExample.v1;
import "api/types/types.proto";
import "validate/validate.proto";
option go_package = "github.com/zhufuyi/sponge/api/serverNameExample/v1;v1";
service {{.TName}} {
// create {{.TName}}
rpc Create(Create{{.TableName}}Request) returns (Create{{.TableName}}Reply) {}
// delete {{.TName}} by {{.CrudInfo.ColumnNameCamelFCL}}
rpc DeleteBy{{.CrudInfo.ColumnNameCamel}}(Delete{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Request) returns (Delete{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Reply) {}
// update {{.TName}} by {{.CrudInfo.ColumnNameCamelFCL}}
rpc UpdateBy{{.CrudInfo.ColumnNameCamel}}(Update{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Request) returns (Update{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Reply) {}
// get {{.TName}} by {{.CrudInfo.ColumnNameCamelFCL}}
rpc GetBy{{.CrudInfo.ColumnNameCamel}}(Get{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Request) returns (Get{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Reply) {}
// list of {{.TName}} by query parameters
rpc List(List{{.TableName}}Request) returns (List{{.TableName}}Reply) {}
}
/*
Notes for defining message fields:
1. Suggest using camel case style naming for message field names, such as firstName, lastName, etc.
2. If the message field name ending in 'id', it is recommended to use xxxID naming format, such as userID, orderID, etc.
3. Add validate rules https://github.com/envoyproxy/protoc-gen-validate#constraint-rules, such as:
uint64 id = 1 [(validate.rules).uint64.gte = 1];
*/
// protoMessageCreateCode
message Create{{.TableName}}Reply {
{{.CrudInfo.ProtoType}} {{.CrudInfo.ColumnNameCamelFCL}} = 1;
}
message Delete{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Request {
{{.CrudInfo.ProtoType}} {{.CrudInfo.ColumnNameCamelFCL}} = 1 {{.CrudInfo.GetGRPCProtoValidation}};
}
message Delete{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Reply {
}
// protoMessageUpdateCode
message Update{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Reply {
}
// protoMessageDetailCode
message Get{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Request {
{{.CrudInfo.ProtoType}} {{.CrudInfo.ColumnNameCamelFCL}} = 1 {{.CrudInfo.GetGRPCProtoValidation}};
}
message Get{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Reply {
{{.TableName}} {{.TName}} = 1;
}
message List{{.TableName}}Request {
api.types.Params params = 1;
}
message List{{.TableName}}Reply {
int64 total = 1;
repeated {{.TableName}} {{.CrudInfo.TableNamePluralCamelFCL}} = 2;
}
`
protoFileForWebCommonTmpl *template.Template
protoFileForWebCommonTmplRaw = `syntax = "proto3";
package api.serverNameExample.v1;
import "api/types/types.proto";
import "google/api/annotations.proto";
import "protoc-gen-openapiv2/options/annotations.proto";
import "tagger/tagger.proto";
import "validate/validate.proto";
option go_package = "github.com/zhufuyi/sponge/api/serverNameExample/v1;v1";
/*
Reference https://github.com/grpc-ecosystem/grpc-gateway/blob/db7fbefff7c04877cdb32e16d4a248a024428207/examples/internal/proto/examplepb/a_bit_of_everything.proto
Default settings for generating swagger documents
NOTE: because json does not support 64 bits, the int64 and uint64 types under *.swagger.json are automatically converted to string types
Tips: add swagger option to rpc method, example:
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
summary: "get user by id",
description: "get user by id",
security: {
security_requirement: {
key: "BearerAuth";
value: {}
}
}
};
*/
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_swagger) = {
host: "localhost:8080"
base_path: ""
info: {
title: "serverNameExample api docs";
version: "2.0";
}
schemes: HTTP;
schemes: HTTPS;
consumes: "application/json";
produces: "application/json";
security_definitions: {
security: {
key: "BearerAuth";
value: {
type: TYPE_API_KEY;
in: IN_HEADER;
name: "Authorization";
description: "Type Bearer your-jwt-token to Value";
}
}
}
};
service {{.TName}} {
// create {{.TName}}
rpc Create(Create{{.TableName}}Request) returns (Create{{.TableName}}Reply) {
option (google.api.http) = {
post: "/api/v1/{{.TName}}"
body: "*"
};
}
// delete {{.TName}} by {{.CrudInfo.ColumnNameCamelFCL}}
rpc DeleteBy{{.CrudInfo.ColumnNameCamel}}(Delete{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Request) returns (Delete{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Reply) {
option (google.api.http) = {
delete: "/api/v1/{{.TName}}/left_curly_bracket{{.CrudInfo.ColumnNameCamelFCL}}right_curly_bracket"
};
}
// update {{.TName}} by {{.CrudInfo.ColumnNameCamelFCL}}
rpc UpdateBy{{.CrudInfo.ColumnNameCamel}}(Update{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Request) returns (Update{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Reply) {
option (google.api.http) = {
put: "/api/v1/{{.TName}}/left_curly_bracket{{.CrudInfo.ColumnNameCamelFCL}}right_curly_bracket"
body: "*"
};
}
// get {{.TName}} by {{.CrudInfo.ColumnNameCamelFCL}}
rpc GetBy{{.CrudInfo.ColumnNameCamel}}(Get{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Request) returns (Get{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Reply) {
option (google.api.http) = {
get: "/api/v1/{{.TName}}/left_curly_bracket{{.CrudInfo.ColumnNameCamelFCL}}right_curly_bracket"
};
}
// list of {{.TName}} by query parameters
rpc List(List{{.TableName}}Request) returns (List{{.TableName}}Reply) {
option (google.api.http) = {
post: "/api/v1/{{.TName}}/list"
body: "*"
};
}
// delete {{.TName}} by batch {{.CrudInfo.ColumnNameCamelFCL}}
rpc DeleteBy{{.CrudInfo.ColumnNamePluralCamel}}(Delete{{.TableName}}By{{.CrudInfo.ColumnNamePluralCamel}}Request) returns (Delete{{.TableName}}By{{.CrudInfo.ColumnNamePluralCamel}}Reply) {
option (google.api.http) = {
post: "/api/v1/{{.TName}}/delete/ids"
body: "*"
};
}
// get {{.TName}} by condition
rpc GetByCondition(Get{{.TableName}}ByConditionRequest) returns (Get{{.TableName}}ByConditionReply) {
option (google.api.http) = {
post: "/api/v1/{{.TName}}/condition"
body: "*"
};
}
// list of {{.TName}} by batch {{.CrudInfo.ColumnNameCamelFCL}}
rpc ListBy{{.CrudInfo.ColumnNamePluralCamel}}(List{{.TableName}}By{{.CrudInfo.ColumnNamePluralCamel}}Request) returns (List{{.TableName}}By{{.CrudInfo.ColumnNamePluralCamel}}Reply) {
option (google.api.http) = {
post: "/api/v1/{{.TName}}/list/ids"
body: "*"
};
}
// list {{.TName}} by last {{.CrudInfo.ColumnNameCamelFCL}}
rpc ListByLast{{.CrudInfo.ColumnNameCamel}}(List{{.TableName}}ByLast{{.CrudInfo.ColumnNameCamel}}Request) returns (List{{.TableName}}ByLast{{.CrudInfo.ColumnNameCamel}}Reply) {
option (google.api.http) = {
get: "/api/v1/{{.TName}}/list"
};
}
}
/*
Notes for defining message fields:
1. Suggest using camel case style naming for message field names, such as firstName, lastName, etc.
2. If the message field name ending in 'id', it is recommended to use xxxID naming format, such as userID, orderID, etc.
3. Add validate rules https://github.com/envoyproxy/protoc-gen-validate#constraint-rules, such as:
uint64 id = 1 [(validate.rules).uint64.gte = 1];
If used to generate code that supports the HTTP protocol, notes for defining message fields:
1. If the route contains the path parameter, such as /api/v1/userExample/{id}, the defined
message must contain the name of the path parameter and the name should be added
with a new tag, such as int64 id = 1 [(tagger.tags) = "uri:\"id\""];
2. If the request url is followed by a query parameter, such as /api/v1/getUserExample?name=Tom,
a form tag must be added when defining the query parameter in the message, such as:
string name = 1 [(tagger.tags) = "form:\"name\""].
3. If the message field name contain underscores(such as 'field_name'), it will cause a problem
where the JSON field names of the Swagger request parameters are different from those of the
GRPC JSON tag names. There are two solutions: Solution 1, remove the underline from the
message field name. Option 2, use the tool 'protoc-go-inject-tag' to modify the JSON tag name,
such as: string first_name = 1 ; // @gotags: json:"firstName"
*/
// protoMessageCreateCode
message Create{{.TableName}}Reply {
{{.CrudInfo.ProtoType}} {{.CrudInfo.ColumnNameCamelFCL}} = 1;
}
message Delete{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Request {
{{.CrudInfo.ProtoType}} {{.CrudInfo.ColumnNameCamelFCL}} = 1 {{.CrudInfo.GetWebProtoValidation}};
}
message Delete{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Reply {
}
// protoMessageUpdateCode
message Update{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Reply {
}
// protoMessageDetailCode
message Get{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Request {
{{.CrudInfo.ProtoType}} {{.CrudInfo.ColumnNameCamelFCL}} = 1 {{.CrudInfo.GetWebProtoValidation}};
}
message Get{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Reply {
{{.TableName}} {{.TName}} = 1;
}
message List{{.TableName}}Request {
api.types.Params params = 1;
}
message List{{.TableName}}Reply {
int64 total = 1;
repeated {{.TableName}} {{.CrudInfo.TableNamePluralCamelFCL}} = 2;
}
message Delete{{.TableName}}By{{.CrudInfo.ColumnNamePluralCamel}}Request {
repeated {{.CrudInfo.ProtoType}} {{.CrudInfo.ColumnNamePluralCamelFCL}} = 1 [(validate.rules).repeated.min_items = 1];
}
message Delete{{.TableName}}By{{.CrudInfo.ColumnNamePluralCamel}}Reply {
}
message Get{{.TableName}}ByConditionRequest {
types.Conditions conditions = 1;
}
message Get{{.TableName}}ByConditionReply {
{{.TableName}} {{.TName}} = 1;
}
message List{{.TableName}}By{{.CrudInfo.ColumnNamePluralCamel}}Request {
repeated {{.CrudInfo.ProtoType}} {{.CrudInfo.ColumnNamePluralCamelFCL}} = 1 [(validate.rules).repeated.min_items = 1];
}
message List{{.TableName}}By{{.CrudInfo.ColumnNamePluralCamel}}Reply {
repeated {{.TableName}} {{.CrudInfo.TableNamePluralCamelFCL}} = 1;
}
message List{{.TableName}}ByLast{{.CrudInfo.ColumnNameCamel}}Request {
{{.CrudInfo.ProtoType}} last{{.CrudInfo.ColumnNameCamel}} = 1 [(tagger.tags) = "form:\"last{{.CrudInfo.ColumnNameCamel}}\""];
uint32 limit = 2 [(validate.rules).uint32.gt = 0, (tagger.tags) = "form:\"limit\""]; // limit size per page
string sort = 3 [(tagger.tags) = "form:\"sort\""]; // sort by column name of table, default is -{{.CrudInfo.ColumnName}}, the - sign indicates descending order.
}
message List{{.TableName}}ByLast{{.CrudInfo.ColumnNameCamel}}Reply {
repeated {{.TableName}} {{.CrudInfo.TableNamePluralCamelFCL}} = 1;
}
`
protoFileForSimpleWebCommonTmpl *template.Template
protoFileForSimpleWebCommonTmplRaw = `syntax = "proto3";
package api.serverNameExample.v1;
import "api/types/types.proto";
import "google/api/annotations.proto";
import "protoc-gen-openapiv2/options/annotations.proto";
import "tagger/tagger.proto";
import "validate/validate.proto";
option go_package = "github.com/zhufuyi/sponge/api/serverNameExample/v1;v1";
/*
Reference https://github.com/grpc-ecosystem/grpc-gateway/blob/db7fbefff7c04877cdb32e16d4a248a024428207/examples/internal/proto/examplepb/a_bit_of_everything.proto
Default settings for generating swagger documents
NOTE: because json does not support 64 bits, the int64 and uint64 types under *.swagger.json are automatically converted to string types
Tips: add swagger option to rpc method, example:
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
summary: "get user by id",
description: "get user by id",
security: {
security_requirement: {
key: "BearerAuth";
value: {}
}
}
};
*/
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_swagger) = {
host: "localhost:8080"
base_path: ""
info: {
title: "serverNameExample api docs";
version: "2.0";
}
schemes: HTTP;
schemes: HTTPS;
consumes: "application/json";
produces: "application/json";
security_definitions: {
security: {
key: "BearerAuth";
value: {
type: TYPE_API_KEY;
in: IN_HEADER;
name: "Authorization";
description: "Type Bearer your-jwt-token to Value";
}
}
}
};
service {{.TName}} {
// create {{.TName}}
rpc Create(Create{{.TableName}}Request) returns (Create{{.TableName}}Reply) {
option (google.api.http) = {
post: "/api/v1/{{.TName}}"
body: "*"
};
}
// delete {{.TName}} by {{.CrudInfo.ColumnNameCamelFCL}}
rpc DeleteBy{{.CrudInfo.ColumnNameCamel}}(Delete{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Request) returns (Delete{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Reply) {
option (google.api.http) = {
delete: "/api/v1/{{.TName}}/left_curly_bracket{{.CrudInfo.ColumnNameCamelFCL}}right_curly_bracket"
};
}
// update {{.TName}} by {{.CrudInfo.ColumnNameCamelFCL}}
rpc UpdateBy{{.CrudInfo.ColumnNameCamel}}(Update{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Request) returns (Update{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Reply) {
option (google.api.http) = {
put: "/api/v1/{{.TName}}/left_curly_bracket{{.CrudInfo.ColumnNameCamelFCL}}right_curly_bracket"
body: "*"
};
}
// get {{.TName}} by {{.CrudInfo.ColumnNameCamelFCL}}
rpc GetBy{{.CrudInfo.ColumnNameCamel}}(Get{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Request) returns (Get{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Reply) {
option (google.api.http) = {
get: "/api/v1/{{.TName}}/left_curly_bracket{{.CrudInfo.ColumnNameCamelFCL}}right_curly_bracket"
};
}
// list of {{.TName}} by query parameters
rpc List(List{{.TableName}}Request) returns (List{{.TableName}}Reply) {
option (google.api.http) = {
post: "/api/v1/{{.TName}}/list"
body: "*"
};
}
}
/*
Notes for defining message fields:
1. Suggest using camel case style naming for message field names, such as firstName, lastName, etc.
2. If the message field name ending in 'id', it is recommended to use xxxID naming format, such as userID, orderID, etc.
3. Add validate rules https://github.com/envoyproxy/protoc-gen-validate#constraint-rules, such as:
uint64 id = 1 [(validate.rules).uint64.gte = 1];
If used to generate code that supports the HTTP protocol, notes for defining message fields:
1. If the route contains the path parameter, such as /api/v1/userExample/{id}, the defined
message must contain the name of the path parameter and the name should be added
with a new tag, such as int64 id = 1 [(tagger.tags) = "uri:\"id\""];
2. If the request url is followed by a query parameter, such as /api/v1/getUserExample?name=Tom,
a form tag must be added when defining the query parameter in the message, such as:
string name = 1 [(tagger.tags) = "form:\"name\""].
3. If the message field name contain underscores(such as 'field_name'), it will cause a problem
where the JSON field names of the Swagger request parameters are different from those of the
GRPC JSON tag names. There are two solutions: Solution 1, remove the underline from the
message field name. Option 2, use the tool 'protoc-go-inject-tag' to modify the JSON tag name,
such as: string first_name = 1 ; // @gotags: json:"firstName"
*/
// protoMessageCreateCode
message Create{{.TableName}}Reply {
{{.CrudInfo.ProtoType}} {{.CrudInfo.ColumnNameCamelFCL}} = 1;
}
message Delete{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Request {
{{.CrudInfo.ProtoType}} {{.CrudInfo.ColumnNameCamelFCL}} = 1 {{.CrudInfo.GetWebProtoValidation}};
}
message Delete{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Reply {
}
// protoMessageUpdateCode
message Update{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Reply {
}
// protoMessageDetailCode
message Get{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Request {
{{.CrudInfo.ProtoType}} {{.CrudInfo.ColumnNameCamelFCL}} = 1 {{.CrudInfo.GetWebProtoValidation}};
}
message Get{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Reply {
{{.TableName}} {{.TName}} = 1;
}
message List{{.TableName}}Request {
api.types.Params params = 1;
}
message List{{.TableName}}Reply {
int64 total = 1;
repeated {{.TableName}} {{.CrudInfo.TableNamePluralCamelFCL}} = 2;
}
`
protoMessageCreateCommonTmpl *template.Template
protoMessageCreateCommonTmplRaw = `message Create{{.TableName}}Request {
{{- range $i, $v := .Fields}}
{{$v.GoType}} {{$v.JSONName}} = {{$v.AddOne $i}}; {{if $v.Comment}} // {{$v.Comment}}{{end}}
{{- end}}
}`
protoMessageUpdateCommonTmpl *template.Template
protoMessageUpdateCommonTmplRaw = `message Update{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Request {
{{- range $i, $v := .Fields}}
{{$v.GoType}} {{$v.JSONName}} = {{$v.AddOneWithTag2 $i}}; {{if $v.Comment}} // {{$v.Comment}}{{end}}
{{- end}}
}`
protoMessageDetailCommonTmpl *template.Template
protoMessageDetailCommonTmplRaw = `message {{.TableName}} {
{{- range $i, $v := .Fields}}
{{$v.GoType}} {{$v.JSONName}} = {{$v.AddOne $i}}; {{if $v.Comment}} // {{$v.Comment}}{{end}}
{{- end}}
}`
serviceStructCommonTmpl *template.Template
serviceStructCommonTmplRaw = `
{
name: "Create",
fn: func() (interface{}, error) {
// todo enter parameters before testing
// serviceCreateStructCode
},
wantErr: false,
},
{
name: "UpdateBy{{.CrudInfo.ColumnNameCamel}}",
fn: func() (interface{}, error) {
// todo enter parameters before testing
// serviceUpdateStructCode
},
wantErr: false,
},
`
serviceCreateStructCommonTmpl *template.Template
serviceCreateStructCommonTmplRaw = ` req := &serverNameExampleV1.Create{{.TableName}}Request{
{{- range .Fields}}
{{.Name}}: {{.GoTypeZero}}, {{if .Comment}} // {{.Comment}}{{end}}
{{- end}}
}
return cli.Create(ctx, req)`
serviceUpdateStructCommonTmpl *template.Template
serviceUpdateStructCommonTmplRaw = ` req := &serverNameExampleV1.Update{{.TableName}}By{{.CrudInfo.ColumnNameCamel}}Request{
{{- range .Fields}}
{{.Name}}: {{.GoTypeZero}}, {{if .Comment}} // {{.Comment}}{{end}}
{{- end}}
}
return cli.UpdateBy{{.CrudInfo.ColumnNameCamel}}(ctx, req)`
commonTmplParseOnce sync.Once
)
func initCommonTemplate() {
commonTmplParseOnce.Do(func() {
var err, errSum error
handlerCreateStructCommonTmpl, err = template.New("goPostStruct").Parse(handlerCreateStructCommonTmplRaw)
if err != nil {
errSum = errors.Wrap(errSum, "handlerCreateStructCommonTmplRaw:"+err.Error())
}
handlerUpdateStructCommonTmpl, err = template.New("goPutStruct").Parse(handlerUpdateStructCommonTmplRaw)
if err != nil {
errSum = errors.Wrap(errSum, "handlerUpdateStructCommonTmplRaw:"+err.Error())
}
handlerDetailStructCommonTmpl, err = template.New("goGetStruct").Parse(handlerDetailStructCommonTmplRaw)
if err != nil {
errSum = errors.Wrap(errSum, "handlerDetailStructCommonTmplRaw:"+err.Error())
}
protoFileCommonTmpl, err = template.New("protoFile").Parse(protoFileCommonTmplRaw)
if err != nil {
errSum = errors.Wrap(errSum, "protoFileCommonTmplRaw:"+err.Error())
}
protoFileSimpleCommonTmpl, err = template.New("protoFileSimple").Parse(protoFileSimpleCommonTmplRaw)
if err != nil {
errSum = errors.Wrap(errSum, "protoFileSimpleCommonTmplRaw:"+err.Error())
}
protoFileForWebCommonTmpl, err = template.New("protoFileForWeb").Parse(protoFileForWebCommonTmplRaw)
if err != nil {
errSum = errors.Wrap(errSum, "protoFileForWebCommonTmplRaw:"+err.Error())
}
protoFileForSimpleWebCommonTmpl, err = template.New("protoFileForSimpleWeb").Parse(protoFileForSimpleWebCommonTmplRaw)
if err != nil {
errSum = errors.Wrap(errSum, "protoFileForSimpleWebCommonTmplRaw:"+err.Error())
}
protoMessageCreateCommonTmpl, err = template.New("protoMessageCreate").Parse(protoMessageCreateCommonTmplRaw)
if err != nil {
errSum = errors.Wrap(errSum, "protoMessageCreateCommonTmplRaw:"+err.Error())
}
protoMessageUpdateCommonTmpl, err = template.New("protoMessageUpdate").Parse(protoMessageUpdateCommonTmplRaw)
if err != nil {
errSum = errors.Wrap(errSum, "protoMessageUpdateCommonTmplRaw:"+err.Error())
}
protoMessageDetailCommonTmpl, err = template.New("protoMessageDetail").Parse(protoMessageDetailCommonTmplRaw)
if err != nil {
errSum = errors.Wrap(errSum, "protoMessageDetailCommonTmplRaw:"+err.Error())
}
serviceCreateStructCommonTmpl, err = template.New("serviceCreateStruct").Parse(serviceCreateStructCommonTmplRaw)
if err != nil {
errSum = errors.Wrap(errSum, "serviceCreateStructCommonTmplRaw:"+err.Error())
}
serviceUpdateStructCommonTmpl, err = template.New("serviceUpdateStruct").Parse(serviceUpdateStructCommonTmplRaw)
if err != nil {
errSum = errors.Wrap(errSum, "serviceUpdateStructCommonTmplRaw:"+err.Error())
}
serviceStructCommonTmpl, err = template.New("serviceStruct").Parse(serviceStructCommonTmplRaw)
if err != nil {
errSum = errors.Wrap(errSum, "serviceStructCommonTmplRaw:"+err.Error())
}
if errSum != nil {
panic(errSum)
}
})
}

View File

@@ -5,18 +5,19 @@ package parser
import (
"bufio"
"bytes"
"errors"
"fmt"
"go/format"
"sort"
"strings"
"text/template"
"github.com/blastrain/vitess-sqlparser/tidbparser/ast"
"github.com/blastrain/vitess-sqlparser/tidbparser/dependency/mysql"
"github.com/blastrain/vitess-sqlparser/tidbparser/dependency/types"
"github.com/blastrain/vitess-sqlparser/tidbparser/parser"
"github.com/huandu/xstrings"
"github.com/jinzhu/inflection"
"github.com/zhufuyi/sqlparser/ast"
"github.com/zhufuyi/sqlparser/dependency/mysql"
"github.com/zhufuyi/sqlparser/dependency/types"
"github.com/zhufuyi/sqlparser/parser"
)
const (
@@ -34,6 +35,8 @@ const (
CodeTypeProto = "proto"
// CodeTypeService grpc service code
CodeTypeService = "service"
// CodeTypeCrudInfo crud info json data
CodeTypeCrudInfo = "crud_info"
// DBDriverMysql mysql driver
DBDriverMysql = "mysql"
@@ -68,6 +71,7 @@ type modelCodes struct {
// ParseSQL generate different usage codes based on sql
func ParseSQL(sql string, options ...Option) (map[string]string, error) {
initTemplate()
initCommonTemplate()
opt := parseOption(options)
stmts, err := parser.New().Parse(sql, opt.Charset, opt.Collation)
@@ -82,6 +86,7 @@ func ParseSQL(sql string, options ...Option) (map[string]string, error) {
modelJSONCodes := make([]string, 0, len(stmts))
importPath := make(map[string]struct{})
tableNames := make([]string, 0, len(stmts))
primaryKeysCodes := make([]string, 0, len(stmts))
for _, stmt := range stmts {
if ct, ok := stmt.(*ast.CreateTableStmt); ok {
code, err2 := makeCode(ct, opt)
@@ -95,6 +100,7 @@ func ParseSQL(sql string, options ...Option) (map[string]string, error) {
serviceStructCodes = append(serviceStructCodes, code.serviceStruct)
modelJSONCodes = append(modelJSONCodes, code.modelJSON)
tableNames = append(tableNames, toCamel(ct.Table.Name.String()))
primaryKeysCodes = append(primaryKeysCodes, code.crudInfo)
for _, s := range code.importPaths {
importPath[s] = struct{}{}
}
@@ -118,13 +124,14 @@ func ParseSQL(sql string, options ...Option) (map[string]string, error) {
}
var codesMap = map[string]string{
CodeTypeModel: modelCode,
CodeTypeJSON: strings.Join(modelJSONCodes, "\n\n"),
CodeTypeDAO: strings.Join(updateFieldsCodes, "\n\n"),
CodeTypeHandler: strings.Join(handlerStructCodes, "\n\n"),
CodeTypeProto: strings.Join(protoFileCodes, "\n\n"),
CodeTypeService: strings.Join(serviceStructCodes, "\n\n"),
TableName: strings.Join(tableNames, ", "),
CodeTypeModel: modelCode,
CodeTypeJSON: strings.Join(modelJSONCodes, "\n\n"),
CodeTypeDAO: strings.Join(updateFieldsCodes, "\n\n"),
CodeTypeHandler: strings.Join(handlerStructCodes, "\n\n"),
CodeTypeProto: strings.Join(protoFileCodes, "\n\n"),
CodeTypeService: strings.Join(serviceStructCodes, "\n\n"),
TableName: strings.Join(tableNames, ", "),
CodeTypeCrudInfo: strings.Join(primaryKeysCodes, "||||"),
}
return codesMap, nil
@@ -140,16 +147,19 @@ type tmplData struct {
SubStructs string // sub structs for model
ProtoSubStructs string // sub structs for protobuf
DBDriver string
CrudInfo *CrudInfo
}
type tmplField struct {
Name string
ColName string
GoType string
Tag string
Comment string
JSONName string
DBDriver string
IsPrimaryKey bool // is primary key
ColName string // table column name
Name string // convert to camel case
GoType string // convert to go type
Tag string
Comment string
JSONName string
DBDriver string
rewriterField *rewriterField
}
@@ -159,6 +169,13 @@ type rewriterField struct {
path string
}
func (d tmplData) isCommonStyle(isEmbed bool) bool {
if d.DBDriver != DBDriverMongodb && !isEmbed && !d.CrudInfo.isIDPrimaryKey() {
return true
}
return false
}
// ConditionZero type of condition 0, used in dao template code
func (t tmplField) ConditionZero() string {
switch t.GoType {
@@ -260,12 +277,33 @@ func (t tmplField) AddOne(i int) int {
// AddOneWithTag counter and add id tag
func (t tmplField) AddOneWithTag(i int) string {
if t.ColName == "id" {
return fmt.Sprintf(`%d [(tagger.tags) = "uri:\"id\"" ]`, i+1)
if t.DBDriver == DBDriverMongodb {
return fmt.Sprintf(`%d [(validate.rules).string.min_len = 6, (tagger.tags) = "uri:\"id\""]`, i+1)
}
return fmt.Sprintf(`%d [(validate.rules).%s.gt = 0, (tagger.tags) = "uri:\"id\""]`, i+1, t.GoType)
}
return fmt.Sprintf("%d", i+1)
}
func (t tmplField) AddOneWithTag2(i int) string {
if t.IsPrimaryKey || t.ColName == "id" {
if t.GoType == "string" {
return fmt.Sprintf(`%d [(validate.rules).string.min_len = 1, (tagger.tags) = "uri:\"%s\""]`, i+1, t.JSONName)
}
return fmt.Sprintf(`%d [(validate.rules).%s.gt = 0, (tagger.tags) = "uri:\"%s\""]`, i+1, t.GoType, t.JSONName)
}
return fmt.Sprintf("%d", i+1)
}
func getProtoFieldName(fields []tmplField) string {
for _, field := range fields {
if field.IsPrimaryKey || field.ColName == "id" {
return field.JSONName
}
}
return ""
}
const (
__mysqlModel__ = "__mysqlModel__" //nolint
__type__ = "__type__" //nolint
@@ -312,6 +350,7 @@ type codeText struct {
handlerStruct string
protoFile string
serviceStruct string
crudInfo string
}
// nolint
@@ -320,7 +359,8 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
data := tmplData{
TableName: stmt.Table.Name.String(),
RawTableName: stmt.Table.Name.String(),
Fields: make([]tmplField, 0, 1),
//Fields: make([]tmplField, 0, 1),
DBDriver: opt.DBDriver,
}
tablePrefix := opt.TablePrefix
if tablePrefix != "" && strings.HasPrefix(data.TableName, tablePrefix) {
@@ -340,7 +380,7 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
}
data.TableName = toCamel(data.TableName)
data.TName = firstLetterToLow(data.TableName)
data.TName = firstLetterToLower(data.TableName)
// find table comment
for _, o := range stmt.Options {
@@ -391,6 +431,7 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
}
}
if isPrimaryKey[colName] {
field.IsPrimaryKey = true
gormTag.WriteString(";primary_key")
}
isNotNull := false
@@ -444,7 +485,7 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
default: // gorm
if !isPrimaryKey[colName] && isNotNull {
gormTag.WriteString(";NOT NULL")
gormTag.WriteString(";not null")
}
tags = append(tags, "gorm", gormTag.String())
@@ -473,24 +514,26 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
data.Fields = append(data.Fields, field)
}
if len(data.Fields) == 0 {
return nil, errors.New("no columns found in table " + data.TableName)
}
data.CrudInfo = newCrudInfo(data)
data.CrudInfo.IsCommonType = data.isCommonStyle(opt.IsEmbed)
if v, ok := opt.FieldTypes[SubStructKey]; ok {
data.SubStructs = v
}
if v, ok := opt.FieldTypes[ProtoSubStructKey]; ok {
data.ProtoSubStructs = v
}
data.DBDriver = opt.DBDriver
updateFieldsCode, err := getUpdateFieldsCode(data, opt.IsEmbed)
if err != nil {
return nil, err
}
handlerStructCode, err := getHandlerStructCodes(data, opt.JSONNamedType)
if err != nil {
return nil, err
}
modelStructCode, importPaths, err := getModelStructCode(data, importPath, opt.IsEmbed, opt.JSONNamedType)
if err != nil {
return nil, err
@@ -501,14 +544,35 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
return nil, err
}
protoFileCode, err := getProtoFileCode(data, opt.JSONNamedType, opt.IsWebProto, opt.IsExtendedAPI)
if err != nil {
return nil, err
}
serviceStructCode, err := getServiceStructCode(data)
if err != nil {
return nil, err
handlerStructCode := ""
serviceStructCode := ""
protoFileCode := ""
if data.isCommonStyle(opt.IsEmbed) {
handlerStructCode, err = getCommonHandlerStructCodes(data, opt.JSONNamedType)
if err != nil {
return nil, err
}
serviceStructCode, err = getCommonServiceStructCode(data)
if err != nil {
return nil, err
}
protoFileCode, err = getCommonProtoFileCode(data, opt.JSONNamedType, opt.IsWebProto, opt.IsExtendedAPI)
if err != nil {
return nil, err
}
} else {
handlerStructCode, err = getHandlerStructCodes(data, opt.JSONNamedType)
if err != nil {
return nil, err
}
serviceStructCode, err = getServiceStructCode(data)
if err != nil {
return nil, err
}
protoFileCode, err = getProtoFileCode(data, opt.JSONNamedType, opt.IsWebProto, opt.IsExtendedAPI)
if err != nil {
return nil, err
}
}
return &codeText{
@@ -519,6 +583,7 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
handlerStruct: handlerStructCode,
protoFile: protoFileCode,
serviceStruct: serviceStructCode,
crudInfo: data.CrudInfo.getCode(),
}, nil
}
@@ -586,6 +651,9 @@ func getModelStructCode(data tmplData, importPaths []string, isEmbed bool, jsonN
// force conversion of ID field to uint64 type
if field.Name == "ID" {
data.Fields[i].GoType = "uint64"
if data.isCommonStyle(isEmbed) {
data.Fields[i].GoType = data.CrudInfo.GoType
}
}
switch field.DBDriver {
case DBDriverMysql, DBDriverTidb, DBDriverPostgresql:
@@ -763,7 +831,7 @@ func getModelJSONCode(data tmplData) (string, error) {
}
func getProtoFileCode(data tmplData, jsonNamedType int, isWebProto bool, isExtendedAPI bool) (string, error) {
data.Fields = goTypeToProto(data.Fields, jsonNamedType)
data.Fields = goTypeToProto(data.Fields, jsonNamedType, false)
var err error
builder := strings.Builder{}
@@ -790,20 +858,20 @@ func getProtoFileCode(data tmplData, jsonNamedType int, isWebProto bool, isExten
protoMessageCreateCode, err := tmplExecuteWithFilter(data, protoMessageCreateTmpl)
if err != nil {
return "", fmt.Errorf("handlerCreateStructTmpl error: %v", err)
return "", fmt.Errorf("handle protoMessageCreateTmpl error: %v", err)
}
protoMessageUpdateCode, err := tmplExecuteWithFilter(data, protoMessageUpdateTmpl, columnID)
if err != nil {
return "", fmt.Errorf("handlerCreateStructTmpl error: %v", err)
return "", fmt.Errorf("handle protoMessageUpdateTmpl error: %v", err)
}
if !isWebProto {
protoMessageUpdateCode = strings.ReplaceAll(protoMessageUpdateCode, ` [(tagger.tags) = "uri:\"id\"" ]`, "")
protoMessageUpdateCode = strings.ReplaceAll(protoMessageUpdateCode, `, (tagger.tags) = "uri:\"id\""`, "")
}
protoMessageDetailCode, err := tmplExecuteWithFilter(data, protoMessageDetailTmpl, columnID, columnCreatedAt, columnUpdatedAt)
if err != nil {
return "", fmt.Errorf("handlerCreateStructTmpl error: %v", err)
return "", fmt.Errorf("handle protoMessageDetailTmpl error: %v", err)
}
code = strings.ReplaceAll(code, "// protoMessageCreateCode", protoMessageCreateCode)
@@ -901,13 +969,13 @@ func getServiceStructCode(data tmplData) (string, error) {
serviceCreateStructCode, err := tmplExecuteWithFilter(data, serviceCreateStructTmpl)
if err != nil {
return "", fmt.Errorf("handlerCreateStructTmpl error: %v", err)
return "", fmt.Errorf("handle serviceCreateStructTmpl error: %v", err)
}
serviceCreateStructCode = strings.ReplaceAll(serviceCreateStructCode, "ID:", "Id:")
serviceUpdateStructCode, err := tmplExecuteWithFilter(data, serviceUpdateStructTmpl, columnID)
if err != nil {
return "", fmt.Errorf("handlerCreateStructTmpl error: %v", err)
return "", fmt.Errorf("handle serviceUpdateStructTmpl error: %v", err)
}
serviceUpdateStructCode = strings.ReplaceAll(serviceUpdateStructCode, "ID:", "Id:")
@@ -1017,7 +1085,7 @@ func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path s
}
// nolint
func goTypeToProto(fields []tmplField, jsonNameType int) []tmplField {
func goTypeToProto(fields []tmplField, jsonNameType int, isCommonStyle bool) []tmplField {
var newFields []tmplField
for _, field := range fields {
switch field.GoType {
@@ -1053,7 +1121,7 @@ func goTypeToProto(fields []tmplField, jsonNameType int) []tmplField {
field.GoType = "repeated string"
}
} else {
if strings.ToLower(field.Name) == "id" {
if strings.ToLower(field.Name) == "id" && !isCommonStyle {
field.GoType = "uint64"
}
}
@@ -1154,7 +1222,7 @@ func toCamel(s string) string {
return str
}
func firstLetterToLow(str string) string {
func firstLetterToLower(str string) string {
if len(str) == 0 {
return str
}
@@ -1167,7 +1235,7 @@ func firstLetterToLow(str string) string {
}
func customToCamel(str string) string {
str = firstLetterToLow(toCamel(str))
str = firstLetterToLower(toCamel(str))
if len(str) == 2 {
if str == "iD" {

View File

@@ -4,12 +4,100 @@ import (
"fmt"
"testing"
"github.com/blastrain/vitess-sqlparser/tidbparser/dependency/mysql"
"github.com/blastrain/vitess-sqlparser/tidbparser/dependency/types"
"github.com/jinzhu/inflection"
"github.com/stretchr/testify/assert"
"github.com/zhufuyi/sqlparser/dependency/mysql"
"github.com/zhufuyi/sqlparser/dependency/types"
)
func TestParseSql(t *testing.T) {
func TestParseSQL(t *testing.T) {
sqls := []string{`create table user
(
id bigint unsigned auto_increment
primary key,
created_at datetime null,
updated_at datetime null,
deleted_at datetime null,
name char(50) not null comment '用户名',
password char(100) not null comment '密码',
email char(50) not null comment '邮件',
phone bigint unsigned not null comment '手机号码',
age tinyint not null comment '年龄',
gender tinyint not null comment '性别1:男2:女3:未知',
status tinyint not null comment '账号状态1:未激活2:已激活3:封禁',
login_state tinyint not null comment '登录状态1:未登录2:已登录',
constraint user_email_uindex
unique (email)
);`,
`create table user_order
(
id varchar(36) not null comment '订单id'
primary key,
product_id varchar(36) not null comment '商品id',
user_id bigint unsigned not null comment '用户id',
status smallint null comment '0:未支付, 1:已支付, 2:已取消',
created_at timestamp null comment '创建时间',
updated_at timestamp null comment '更新时间'
);`,
`create table user_str
(
user_id varchar(36) not null comment '用户id'
primary key,
username varchar(50) not null comment '用户名',
email varchar(100) not null comment '邮箱',
created_at datetime null comment '创建时间',
constraint email
unique (email)
);`,
`create table user_no_primary
(
username varchar(50) not null comment '用户名',
email varchar(100) not null comment '邮箱',
user_id varchar(36) not null comment '用户id',
created_at datetime null comment '创建时间',
constraint email
unique (email)
);`}
for _, sql := range sqls {
codes, err := ParseSQL(sql, WithJSONTag(0), WithEmbed())
assert.Nil(t, err)
for k, v := range codes {
assert.NotEmpty(t, k)
assert.NotEmpty(t, v)
}
//printCode(codes)
codes, err = ParseSQL(sql, WithJSONTag(1), WithWebProto(), WithDBDriver(DBDriverMysql))
assert.Nil(t, err)
for k, v := range codes {
assert.NotEmpty(t, k)
assert.NotEmpty(t, v)
}
//printCode(codes)
codes, err = ParseSQL(sql, WithJSONTag(0), WithDBDriver(DBDriverPostgresql))
assert.Nil(t, err)
for k, v := range codes {
assert.NotEmpty(t, k)
assert.NotEmpty(t, v)
}
//printCode(codes)
codes, err = ParseSQL(sql, WithJSONTag(0), WithDBDriver(DBDriverSqlite))
assert.Nil(t, err)
for k, v := range codes {
assert.NotEmpty(t, k)
assert.NotEmpty(t, v)
}
//printCode(codes)
}
}
func TestParseSqlWithTablePrefix(t *testing.T) {
sql := `CREATE TABLE t_person_info (
id BIGINT(11) PRIMARY KEY AUTO_INCREMENT NOT NULL COMMENT 'id',
age INT(11) unsigned NULL,
@@ -177,7 +265,7 @@ func Test_goTypeToProto(t *testing.T) {
{GoType: "uint"},
{GoType: "time.Time"},
}
v := goTypeToProto(fields, 1)
v := goTypeToProto(fields, 1, false)
assert.NotNil(t, v)
}
@@ -206,14 +294,14 @@ func Test_initTemplate(t *testing.T) {
}
func TestGetMysqlTableInfo(t *testing.T) {
info, err := GetMysqlTableInfo("root:123456@(192.168.3.37:3306)/test", "user")
info, err := GetMysqlTableInfo("root:123456@(192.168.3.37:3306)/account", "user_order")
t.Log(err, info)
}
func TestGetPostgresqlTableInfo(t *testing.T) {
var (
dbname = "account"
tableName = "user_example"
tableName = "user_order"
dsn = fmt.Sprintf("host=192.168.3.37 port=5432 user=root password=123456 dbname=%s sslmode=disable", dbname)
)
@@ -228,8 +316,13 @@ func TestGetPostgresqlTableInfo(t *testing.T) {
t.Log(fieldTypes)
}
func Test_getPostgresqlTableFields(t *testing.T) {
defer func() { _ = recover() }()
_, _ = getPostgresqlTableFields(nil, "foobar")
}
func TestGetSqliteTableInfo(t *testing.T) {
info, err := GetSqliteTableInfo("..\\..\\..\\test\\sql\\sqlite\\sponge.db", "user_example")
info, err := GetSqliteTableInfo("..\\..\\..\\test\\sql\\sqlite\\sponge.db", "user_order")
t.Log(err, info)
}
@@ -260,7 +353,7 @@ func TestConvertToSQLByPgFields(t *testing.T) {
t.Log(sql, tps)
}
func Test_toMysqlTable(t *testing.T) {
func Test_PGField_getMysqlType(t *testing.T) {
fields := []*PGField{
{Type: "smallint"},
{Type: "bigint"},
@@ -277,7 +370,23 @@ func Test_toMysqlTable(t *testing.T) {
{Type: "boolean"},
}
for _, field := range fields {
t.Log(toMysqlType(field), getType(field))
t.Log(field.getMysqlType(), getType(field))
}
}
func Test_SqliteField_getMysqlType(t *testing.T) {
fields := []*SqliteField{
{Type: "integer"},
{Type: "text"},
{Type: "real"},
{Type: "numeric"},
{Type: "blob"},
{Type: "datetime"},
{Type: "boolean"},
{Type: "unknown_type"},
}
for _, field := range fields {
t.Log(field.getMysqlType())
}
}
@@ -288,9 +397,9 @@ func printCode(code map[string]string) {
}
func printPGFields(fields []*PGField) {
fmt.Printf("%-20v %-20v %-20v %-20v %-20v %-20v\n", "Name", "Type", "Length", "Lengthvar", "Notnull", "Comment")
fmt.Printf("%-20v %-20v %-20v %-20v %-20v %-20v %-20v\n", "Name", "Type", "Length", "Lengthvar", "Notnull", "Comment", "IsPrimaryKey")
for _, p := range fields {
fmt.Printf("%-20v %-20v %-20v %-20v %-20v %-20v\n", p.Name, p.Type, p.Length, p.Lengthvar, p.Notnull, p.Comment)
fmt.Printf("%-20v %-20v %-20v %-20v %-20v %-20v %-20v\n", p.Name, p.Type, p.Length, p.Lengthvar, p.Notnull, p.Comment, p.IsPrimaryKey)
}
}
@@ -442,3 +551,82 @@ func Test_embedTimeFields(t *testing.T) {
fields = embedTimeField(names, []*MgoField{})
t.Log(fields)
}
func TestCrudInfo(t *testing.T) {
data := tmplData{
TableName: "User",
TName: "user",
NameFunc: false,
RawTableName: "user",
Fields: []tmplField{
{
ColName: "name",
Name: "Name",
GoType: "string",
Tag: "json:\"name\"",
Comment: "姓名",
JSONName: "name",
DBDriver: "mysql",
},
{
ColName: "age",
Name: "Age",
GoType: "int",
Tag: "json:\"age\"",
Comment: "年龄",
JSONName: "age",
DBDriver: "mysql",
},
{
ColName: "created_at",
Name: "CreatedAt",
GoType: "time.Time",
Tag: "json:\"created_at\"",
Comment: "创建时间",
JSONName: "createdAt",
DBDriver: "mysql",
},
},
Comment: "用户信息",
SubStructs: "",
ProtoSubStructs: "",
DBDriver: "mysql",
}
info := newCrudInfo(data)
isPrimary := info.isIDPrimaryKey()
assert.Equal(t, false, isPrimary)
code := info.getCode()
assert.Contains(t, code, `"tableNameCamel":"User","tableNameCamelFCL":"user"`)
grpcValidation := info.GetGRPCProtoValidation()
assert.Contains(t, grpcValidation, "validate.rules")
webValidation := info.GetWebProtoValidation()
assert.Contains(t, webValidation, "validate.rules")
info = nil
_ = info.isIDPrimaryKey()
_ = info.getCode()
_ = info.GetGRPCProtoValidation()
_ = info.GetWebProtoValidation()
}
func Test_customEndOfLetterToLower(t *testing.T) {
names := []string{
"ID",
"IP",
"userID",
"orderID",
"LocalIP",
"bus",
"BUS",
"x",
"s",
}
for _, name := range names {
t.Log(customEndOfLetterToLower(name, inflection.Plural(name)))
}
}

View File

@@ -8,18 +8,8 @@ import (
"gorm.io/gorm"
)
// PGField postgresql field
type PGField struct {
Name string `gorm:"column:name;" json:"name"`
Type string `gorm:"column:type;" json:"type"`
Comment string `gorm:"column:comment;" json:"comment"`
Length int `gorm:"column:length;" json:"length"`
Lengthvar int `gorm:"column:lengthvar;" json:"lengthvar"`
Notnull bool `gorm:"column:notnull;" json:"notnull"`
}
// GetPostgresqlTableInfo get table info from postgres
func GetPostgresqlTableInfo(dsn string, tableName string) ([]*PGField, error) {
func GetPostgresqlTableInfo(dsn string, tableName string) (PGFields, error) {
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
return nil, fmt.Errorf("GetPostgresqlTableInfo error: %v", err)
@@ -29,51 +19,47 @@ func GetPostgresqlTableInfo(dsn string, tableName string) ([]*PGField, error) {
return getPostgresqlTableFields(db, tableName)
}
func getPostgresqlTableFields(db *gorm.DB, tableName string) ([]*PGField, error) {
query := fmt.Sprintf(`SELECT a.attname AS name, t.typname AS type, a.attlen AS length, a.atttypmod AS lengthvar, a.attnotnull AS notnull, b.description AS comment
FROM pg_class c, pg_attribute a
LEFT JOIN pg_description b
ON a.attrelid = b.objoid
AND a.attnum = b.objsubid, pg_type t
WHERE c.relname = '%s'
AND a.attnum > 0
AND a.attrelid = c.oid
AND a.atttypid = t.oid
ORDER BY a.attnum;`, tableName)
var fields []*PGField
result := db.Raw(query).Scan(&fields)
if result.Error != nil {
return nil, fmt.Errorf("failed to get table fields: %v", result.Error)
}
return fields, nil
}
// ConvertToSQLByPgFields convert to mysql table ddl
func ConvertToSQLByPgFields(tableName string, fields []*PGField) (string, map[string]string) {
func ConvertToSQLByPgFields(tableName string, fields PGFields) (string, map[string]string) {
fieldStr := ""
pgTypeMap := make(map[string]string) // name:type
if len(fields) == 0 {
return "", pgTypeMap
}
for _, field := range fields {
pgTypeMap[field.Name] = getType(field)
sqlType := toMysqlType(field)
if field.Name == "id" {
fieldStr += fmt.Sprintf(" %s bigint unsigned primary key,\n", field.Name)
continue
}
sqlType := field.getMysqlType()
notnullStr := "not null"
if !field.Notnull {
notnullStr = "null"
}
fieldStr += fmt.Sprintf(" `%s` %s %s comment '%s',\n", field.Name, sqlType, notnullStr, field.Comment)
}
fieldStr = strings.TrimSuffix(fieldStr, ",\n")
primaryField := fields.getPrimaryField()
if primaryField != nil {
fieldStr += fmt.Sprintf(" PRIMARY KEY (`%s`)\n", primaryField.Name)
} else {
fieldStr = strings.TrimSuffix(fieldStr, ",\n")
}
sqlStr := fmt.Sprintf("CREATE TABLE `%s` (\n%s\n);", tableName, fieldStr)
return sqlStr, pgTypeMap
}
// PGField postgresql field
type PGField struct {
Name string `gorm:"column:name;" json:"name"`
Type string `gorm:"column:type;" json:"type"`
Comment string `gorm:"column:comment;" json:"comment"`
Length int `gorm:"column:length;" json:"length"`
Lengthvar int `gorm:"column:lengthvar;" json:"lengthvar"`
Notnull bool `gorm:"column:notnull;" json:"notnull"`
IsPrimaryKey bool `gorm:"column:is_primary_key;" json:"is_primary_key"`
}
// nolint
func toMysqlType(field *PGField) string {
func (field *PGField) getMysqlType() string {
switch field.Type {
case "smallint", "integer", "smallserial", "serial", "int2", "int4":
return "int"
@@ -117,6 +103,79 @@ func toMysqlType(field *PGField) string {
return field.Type
}
type PGFields []*PGField
func (fields PGFields) getPrimaryField() *PGField {
var f *PGField
for _, field := range fields {
if field.IsPrimaryKey || field.Name == "id" {
f = field
return f
}
}
/*
// if no primary key, find the first xxx_id field
if f == nil {
for _, field := range fields {
if strings.HasSuffix(field.Name, "_id") {
f = field
f.IsPrimaryKey = true
return f
}
}
}
// if no xxx_id field, find the first field
if f == nil {
for _, field := range fields {
f = field
f.IsPrimaryKey = true
return f
}
}
*/
return f
}
func getPostgresqlTableFields(db *gorm.DB, tableName string) (PGFields, error) {
query := fmt.Sprintf(`SELECT
a.attname AS name,
t.typname AS type,
a.attlen AS length,
a.atttypmod AS lengthvar,
a.attnotnull AS notnull,
b.description AS comment,
CASE
WHEN pk.constraint_type = 'PRIMARY KEY' THEN true
ELSE false
END AS is_primary_key
FROM pg_class c
JOIN pg_attribute a ON a.attrelid = c.oid
LEFT JOIN pg_description b ON a.attrelid = b.objoid AND a.attnum = b.objsubid
JOIN pg_type t ON a.atttypid = t.oid
LEFT JOIN (
SELECT
kcu.column_name,
con.constraint_type
FROM information_schema.table_constraints con
JOIN information_schema.key_column_usage kcu
ON con.constraint_name = kcu.constraint_name
WHERE con.constraint_type = 'PRIMARY KEY'
AND con.table_name = '%s'
) AS pk ON a.attname = pk.column_name
WHERE c.relname = '%s'
AND a.attnum > 0
ORDER BY a.attnum;`, tableName, tableName)
var fields PGFields
result := db.Raw(query).Scan(&fields)
if result.Error != nil {
return nil, fmt.Errorf("failed to get table fields: %v", result.Error)
}
return fields, nil
}
func getType(field *PGField) string {
switch field.Type {
case "character", "character varying", "varchar", "char", "bpchar":

View File

@@ -1,24 +1,12 @@
package parser
import (
"fmt"
"strings"
"github.com/zhufuyi/sponge/pkg/ggorm"
)
var sqliteToMysqlTypeMap = map[string]string{
" INTEGER ": " INT ",
" REAL ": " FLOAT ",
" BOOLEAN ": " TINYINT ",
" NUMERIC ": " VARCHAR(255) ",
"AUTOINCREMENT": "auto_increment",
" integer ": " INT ",
" real ": " FLOAT ",
" boolean ": " TINYINT ",
" numeric ": " VARCHAR(255) ",
"autoincrement": "auto_increment",
}
// GetSqliteTableInfo get table info from sqlite
func GetSqliteTableInfo(dbFile string, tableName string) (string, error) {
db, err := ggorm.InitSqlite(dbFile)
@@ -27,16 +15,102 @@ func GetSqliteTableInfo(dbFile string, tableName string) (string, error) {
}
defer closeDB(db)
var sql string
err = db.Raw("select sql from sqlite_master where type = ? and name = ?", "table", tableName).Scan(&sql).Error
var sqliteFields SqliteFields
sql := fmt.Sprintf("PRAGMA table_info('%s')", tableName)
err = db.Raw(sql).Scan(&sqliteFields).Error
if err != nil {
return "", err
}
for k, v := range sqliteToMysqlTypeMap {
sql = strings.ReplaceAll(sql, k, v)
}
sql = strings.ReplaceAll(sql, "\"", "")
return sql, nil
return convertToSQLBySqliteFields(tableName, sqliteFields), nil
}
// SqliteField sqlite field struct
type SqliteField struct {
Cid int `gorm:"column:cid" json:"cid"`
Name string `gorm:"column:name" json:"name"`
Type string `gorm:"column:type" json:"type"`
Notnull int `gorm:"column:notnull" json:"notnull"`
DefaultValue string `gorm:"column:dflt_value" json:"dflt_value"`
Pk int `gorm:"column:pk" json:"pk"`
}
var sqliteToMysqlType = map[string]string{
"integer": "INT",
"text": "TEXT",
"real": "FLOAT",
"datetime": "DATETIME",
"blob": "BLOB",
"boolean": "TINYINT",
"numeric": " VARCHAR(255)",
"autoincrement": "auto_increment",
}
func (field *SqliteField) getMysqlType() string {
sqliteType := strings.ToLower(field.Type)
if mysqlType, ok := sqliteToMysqlType[sqliteType]; ok {
if field.Name == "id" && sqliteType == "text" {
return "VARCHAR(50)"
}
return mysqlType
}
return "VARCHAR(100)"
}
// SqliteFields sqlite fields
type SqliteFields []*SqliteField
func (fields SqliteFields) getPrimaryField() *SqliteField {
var f *SqliteField
for _, field := range fields {
if field.Pk == 1 || field.Name == "id" {
f = field
return f
}
}
/*
// if no primary key, find the first xxx_id field
if f == nil {
for _, field := range fields {
if strings.HasSuffix(field.Name, "_id") {
f = field
f.Pk = 1
return f
}
}
}
// if no xxx_id field, find the first field
if f == nil {
for _, field := range fields {
f = field
f.Pk = 1
return f
}
}
*/
return f
}
func convertToSQLBySqliteFields(tableName string, fields SqliteFields) string {
if len(fields) == 0 {
return ""
}
fieldStr := ""
for _, field := range fields {
notnullStr := "not null"
if field.Notnull == 0 {
notnullStr = "null"
}
fieldStr += fmt.Sprintf(" `%s` %s %s comment '%s',\n", field.Name, field.getMysqlType(), notnullStr, "")
}
primaryField := fields.getPrimaryField()
if primaryField != nil {
fieldStr += fmt.Sprintf(" PRIMARY KEY (`%s`)\n", primaryField.Name)
} else {
fieldStr = strings.TrimSuffix(fieldStr, ",\n")
}
return fmt.Sprintf("CREATE TABLE `%s` (\n%s\n);", tableName, fieldStr)
}

View File

@@ -297,9 +297,22 @@ import "validate/validate.proto";
option go_package = "github.com/zhufuyi/sponge/api/serverNameExample/v1;v1";
// Default settings for generating swagger documents
// NOTE: because json does not support 64 bits, the int64 and uint64 types under *.swagger.json are automatically converted to string types
// Reference https://github.com/grpc-ecosystem/grpc-gateway/blob/db7fbefff7c04877cdb32e16d4a248a024428207/examples/internal/proto/examplepb/a_bit_of_everything.proto
/*
Reference https://github.com/grpc-ecosystem/grpc-gateway/blob/db7fbefff7c04877cdb32e16d4a248a024428207/examples/internal/proto/examplepb/a_bit_of_everything.proto
Default settings for generating swagger documents
NOTE: because json does not support 64 bits, the int64 and uint64 types under *.swagger.json are automatically converted to string types
Tips: add swagger option to rpc method, example:
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
summary: "get user by id",
description: "get user by id",
security: {
security_requirement: {
key: "BearerAuth";
value: {}
}
}
};
*/
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_swagger) = {
host: "localhost:8080"
base_path: ""
@@ -331,16 +344,6 @@ service {{.TName}} {
post: "/api/v1/{{.TName}}"
body: "*"
};
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
summary: "create {{.TName}}",
description: "submit information to create {{.TName}}",
//security: {
// security_requirement: {
// key: "BearerAuth";
// value: {}
// }
//}
};
}
// delete {{.TName}} by id
@@ -348,16 +351,6 @@ service {{.TName}} {
option (google.api.http) = {
delete: "/api/v1/{{.TName}}/{id}"
};
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
summary: "delete {{.TName}}",
description: "delete {{.TName}} by id",
//security: {
// security_requirement: {
// key: "BearerAuth";
// value: {}
// }
//}
};
}
// update {{.TName}} by id
@@ -366,16 +359,6 @@ service {{.TName}} {
put: "/api/v1/{{.TName}}/{id}"
body: "*"
};
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
summary: "update {{.TName}}",
description: "update {{.TName}} by id",
//security: {
// security_requirement: {
// key: "BearerAuth";
// value: {}
// }
//}
};
}
// get {{.TName}} by id
@@ -383,16 +366,6 @@ service {{.TName}} {
option (google.api.http) = {
get: "/api/v1/{{.TName}}/{id}"
};
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
summary: "get {{.TName}} detail",
description: "get {{.TName}} detail by id",
//security: {
// security_requirement: {
// key: "BearerAuth";
// value: {}
// }
//}
};
}
// list of {{.TName}} by query parameters
@@ -401,16 +374,6 @@ service {{.TName}} {
post: "/api/v1/{{.TName}}/list"
body: "*"
};
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
summary: "list of {{.TName}}s by parameters",
description: "list of {{.TName}}s by paging and conditions",
//security: {
// security_requirement: {
// key: "BearerAuth";
// value: {}
// }
//}
};
}
// delete {{.TName}} by batch id
@@ -419,16 +382,6 @@ service {{.TName}} {
post: "/api/v1/{{.TName}}/delete/ids"
body: "*"
};
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
summary: "delete {{.TName}}s by batch id",
description: "delete {{.TName}}s by batch id",
//security: {
// security_requirement: {
// key: "BearerAuth";
// value: {}
// }
//}
};
}
// get {{.TName}} by condition
@@ -437,16 +390,6 @@ service {{.TName}} {
post: "/api/v1/{{.TName}}/condition"
body: "*"
};
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
summary: "get {{.TName}} detail by condition",
description: "get {{.TName}} detail by condition",
//security: {
// security_requirement: {
// key: "BearerAuth";
// value: {}
// }
//}
};
}
// list of {{.TName}} by batch id
@@ -455,16 +398,6 @@ service {{.TName}} {
post: "/api/v1/{{.TName}}/list/ids"
body: "*"
};
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
summary: "list of {{.TName}}s by batch id",
description: "list of {{.TName}}s by batch id",
//security: {
// security_requirement: {
// key: "BearerAuth";
// value: {}
// }
//}
};
}
// list {{.TName}} by last id
@@ -472,16 +405,6 @@ service {{.TName}} {
option (google.api.http) = {
get: "/api/v1/{{.TName}}/list"
};
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
summary: "list of {{.TName}} by last id",
description: "list of {{.TName}} by last id",
//security: {
// security_requirement: {
// key: "BearerAuth";
// value: {}
// }
//}
};
}
}
@@ -595,9 +518,22 @@ import "validate/validate.proto";
option go_package = "github.com/zhufuyi/sponge/api/serverNameExample/v1;v1";
// Default settings for generating swagger documents
// NOTE: because json does not support 64 bits, the int64 and uint64 types under *.swagger.json are automatically converted to string types
// Reference https://github.com/grpc-ecosystem/grpc-gateway/blob/db7fbefff7c04877cdb32e16d4a248a024428207/examples/internal/proto/examplepb/a_bit_of_everything.proto
/*
Reference https://github.com/grpc-ecosystem/grpc-gateway/blob/db7fbefff7c04877cdb32e16d4a248a024428207/examples/internal/proto/examplepb/a_bit_of_everything.proto
Default settings for generating swagger documents
NOTE: because json does not support 64 bits, the int64 and uint64 types under *.swagger.json are automatically converted to string types
Tips: add swagger option to rpc method, example:
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
summary: "get user by id",
description: "get user by id",
security: {
security_requirement: {
key: "BearerAuth";
value: {}
}
}
};
*/
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_swagger) = {
host: "localhost:8080"
base_path: ""
@@ -629,16 +565,6 @@ service {{.TName}} {
post: "/api/v1/{{.TName}}"
body: "*"
};
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
summary: "create {{.TName}}",
description: "submit information to create {{.TName}}",
//security: {
// security_requirement: {
// key: "BearerAuth";
// value: {}
// }
//}
};
}
// delete {{.TName}} by id
@@ -646,16 +572,6 @@ service {{.TName}} {
option (google.api.http) = {
delete: "/api/v1/{{.TName}}/{id}"
};
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
summary: "delete {{.TName}}",
description: "delete {{.TName}} by id",
//security: {
// security_requirement: {
// key: "BearerAuth";
// value: {}
// }
//}
};
}
// update {{.TName}} by id
@@ -664,16 +580,6 @@ service {{.TName}} {
put: "/api/v1/{{.TName}}/{id}"
body: "*"
};
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
summary: "update {{.TName}}",
description: "update {{.TName}} by id",
//security: {
// security_requirement: {
// key: "BearerAuth";
// value: {}
// }
//}
};
}
// get {{.TName}} by id
@@ -681,16 +587,6 @@ service {{.TName}} {
option (google.api.http) = {
get: "/api/v1/{{.TName}}/{id}"
};
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
summary: "get {{.TName}} detail",
description: "get {{.TName}} detail by id",
//security: {
// security_requirement: {
// key: "BearerAuth";
// value: {}
// }
//}
};
}
// list of {{.TName}} by query parameters
@@ -699,16 +595,6 @@ service {{.TName}} {
post: "/api/v1/{{.TName}}/list"
body: "*"
};
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
summary: "list of {{.TName}}s by parameters",
description: "list of {{.TName}}s by paging and conditions",
//security: {
// security_requirement: {
// key: "BearerAuth";
// value: {}
// }
//}
};
}
}