mirror of
https://github.com/zhufuyi/sponge.git
synced 2025-12-24 10:40:55 +08:00
feat: support custom table primary key type and name
This commit is contained in:
@@ -5,18 +5,19 @@ package parser
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/blastrain/vitess-sqlparser/tidbparser/ast"
|
||||
"github.com/blastrain/vitess-sqlparser/tidbparser/dependency/mysql"
|
||||
"github.com/blastrain/vitess-sqlparser/tidbparser/dependency/types"
|
||||
"github.com/blastrain/vitess-sqlparser/tidbparser/parser"
|
||||
"github.com/huandu/xstrings"
|
||||
"github.com/jinzhu/inflection"
|
||||
"github.com/zhufuyi/sqlparser/ast"
|
||||
"github.com/zhufuyi/sqlparser/dependency/mysql"
|
||||
"github.com/zhufuyi/sqlparser/dependency/types"
|
||||
"github.com/zhufuyi/sqlparser/parser"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -34,6 +35,8 @@ const (
|
||||
CodeTypeProto = "proto"
|
||||
// CodeTypeService grpc service code
|
||||
CodeTypeService = "service"
|
||||
// CodeTypeCrudInfo crud info json data
|
||||
CodeTypeCrudInfo = "crud_info"
|
||||
|
||||
// DBDriverMysql mysql driver
|
||||
DBDriverMysql = "mysql"
|
||||
@@ -68,6 +71,7 @@ type modelCodes struct {
|
||||
// ParseSQL generate different usage codes based on sql
|
||||
func ParseSQL(sql string, options ...Option) (map[string]string, error) {
|
||||
initTemplate()
|
||||
initCommonTemplate()
|
||||
opt := parseOption(options)
|
||||
|
||||
stmts, err := parser.New().Parse(sql, opt.Charset, opt.Collation)
|
||||
@@ -82,6 +86,7 @@ func ParseSQL(sql string, options ...Option) (map[string]string, error) {
|
||||
modelJSONCodes := make([]string, 0, len(stmts))
|
||||
importPath := make(map[string]struct{})
|
||||
tableNames := make([]string, 0, len(stmts))
|
||||
primaryKeysCodes := make([]string, 0, len(stmts))
|
||||
for _, stmt := range stmts {
|
||||
if ct, ok := stmt.(*ast.CreateTableStmt); ok {
|
||||
code, err2 := makeCode(ct, opt)
|
||||
@@ -95,6 +100,7 @@ func ParseSQL(sql string, options ...Option) (map[string]string, error) {
|
||||
serviceStructCodes = append(serviceStructCodes, code.serviceStruct)
|
||||
modelJSONCodes = append(modelJSONCodes, code.modelJSON)
|
||||
tableNames = append(tableNames, toCamel(ct.Table.Name.String()))
|
||||
primaryKeysCodes = append(primaryKeysCodes, code.crudInfo)
|
||||
for _, s := range code.importPaths {
|
||||
importPath[s] = struct{}{}
|
||||
}
|
||||
@@ -118,13 +124,14 @@ func ParseSQL(sql string, options ...Option) (map[string]string, error) {
|
||||
}
|
||||
|
||||
var codesMap = map[string]string{
|
||||
CodeTypeModel: modelCode,
|
||||
CodeTypeJSON: strings.Join(modelJSONCodes, "\n\n"),
|
||||
CodeTypeDAO: strings.Join(updateFieldsCodes, "\n\n"),
|
||||
CodeTypeHandler: strings.Join(handlerStructCodes, "\n\n"),
|
||||
CodeTypeProto: strings.Join(protoFileCodes, "\n\n"),
|
||||
CodeTypeService: strings.Join(serviceStructCodes, "\n\n"),
|
||||
TableName: strings.Join(tableNames, ", "),
|
||||
CodeTypeModel: modelCode,
|
||||
CodeTypeJSON: strings.Join(modelJSONCodes, "\n\n"),
|
||||
CodeTypeDAO: strings.Join(updateFieldsCodes, "\n\n"),
|
||||
CodeTypeHandler: strings.Join(handlerStructCodes, "\n\n"),
|
||||
CodeTypeProto: strings.Join(protoFileCodes, "\n\n"),
|
||||
CodeTypeService: strings.Join(serviceStructCodes, "\n\n"),
|
||||
TableName: strings.Join(tableNames, ", "),
|
||||
CodeTypeCrudInfo: strings.Join(primaryKeysCodes, "||||"),
|
||||
}
|
||||
|
||||
return codesMap, nil
|
||||
@@ -140,16 +147,19 @@ type tmplData struct {
|
||||
SubStructs string // sub structs for model
|
||||
ProtoSubStructs string // sub structs for protobuf
|
||||
DBDriver string
|
||||
|
||||
CrudInfo *CrudInfo
|
||||
}
|
||||
|
||||
type tmplField struct {
|
||||
Name string
|
||||
ColName string
|
||||
GoType string
|
||||
Tag string
|
||||
Comment string
|
||||
JSONName string
|
||||
DBDriver string
|
||||
IsPrimaryKey bool // is primary key
|
||||
ColName string // table column name
|
||||
Name string // convert to camel case
|
||||
GoType string // convert to go type
|
||||
Tag string
|
||||
Comment string
|
||||
JSONName string
|
||||
DBDriver string
|
||||
|
||||
rewriterField *rewriterField
|
||||
}
|
||||
@@ -159,6 +169,13 @@ type rewriterField struct {
|
||||
path string
|
||||
}
|
||||
|
||||
func (d tmplData) isCommonStyle(isEmbed bool) bool {
|
||||
if d.DBDriver != DBDriverMongodb && !isEmbed && !d.CrudInfo.isIDPrimaryKey() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ConditionZero type of condition 0, used in dao template code
|
||||
func (t tmplField) ConditionZero() string {
|
||||
switch t.GoType {
|
||||
@@ -260,12 +277,33 @@ func (t tmplField) AddOne(i int) int {
|
||||
// AddOneWithTag counter and add id tag
|
||||
func (t tmplField) AddOneWithTag(i int) string {
|
||||
if t.ColName == "id" {
|
||||
return fmt.Sprintf(`%d [(tagger.tags) = "uri:\"id\"" ]`, i+1)
|
||||
if t.DBDriver == DBDriverMongodb {
|
||||
return fmt.Sprintf(`%d [(validate.rules).string.min_len = 6, (tagger.tags) = "uri:\"id\""]`, i+1)
|
||||
}
|
||||
return fmt.Sprintf(`%d [(validate.rules).%s.gt = 0, (tagger.tags) = "uri:\"id\""]`, i+1, t.GoType)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%d", i+1)
|
||||
}
|
||||
|
||||
func (t tmplField) AddOneWithTag2(i int) string {
|
||||
if t.IsPrimaryKey || t.ColName == "id" {
|
||||
if t.GoType == "string" {
|
||||
return fmt.Sprintf(`%d [(validate.rules).string.min_len = 1, (tagger.tags) = "uri:\"%s\""]`, i+1, t.JSONName)
|
||||
}
|
||||
return fmt.Sprintf(`%d [(validate.rules).%s.gt = 0, (tagger.tags) = "uri:\"%s\""]`, i+1, t.GoType, t.JSONName)
|
||||
}
|
||||
return fmt.Sprintf("%d", i+1)
|
||||
}
|
||||
|
||||
func getProtoFieldName(fields []tmplField) string {
|
||||
for _, field := range fields {
|
||||
if field.IsPrimaryKey || field.ColName == "id" {
|
||||
return field.JSONName
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
const (
|
||||
__mysqlModel__ = "__mysqlModel__" //nolint
|
||||
__type__ = "__type__" //nolint
|
||||
@@ -312,6 +350,7 @@ type codeText struct {
|
||||
handlerStruct string
|
||||
protoFile string
|
||||
serviceStruct string
|
||||
crudInfo string
|
||||
}
|
||||
|
||||
// nolint
|
||||
@@ -320,7 +359,8 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
|
||||
data := tmplData{
|
||||
TableName: stmt.Table.Name.String(),
|
||||
RawTableName: stmt.Table.Name.String(),
|
||||
Fields: make([]tmplField, 0, 1),
|
||||
//Fields: make([]tmplField, 0, 1),
|
||||
DBDriver: opt.DBDriver,
|
||||
}
|
||||
tablePrefix := opt.TablePrefix
|
||||
if tablePrefix != "" && strings.HasPrefix(data.TableName, tablePrefix) {
|
||||
@@ -340,7 +380,7 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
|
||||
}
|
||||
|
||||
data.TableName = toCamel(data.TableName)
|
||||
data.TName = firstLetterToLow(data.TableName)
|
||||
data.TName = firstLetterToLower(data.TableName)
|
||||
|
||||
// find table comment
|
||||
for _, o := range stmt.Options {
|
||||
@@ -391,6 +431,7 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
|
||||
}
|
||||
}
|
||||
if isPrimaryKey[colName] {
|
||||
field.IsPrimaryKey = true
|
||||
gormTag.WriteString(";primary_key")
|
||||
}
|
||||
isNotNull := false
|
||||
@@ -444,7 +485,7 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
|
||||
|
||||
default: // gorm
|
||||
if !isPrimaryKey[colName] && isNotNull {
|
||||
gormTag.WriteString(";NOT NULL")
|
||||
gormTag.WriteString(";not null")
|
||||
}
|
||||
tags = append(tags, "gorm", gormTag.String())
|
||||
|
||||
@@ -473,24 +514,26 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
|
||||
|
||||
data.Fields = append(data.Fields, field)
|
||||
}
|
||||
|
||||
if len(data.Fields) == 0 {
|
||||
return nil, errors.New("no columns found in table " + data.TableName)
|
||||
}
|
||||
|
||||
data.CrudInfo = newCrudInfo(data)
|
||||
data.CrudInfo.IsCommonType = data.isCommonStyle(opt.IsEmbed)
|
||||
|
||||
if v, ok := opt.FieldTypes[SubStructKey]; ok {
|
||||
data.SubStructs = v
|
||||
}
|
||||
if v, ok := opt.FieldTypes[ProtoSubStructKey]; ok {
|
||||
data.ProtoSubStructs = v
|
||||
}
|
||||
data.DBDriver = opt.DBDriver
|
||||
|
||||
updateFieldsCode, err := getUpdateFieldsCode(data, opt.IsEmbed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
handlerStructCode, err := getHandlerStructCodes(data, opt.JSONNamedType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelStructCode, importPaths, err := getModelStructCode(data, importPath, opt.IsEmbed, opt.JSONNamedType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -501,14 +544,35 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFileCode, err := getProtoFileCode(data, opt.JSONNamedType, opt.IsWebProto, opt.IsExtendedAPI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serviceStructCode, err := getServiceStructCode(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
handlerStructCode := ""
|
||||
serviceStructCode := ""
|
||||
protoFileCode := ""
|
||||
if data.isCommonStyle(opt.IsEmbed) {
|
||||
handlerStructCode, err = getCommonHandlerStructCodes(data, opt.JSONNamedType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
serviceStructCode, err = getCommonServiceStructCode(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
protoFileCode, err = getCommonProtoFileCode(data, opt.JSONNamedType, opt.IsWebProto, opt.IsExtendedAPI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
handlerStructCode, err = getHandlerStructCodes(data, opt.JSONNamedType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
serviceStructCode, err = getServiceStructCode(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
protoFileCode, err = getProtoFileCode(data, opt.JSONNamedType, opt.IsWebProto, opt.IsExtendedAPI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &codeText{
|
||||
@@ -519,6 +583,7 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
|
||||
handlerStruct: handlerStructCode,
|
||||
protoFile: protoFileCode,
|
||||
serviceStruct: serviceStructCode,
|
||||
crudInfo: data.CrudInfo.getCode(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -586,6 +651,9 @@ func getModelStructCode(data tmplData, importPaths []string, isEmbed bool, jsonN
|
||||
// force conversion of ID field to uint64 type
|
||||
if field.Name == "ID" {
|
||||
data.Fields[i].GoType = "uint64"
|
||||
if data.isCommonStyle(isEmbed) {
|
||||
data.Fields[i].GoType = data.CrudInfo.GoType
|
||||
}
|
||||
}
|
||||
switch field.DBDriver {
|
||||
case DBDriverMysql, DBDriverTidb, DBDriverPostgresql:
|
||||
@@ -763,7 +831,7 @@ func getModelJSONCode(data tmplData) (string, error) {
|
||||
}
|
||||
|
||||
func getProtoFileCode(data tmplData, jsonNamedType int, isWebProto bool, isExtendedAPI bool) (string, error) {
|
||||
data.Fields = goTypeToProto(data.Fields, jsonNamedType)
|
||||
data.Fields = goTypeToProto(data.Fields, jsonNamedType, false)
|
||||
|
||||
var err error
|
||||
builder := strings.Builder{}
|
||||
@@ -790,20 +858,20 @@ func getProtoFileCode(data tmplData, jsonNamedType int, isWebProto bool, isExten
|
||||
|
||||
protoMessageCreateCode, err := tmplExecuteWithFilter(data, protoMessageCreateTmpl)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("handlerCreateStructTmpl error: %v", err)
|
||||
return "", fmt.Errorf("handle protoMessageCreateTmpl error: %v", err)
|
||||
}
|
||||
|
||||
protoMessageUpdateCode, err := tmplExecuteWithFilter(data, protoMessageUpdateTmpl, columnID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("handlerCreateStructTmpl error: %v", err)
|
||||
return "", fmt.Errorf("handle protoMessageUpdateTmpl error: %v", err)
|
||||
}
|
||||
if !isWebProto {
|
||||
protoMessageUpdateCode = strings.ReplaceAll(protoMessageUpdateCode, ` [(tagger.tags) = "uri:\"id\"" ]`, "")
|
||||
protoMessageUpdateCode = strings.ReplaceAll(protoMessageUpdateCode, `, (tagger.tags) = "uri:\"id\""`, "")
|
||||
}
|
||||
|
||||
protoMessageDetailCode, err := tmplExecuteWithFilter(data, protoMessageDetailTmpl, columnID, columnCreatedAt, columnUpdatedAt)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("handlerCreateStructTmpl error: %v", err)
|
||||
return "", fmt.Errorf("handle protoMessageDetailTmpl error: %v", err)
|
||||
}
|
||||
|
||||
code = strings.ReplaceAll(code, "// protoMessageCreateCode", protoMessageCreateCode)
|
||||
@@ -901,13 +969,13 @@ func getServiceStructCode(data tmplData) (string, error) {
|
||||
|
||||
serviceCreateStructCode, err := tmplExecuteWithFilter(data, serviceCreateStructTmpl)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("handlerCreateStructTmpl error: %v", err)
|
||||
return "", fmt.Errorf("handle serviceCreateStructTmpl error: %v", err)
|
||||
}
|
||||
serviceCreateStructCode = strings.ReplaceAll(serviceCreateStructCode, "ID:", "Id:")
|
||||
|
||||
serviceUpdateStructCode, err := tmplExecuteWithFilter(data, serviceUpdateStructTmpl, columnID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("handlerCreateStructTmpl error: %v", err)
|
||||
return "", fmt.Errorf("handle serviceUpdateStructTmpl error: %v", err)
|
||||
}
|
||||
serviceUpdateStructCode = strings.ReplaceAll(serviceUpdateStructCode, "ID:", "Id:")
|
||||
|
||||
@@ -1017,7 +1085,7 @@ func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path s
|
||||
}
|
||||
|
||||
// nolint
|
||||
func goTypeToProto(fields []tmplField, jsonNameType int) []tmplField {
|
||||
func goTypeToProto(fields []tmplField, jsonNameType int, isCommonStyle bool) []tmplField {
|
||||
var newFields []tmplField
|
||||
for _, field := range fields {
|
||||
switch field.GoType {
|
||||
@@ -1053,7 +1121,7 @@ func goTypeToProto(fields []tmplField, jsonNameType int) []tmplField {
|
||||
field.GoType = "repeated string"
|
||||
}
|
||||
} else {
|
||||
if strings.ToLower(field.Name) == "id" {
|
||||
if strings.ToLower(field.Name) == "id" && !isCommonStyle {
|
||||
field.GoType = "uint64"
|
||||
}
|
||||
}
|
||||
@@ -1154,7 +1222,7 @@ func toCamel(s string) string {
|
||||
return str
|
||||
}
|
||||
|
||||
func firstLetterToLow(str string) string {
|
||||
func firstLetterToLower(str string) string {
|
||||
if len(str) == 0 {
|
||||
return str
|
||||
}
|
||||
@@ -1167,7 +1235,7 @@ func firstLetterToLow(str string) string {
|
||||
}
|
||||
|
||||
func customToCamel(str string) string {
|
||||
str = firstLetterToLow(toCamel(str))
|
||||
str = firstLetterToLower(toCamel(str))
|
||||
|
||||
if len(str) == 2 {
|
||||
if str == "iD" {
|
||||
|
||||
Reference in New Issue
Block a user