mirror of
https://github.com/zhufuyi/sponge.git
synced 2025-10-04 00:16:25 +08:00
fix parse sql error
This commit is contained in:
@@ -45,6 +45,9 @@ const (
|
||||
DBDriverSqlite = "sqlite"
|
||||
// DBDriverMongodb mongodb driver
|
||||
DBDriverMongodb = "mongodb"
|
||||
|
||||
jsonTypeName = "datatypes.JSON"
|
||||
jsonPkgPath = "gorm.io/datatypes"
|
||||
)
|
||||
|
||||
// Codes content
|
||||
@@ -147,6 +150,13 @@ type tmplField struct {
|
||||
Comment string
|
||||
JSONName string
|
||||
DBDriver string
|
||||
|
||||
rewriterField *rewriterField
|
||||
}
|
||||
|
||||
type rewriterField struct {
|
||||
goType string
|
||||
path string
|
||||
}
|
||||
|
||||
// ConditionZero type of condition 0, used in dao template code
|
||||
@@ -217,7 +227,7 @@ func (t tmplField) GoTypeZero() string {
|
||||
case "int8", "int16", "int32", "int64", "int", "uint8", "uint16", "uint32", "uint64", "uint", "float64", "float32",
|
||||
"sql.NullInt32", "sql.NullInt64", "sql.NullFloat64":
|
||||
return `0`
|
||||
case "string", "sql.NullString":
|
||||
case "string", "sql.NullString", jsonTypeName:
|
||||
return `""`
|
||||
case "time.Time", "*time.Time", "sql.NullTime":
|
||||
return `0 /*time.Now().Second()*/`
|
||||
@@ -446,11 +456,17 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
|
||||
if !canNull {
|
||||
nullStyle = NullDisable
|
||||
}
|
||||
goType, pkg := mysqlToGoType(col.Tp, nullStyle)
|
||||
goType, pkg, rrField := mysqlToGoType(col.Tp, nullStyle)
|
||||
if pkg != "" {
|
||||
importPath = append(importPath, pkg)
|
||||
}
|
||||
field.GoType = goType
|
||||
field.rewriterField = rrField
|
||||
if opt.DBDriver == DBDriverPostgresql {
|
||||
if opt.FieldTypes[colName] == "bool" {
|
||||
field.GoType = "bool" // rewritten type
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data.Fields = append(data.Fields, field)
|
||||
@@ -522,6 +538,15 @@ func getModelStructCode(data tmplData, importPaths []string, isEmbed bool) (stri
|
||||
if isIgnoreFields(field.ColName) {
|
||||
continue
|
||||
}
|
||||
switch field.DBDriver {
|
||||
case DBDriverMysql, DBDriverTidb, DBDriverPostgresql:
|
||||
if field.rewriterField != nil {
|
||||
if field.rewriterField.goType == jsonTypeName {
|
||||
field.GoType = jsonTypeName
|
||||
importPaths = append(importPaths, jsonPkgPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
newFields = append(newFields, field)
|
||||
if strings.Contains(field.GoType, "time.Time") {
|
||||
isHaveTimeType = true
|
||||
@@ -559,6 +584,15 @@ func getModelStructCode(data tmplData, importPaths []string, isEmbed bool) (stri
|
||||
if field.Name == "ID" {
|
||||
data.Fields[i].GoType = "uint64"
|
||||
}
|
||||
switch field.DBDriver {
|
||||
case DBDriverMysql, DBDriverTidb, DBDriverPostgresql:
|
||||
if field.rewriterField != nil {
|
||||
if field.rewriterField.goType == jsonTypeName {
|
||||
data.Fields[i].GoType = jsonTypeName
|
||||
importPaths = append(importPaths, jsonPkgPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
newImportPaths = importPaths
|
||||
@@ -618,6 +652,14 @@ func getUpdateFieldsCode(data tmplData, isEmbed bool) (string, error) {
|
||||
if isIgnoreFields(field.ColName, falseColumns...) || field.ColName == columnID || field.ColName == _columnID {
|
||||
continue
|
||||
}
|
||||
switch field.DBDriver {
|
||||
case DBDriverMysql, DBDriverTidb, DBDriverPostgresql:
|
||||
if field.rewriterField != nil {
|
||||
if field.rewriterField.goType == jsonTypeName {
|
||||
field.GoType = "[]byte"
|
||||
}
|
||||
}
|
||||
}
|
||||
newFields = append(newFields, field)
|
||||
}
|
||||
data.Fields = newFields
|
||||
@@ -889,7 +931,7 @@ func addCommaToJSON(modelJSONCode string) string {
|
||||
}
|
||||
|
||||
// nolint
|
||||
func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path string) {
|
||||
func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path string, rrField *rewriterField) {
|
||||
if style == NullInSql {
|
||||
path = "database/sql"
|
||||
switch colTp.Tp {
|
||||
@@ -909,7 +951,7 @@ func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path s
|
||||
case mysql.TypeJSON, mysql.TypeEnum:
|
||||
name = "sql.NullString"
|
||||
default:
|
||||
return "UnSupport", ""
|
||||
return "UnSupport", "", nil
|
||||
}
|
||||
} else {
|
||||
switch colTp.Tp {
|
||||
@@ -935,16 +977,22 @@ func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path s
|
||||
name = "time.Time"
|
||||
case mysql.TypeDecimal, mysql.TypeNewDecimal:
|
||||
name = "string"
|
||||
case mysql.TypeJSON, mysql.TypeEnum:
|
||||
case mysql.TypeEnum:
|
||||
name = "string"
|
||||
case mysql.TypeJSON:
|
||||
name = "string"
|
||||
rrField = &rewriterField{
|
||||
goType: jsonTypeName,
|
||||
path: jsonPkgPath,
|
||||
}
|
||||
default:
|
||||
return "UnSupport", ""
|
||||
return "UnSupport", "", nil
|
||||
}
|
||||
if style == NullInPointer {
|
||||
name = "*" + name
|
||||
}
|
||||
}
|
||||
return name, path
|
||||
return name, path, rrField
|
||||
}
|
||||
|
||||
// nolint
|
||||
@@ -970,6 +1018,8 @@ func goTypeToProto(fields []tmplField) []tmplField {
|
||||
field.GoType = "string"
|
||||
case goTypeStrings:
|
||||
field.GoType = "repeated string"
|
||||
case jsonTypeName:
|
||||
field.GoType = "string"
|
||||
}
|
||||
|
||||
if field.DBDriver == DBDriverMongodb && field.GoType != "" {
|
||||
|
Reference in New Issue
Block a user