feat: support custom table primary key type and name

This commit is contained in:
zhuyasen
2024-11-04 20:30:28 +08:00
parent 4deb203a02
commit 6eb9e51b9d
51 changed files with 5514 additions and 344 deletions

View File

@@ -5,18 +5,19 @@ package parser
import (
"bufio"
"bytes"
"errors"
"fmt"
"go/format"
"sort"
"strings"
"text/template"
"github.com/blastrain/vitess-sqlparser/tidbparser/ast"
"github.com/blastrain/vitess-sqlparser/tidbparser/dependency/mysql"
"github.com/blastrain/vitess-sqlparser/tidbparser/dependency/types"
"github.com/blastrain/vitess-sqlparser/tidbparser/parser"
"github.com/huandu/xstrings"
"github.com/jinzhu/inflection"
"github.com/zhufuyi/sqlparser/ast"
"github.com/zhufuyi/sqlparser/dependency/mysql"
"github.com/zhufuyi/sqlparser/dependency/types"
"github.com/zhufuyi/sqlparser/parser"
)
const (
@@ -34,6 +35,8 @@ const (
CodeTypeProto = "proto"
// CodeTypeService grpc service code
CodeTypeService = "service"
// CodeTypeCrudInfo crud info json data
CodeTypeCrudInfo = "crud_info"
// DBDriverMysql mysql driver
DBDriverMysql = "mysql"
@@ -68,6 +71,7 @@ type modelCodes struct {
// ParseSQL generate different usage codes based on sql
func ParseSQL(sql string, options ...Option) (map[string]string, error) {
initTemplate()
initCommonTemplate()
opt := parseOption(options)
stmts, err := parser.New().Parse(sql, opt.Charset, opt.Collation)
@@ -82,6 +86,7 @@ func ParseSQL(sql string, options ...Option) (map[string]string, error) {
modelJSONCodes := make([]string, 0, len(stmts))
importPath := make(map[string]struct{})
tableNames := make([]string, 0, len(stmts))
primaryKeysCodes := make([]string, 0, len(stmts))
for _, stmt := range stmts {
if ct, ok := stmt.(*ast.CreateTableStmt); ok {
code, err2 := makeCode(ct, opt)
@@ -95,6 +100,7 @@ func ParseSQL(sql string, options ...Option) (map[string]string, error) {
serviceStructCodes = append(serviceStructCodes, code.serviceStruct)
modelJSONCodes = append(modelJSONCodes, code.modelJSON)
tableNames = append(tableNames, toCamel(ct.Table.Name.String()))
primaryKeysCodes = append(primaryKeysCodes, code.crudInfo)
for _, s := range code.importPaths {
importPath[s] = struct{}{}
}
@@ -118,13 +124,14 @@ func ParseSQL(sql string, options ...Option) (map[string]string, error) {
}
var codesMap = map[string]string{
CodeTypeModel: modelCode,
CodeTypeJSON: strings.Join(modelJSONCodes, "\n\n"),
CodeTypeDAO: strings.Join(updateFieldsCodes, "\n\n"),
CodeTypeHandler: strings.Join(handlerStructCodes, "\n\n"),
CodeTypeProto: strings.Join(protoFileCodes, "\n\n"),
CodeTypeService: strings.Join(serviceStructCodes, "\n\n"),
TableName: strings.Join(tableNames, ", "),
CodeTypeModel: modelCode,
CodeTypeJSON: strings.Join(modelJSONCodes, "\n\n"),
CodeTypeDAO: strings.Join(updateFieldsCodes, "\n\n"),
CodeTypeHandler: strings.Join(handlerStructCodes, "\n\n"),
CodeTypeProto: strings.Join(protoFileCodes, "\n\n"),
CodeTypeService: strings.Join(serviceStructCodes, "\n\n"),
TableName: strings.Join(tableNames, ", "),
CodeTypeCrudInfo: strings.Join(primaryKeysCodes, "||||"),
}
return codesMap, nil
@@ -140,16 +147,19 @@ type tmplData struct {
SubStructs string // sub structs for model
ProtoSubStructs string // sub structs for protobuf
DBDriver string
CrudInfo *CrudInfo
}
type tmplField struct {
Name string
ColName string
GoType string
Tag string
Comment string
JSONName string
DBDriver string
IsPrimaryKey bool // is primary key
ColName string // table column name
Name string // convert to camel case
GoType string // convert to go type
Tag string
Comment string
JSONName string
DBDriver string
rewriterField *rewriterField
}
@@ -159,6 +169,13 @@ type rewriterField struct {
path string
}
func (d tmplData) isCommonStyle(isEmbed bool) bool {
if d.DBDriver != DBDriverMongodb && !isEmbed && !d.CrudInfo.isIDPrimaryKey() {
return true
}
return false
}
// ConditionZero type of condition 0, used in dao template code
func (t tmplField) ConditionZero() string {
switch t.GoType {
@@ -260,12 +277,33 @@ func (t tmplField) AddOne(i int) int {
// AddOneWithTag counter and add id tag
func (t tmplField) AddOneWithTag(i int) string {
if t.ColName == "id" {
return fmt.Sprintf(`%d [(tagger.tags) = "uri:\"id\"" ]`, i+1)
if t.DBDriver == DBDriverMongodb {
return fmt.Sprintf(`%d [(validate.rules).string.min_len = 6, (tagger.tags) = "uri:\"id\""]`, i+1)
}
return fmt.Sprintf(`%d [(validate.rules).%s.gt = 0, (tagger.tags) = "uri:\"id\""]`, i+1, t.GoType)
}
return fmt.Sprintf("%d", i+1)
}
func (t tmplField) AddOneWithTag2(i int) string {
if t.IsPrimaryKey || t.ColName == "id" {
if t.GoType == "string" {
return fmt.Sprintf(`%d [(validate.rules).string.min_len = 1, (tagger.tags) = "uri:\"%s\""]`, i+1, t.JSONName)
}
return fmt.Sprintf(`%d [(validate.rules).%s.gt = 0, (tagger.tags) = "uri:\"%s\""]`, i+1, t.GoType, t.JSONName)
}
return fmt.Sprintf("%d", i+1)
}
func getProtoFieldName(fields []tmplField) string {
for _, field := range fields {
if field.IsPrimaryKey || field.ColName == "id" {
return field.JSONName
}
}
return ""
}
const (
__mysqlModel__ = "__mysqlModel__" //nolint
__type__ = "__type__" //nolint
@@ -312,6 +350,7 @@ type codeText struct {
handlerStruct string
protoFile string
serviceStruct string
crudInfo string
}
// nolint
@@ -320,7 +359,8 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
data := tmplData{
TableName: stmt.Table.Name.String(),
RawTableName: stmt.Table.Name.String(),
Fields: make([]tmplField, 0, 1),
//Fields: make([]tmplField, 0, 1),
DBDriver: opt.DBDriver,
}
tablePrefix := opt.TablePrefix
if tablePrefix != "" && strings.HasPrefix(data.TableName, tablePrefix) {
@@ -340,7 +380,7 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
}
data.TableName = toCamel(data.TableName)
data.TName = firstLetterToLow(data.TableName)
data.TName = firstLetterToLower(data.TableName)
// find table comment
for _, o := range stmt.Options {
@@ -391,6 +431,7 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
}
}
if isPrimaryKey[colName] {
field.IsPrimaryKey = true
gormTag.WriteString(";primary_key")
}
isNotNull := false
@@ -444,7 +485,7 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
default: // gorm
if !isPrimaryKey[colName] && isNotNull {
gormTag.WriteString(";NOT NULL")
gormTag.WriteString(";not null")
}
tags = append(tags, "gorm", gormTag.String())
@@ -473,24 +514,26 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
data.Fields = append(data.Fields, field)
}
if len(data.Fields) == 0 {
return nil, errors.New("no columns found in table " + data.TableName)
}
data.CrudInfo = newCrudInfo(data)
data.CrudInfo.IsCommonType = data.isCommonStyle(opt.IsEmbed)
if v, ok := opt.FieldTypes[SubStructKey]; ok {
data.SubStructs = v
}
if v, ok := opt.FieldTypes[ProtoSubStructKey]; ok {
data.ProtoSubStructs = v
}
data.DBDriver = opt.DBDriver
updateFieldsCode, err := getUpdateFieldsCode(data, opt.IsEmbed)
if err != nil {
return nil, err
}
handlerStructCode, err := getHandlerStructCodes(data, opt.JSONNamedType)
if err != nil {
return nil, err
}
modelStructCode, importPaths, err := getModelStructCode(data, importPath, opt.IsEmbed, opt.JSONNamedType)
if err != nil {
return nil, err
@@ -501,14 +544,35 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
return nil, err
}
protoFileCode, err := getProtoFileCode(data, opt.JSONNamedType, opt.IsWebProto, opt.IsExtendedAPI)
if err != nil {
return nil, err
}
serviceStructCode, err := getServiceStructCode(data)
if err != nil {
return nil, err
handlerStructCode := ""
serviceStructCode := ""
protoFileCode := ""
if data.isCommonStyle(opt.IsEmbed) {
handlerStructCode, err = getCommonHandlerStructCodes(data, opt.JSONNamedType)
if err != nil {
return nil, err
}
serviceStructCode, err = getCommonServiceStructCode(data)
if err != nil {
return nil, err
}
protoFileCode, err = getCommonProtoFileCode(data, opt.JSONNamedType, opt.IsWebProto, opt.IsExtendedAPI)
if err != nil {
return nil, err
}
} else {
handlerStructCode, err = getHandlerStructCodes(data, opt.JSONNamedType)
if err != nil {
return nil, err
}
serviceStructCode, err = getServiceStructCode(data)
if err != nil {
return nil, err
}
protoFileCode, err = getProtoFileCode(data, opt.JSONNamedType, opt.IsWebProto, opt.IsExtendedAPI)
if err != nil {
return nil, err
}
}
return &codeText{
@@ -519,6 +583,7 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
handlerStruct: handlerStructCode,
protoFile: protoFileCode,
serviceStruct: serviceStructCode,
crudInfo: data.CrudInfo.getCode(),
}, nil
}
@@ -586,6 +651,9 @@ func getModelStructCode(data tmplData, importPaths []string, isEmbed bool, jsonN
// force conversion of ID field to uint64 type
if field.Name == "ID" {
data.Fields[i].GoType = "uint64"
if data.isCommonStyle(isEmbed) {
data.Fields[i].GoType = data.CrudInfo.GoType
}
}
switch field.DBDriver {
case DBDriverMysql, DBDriverTidb, DBDriverPostgresql:
@@ -763,7 +831,7 @@ func getModelJSONCode(data tmplData) (string, error) {
}
func getProtoFileCode(data tmplData, jsonNamedType int, isWebProto bool, isExtendedAPI bool) (string, error) {
data.Fields = goTypeToProto(data.Fields, jsonNamedType)
data.Fields = goTypeToProto(data.Fields, jsonNamedType, false)
var err error
builder := strings.Builder{}
@@ -790,20 +858,20 @@ func getProtoFileCode(data tmplData, jsonNamedType int, isWebProto bool, isExten
protoMessageCreateCode, err := tmplExecuteWithFilter(data, protoMessageCreateTmpl)
if err != nil {
return "", fmt.Errorf("handlerCreateStructTmpl error: %v", err)
return "", fmt.Errorf("handle protoMessageCreateTmpl error: %v", err)
}
protoMessageUpdateCode, err := tmplExecuteWithFilter(data, protoMessageUpdateTmpl, columnID)
if err != nil {
return "", fmt.Errorf("handlerCreateStructTmpl error: %v", err)
return "", fmt.Errorf("handle protoMessageUpdateTmpl error: %v", err)
}
if !isWebProto {
protoMessageUpdateCode = strings.ReplaceAll(protoMessageUpdateCode, ` [(tagger.tags) = "uri:\"id\"" ]`, "")
protoMessageUpdateCode = strings.ReplaceAll(protoMessageUpdateCode, `, (tagger.tags) = "uri:\"id\""`, "")
}
protoMessageDetailCode, err := tmplExecuteWithFilter(data, protoMessageDetailTmpl, columnID, columnCreatedAt, columnUpdatedAt)
if err != nil {
return "", fmt.Errorf("handlerCreateStructTmpl error: %v", err)
return "", fmt.Errorf("handle protoMessageDetailTmpl error: %v", err)
}
code = strings.ReplaceAll(code, "// protoMessageCreateCode", protoMessageCreateCode)
@@ -901,13 +969,13 @@ func getServiceStructCode(data tmplData) (string, error) {
serviceCreateStructCode, err := tmplExecuteWithFilter(data, serviceCreateStructTmpl)
if err != nil {
return "", fmt.Errorf("handlerCreateStructTmpl error: %v", err)
return "", fmt.Errorf("handle serviceCreateStructTmpl error: %v", err)
}
serviceCreateStructCode = strings.ReplaceAll(serviceCreateStructCode, "ID:", "Id:")
serviceUpdateStructCode, err := tmplExecuteWithFilter(data, serviceUpdateStructTmpl, columnID)
if err != nil {
return "", fmt.Errorf("handlerCreateStructTmpl error: %v", err)
return "", fmt.Errorf("handle serviceUpdateStructTmpl error: %v", err)
}
serviceUpdateStructCode = strings.ReplaceAll(serviceUpdateStructCode, "ID:", "Id:")
@@ -1017,7 +1085,7 @@ func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path s
}
// nolint
func goTypeToProto(fields []tmplField, jsonNameType int) []tmplField {
func goTypeToProto(fields []tmplField, jsonNameType int, isCommonStyle bool) []tmplField {
var newFields []tmplField
for _, field := range fields {
switch field.GoType {
@@ -1053,7 +1121,7 @@ func goTypeToProto(fields []tmplField, jsonNameType int) []tmplField {
field.GoType = "repeated string"
}
} else {
if strings.ToLower(field.Name) == "id" {
if strings.ToLower(field.Name) == "id" && !isCommonStyle {
field.GoType = "uint64"
}
}
@@ -1154,7 +1222,7 @@ func toCamel(s string) string {
return str
}
func firstLetterToLow(str string) string {
func firstLetterToLower(str string) string {
if len(str) == 0 {
return str
}
@@ -1167,7 +1235,7 @@ func firstLetterToLow(str string) string {
}
func customToCamel(str string) string {
str = firstLetterToLow(toCamel(str))
str = firstLetterToLower(toCamel(str))
if len(str) == 2 {
if str == "iD" {