fix parse sql error

This commit is contained in:
zhuyasen
2024-06-16 11:58:28 +08:00
parent 5791c487e0
commit eef4b3d20a
4 changed files with 91 additions and 23 deletions

View File

@@ -45,6 +45,9 @@ const (
DBDriverSqlite = "sqlite"
// DBDriverMongodb mongodb driver
DBDriverMongodb = "mongodb"
jsonTypeName = "datatypes.JSON"
jsonPkgPath = "gorm.io/datatypes"
)
// Codes content
@@ -147,6 +150,13 @@ type tmplField struct {
Comment string
JSONName string
DBDriver string
rewriterField *rewriterField
}
type rewriterField struct {
goType string
path string
}
// ConditionZero type of condition 0, used in dao template code
@@ -217,7 +227,7 @@ func (t tmplField) GoTypeZero() string {
case "int8", "int16", "int32", "int64", "int", "uint8", "uint16", "uint32", "uint64", "uint", "float64", "float32",
"sql.NullInt32", "sql.NullInt64", "sql.NullFloat64":
return `0`
case "string", "sql.NullString":
case "string", "sql.NullString", jsonTypeName:
return `""`
case "time.Time", "*time.Time", "sql.NullTime":
return `0 /*time.Now().Second()*/`
@@ -446,11 +456,17 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
if !canNull {
nullStyle = NullDisable
}
goType, pkg := mysqlToGoType(col.Tp, nullStyle)
goType, pkg, rrField := mysqlToGoType(col.Tp, nullStyle)
if pkg != "" {
importPath = append(importPath, pkg)
}
field.GoType = goType
field.rewriterField = rrField
if opt.DBDriver == DBDriverPostgresql {
if opt.FieldTypes[colName] == "bool" {
field.GoType = "bool" // rewritten type
}
}
}
data.Fields = append(data.Fields, field)
@@ -522,6 +538,15 @@ func getModelStructCode(data tmplData, importPaths []string, isEmbed bool) (stri
if isIgnoreFields(field.ColName) {
continue
}
switch field.DBDriver {
case DBDriverMysql, DBDriverTidb, DBDriverPostgresql:
if field.rewriterField != nil {
if field.rewriterField.goType == jsonTypeName {
field.GoType = jsonTypeName
importPaths = append(importPaths, jsonPkgPath)
}
}
}
newFields = append(newFields, field)
if strings.Contains(field.GoType, "time.Time") {
isHaveTimeType = true
@@ -559,6 +584,15 @@ func getModelStructCode(data tmplData, importPaths []string, isEmbed bool) (stri
if field.Name == "ID" {
data.Fields[i].GoType = "uint64"
}
switch field.DBDriver {
case DBDriverMysql, DBDriverTidb, DBDriverPostgresql:
if field.rewriterField != nil {
if field.rewriterField.goType == jsonTypeName {
data.Fields[i].GoType = jsonTypeName
importPaths = append(importPaths, jsonPkgPath)
}
}
}
}
}
newImportPaths = importPaths
@@ -618,6 +652,14 @@ func getUpdateFieldsCode(data tmplData, isEmbed bool) (string, error) {
if isIgnoreFields(field.ColName, falseColumns...) || field.ColName == columnID || field.ColName == _columnID {
continue
}
switch field.DBDriver {
case DBDriverMysql, DBDriverTidb, DBDriverPostgresql:
if field.rewriterField != nil {
if field.rewriterField.goType == jsonTypeName {
field.GoType = "[]byte"
}
}
}
newFields = append(newFields, field)
}
data.Fields = newFields
@@ -889,7 +931,7 @@ func addCommaToJSON(modelJSONCode string) string {
}
// nolint
func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path string) {
func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path string, rrField *rewriterField) {
if style == NullInSql {
path = "database/sql"
switch colTp.Tp {
@@ -909,7 +951,7 @@ func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path s
case mysql.TypeJSON, mysql.TypeEnum:
name = "sql.NullString"
default:
return "UnSupport", ""
return "UnSupport", "", nil
}
} else {
switch colTp.Tp {
@@ -935,16 +977,22 @@ func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path s
name = "time.Time"
case mysql.TypeDecimal, mysql.TypeNewDecimal:
name = "string"
case mysql.TypeJSON, mysql.TypeEnum:
case mysql.TypeEnum:
name = "string"
case mysql.TypeJSON:
name = "string"
rrField = &rewriterField{
goType: jsonTypeName,
path: jsonPkgPath,
}
default:
return "UnSupport", ""
return "UnSupport", "", nil
}
if style == NullInPointer {
name = "*" + name
}
}
return name, path
return name, path, rrField
}
// nolint
@@ -970,6 +1018,8 @@ func goTypeToProto(fields []tmplField) []tmplField {
field.GoType = "string"
case goTypeStrings:
field.GoType = "repeated string"
case jsonTypeName:
field.GoType = "string"
}
if field.DBDriver == DBDriverMongodb && field.GoType != "" {