|
|
|
|
@@ -51,8 +51,14 @@ const (
|
|
|
|
|
// DBDriverMongodb mongodb driver
|
|
|
|
|
DBDriverMongodb = "mongodb"
|
|
|
|
|
|
|
|
|
|
jsonTypeName = "datatypes.JSON"
|
|
|
|
|
jsonPkgPath = "gorm.io/datatypes"
|
|
|
|
|
jsonTypeName = "datatypes.JSON"
|
|
|
|
|
jsonPkgPath = "gorm.io/datatypes"
|
|
|
|
|
boolTypeName = "sgorm.Bool"
|
|
|
|
|
boolPkgPath = "github.com/go-dev-frame/sponge/pkg/sgorm"
|
|
|
|
|
decimalTypeName = "decimal.Decimal"
|
|
|
|
|
decimalPkgPath = "github.com/shopspring/decimal"
|
|
|
|
|
|
|
|
|
|
unknownCustomType = "UnknownCustomType"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// Codes content
|
|
|
|
|
@@ -135,8 +141,8 @@ func ParseSQL(sql string, options ...Option) (map[string]string, error) {
|
|
|
|
|
CodeTypeProto: strings.Join(protoFileCodes, "\n\n"),
|
|
|
|
|
CodeTypeService: strings.Join(serviceStructCodes, "\n\n"),
|
|
|
|
|
TableName: strings.Join(tableNames, ", "),
|
|
|
|
|
CodeTypeCrudInfo: strings.Join(primaryKeysCodes, "||||"),
|
|
|
|
|
CodeTypeTableInfo: strings.Join(tableInfoCodes, "||||"),
|
|
|
|
|
CodeTypeCrudInfo: strings.Join(primaryKeysCodes, " |||| "),
|
|
|
|
|
CodeTypeTableInfo: strings.Join(tableInfoCodes, " |||| "),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return codesMap, nil
|
|
|
|
|
@@ -188,30 +194,46 @@ func (t tmplField) ConditionZero() string {
|
|
|
|
|
switch t.GoType {
|
|
|
|
|
case "int8", "int16", "int32", "int64", "int", "uint8", "uint16", "uint32", "uint64", "uint", "float64", "float32", //nolint
|
|
|
|
|
"sql.NullInt32", "sql.NullInt64", "sql.NullFloat64": //nolint
|
|
|
|
|
return `!= 0`
|
|
|
|
|
return ` != 0`
|
|
|
|
|
case "string", "sql.NullString": //nolint
|
|
|
|
|
return `!= ""`
|
|
|
|
|
return ` != ""`
|
|
|
|
|
case "time.Time", "*time.Time", "sql.NullTime": //nolint
|
|
|
|
|
return `.IsZero() == false`
|
|
|
|
|
case "[]byte", "[]string", "[]int", "interface{}": //nolint
|
|
|
|
|
return `!= nil` //nolint
|
|
|
|
|
return ` != nil` //nolint
|
|
|
|
|
case "bool": //nolint
|
|
|
|
|
return `!= false /*Warning: if the value itself is false, can't be updated*/`
|
|
|
|
|
return ` != false`
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if t.DBDriver == DBDriverMongodb {
|
|
|
|
|
switch t.DBDriver {
|
|
|
|
|
case DBDriverMysql, DBDriverPostgresql, DBDriverTidb:
|
|
|
|
|
if t.rewriterField != nil {
|
|
|
|
|
switch t.rewriterField.goType {
|
|
|
|
|
case boolTypeName:
|
|
|
|
|
return ` != nil` //nolint
|
|
|
|
|
case jsonTypeName:
|
|
|
|
|
return `.String() != ""`
|
|
|
|
|
case decimalTypeName:
|
|
|
|
|
return `.IsZero() == false`
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
case DBDriverMongodb:
|
|
|
|
|
if t.GoType == goTypeOID {
|
|
|
|
|
return `!= primitive.NilObjectID`
|
|
|
|
|
return ` != primitive.NilObjectID`
|
|
|
|
|
}
|
|
|
|
|
if t.GoType == "*"+t.Name {
|
|
|
|
|
return `!= nil`
|
|
|
|
|
return ` != nil` //nolint
|
|
|
|
|
}
|
|
|
|
|
if strings.Contains(t.GoType, "[]") {
|
|
|
|
|
return `!= nil`
|
|
|
|
|
return ` != nil` //nolint
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return `!= ` + t.GoType
|
|
|
|
|
if t.GoType == "" {
|
|
|
|
|
return ` != "unknown_zero_value"`
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return ` != ` + t.GoType
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// GoZero type of 0, used in model to json template code
|
|
|
|
|
@@ -230,7 +252,17 @@ func (t tmplField) GoZero() string {
|
|
|
|
|
return `= false`
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if t.DBDriver == DBDriverMongodb {
|
|
|
|
|
switch t.DBDriver {
|
|
|
|
|
case DBDriverMysql, DBDriverPostgresql, DBDriverTidb:
|
|
|
|
|
if t.rewriterField != nil {
|
|
|
|
|
switch t.rewriterField.goType {
|
|
|
|
|
case jsonTypeName, decimalTypeName:
|
|
|
|
|
return ` = "string"`
|
|
|
|
|
case boolTypeName:
|
|
|
|
|
return `= false`
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
case DBDriverMongodb:
|
|
|
|
|
if t.GoType == goTypeOID {
|
|
|
|
|
return `= primitive.NilObjectID`
|
|
|
|
|
}
|
|
|
|
|
@@ -242,10 +274,14 @@ func (t tmplField) GoZero() string {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if t.GoType == "" {
|
|
|
|
|
return `!= "unknown_zero_value"`
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return `= ` + t.GoType
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// GoTypeZero type of 0, used in service template code
|
|
|
|
|
// GoTypeZero type of 0, used in service template code, corresponding protobuf type
|
|
|
|
|
func (t tmplField) GoTypeZero() string {
|
|
|
|
|
switch t.GoType {
|
|
|
|
|
case "int8", "int16", "int32", "int64", "int", "uint8", "uint16", "uint32", "uint64", "uint", "float64", "float32",
|
|
|
|
|
@@ -254,14 +290,26 @@ func (t tmplField) GoTypeZero() string {
|
|
|
|
|
case "string", "sql.NullString", jsonTypeName:
|
|
|
|
|
return `""`
|
|
|
|
|
case "time.Time", "*time.Time", "sql.NullTime":
|
|
|
|
|
return `0 /*time.Now().Second()*/`
|
|
|
|
|
return `""`
|
|
|
|
|
case "[]byte", "[]string", "[]int", "interface{}": //nolint
|
|
|
|
|
return `nil` //nolint
|
|
|
|
|
case "bool": //nolint
|
|
|
|
|
return `false`
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if t.DBDriver == DBDriverMongodb {
|
|
|
|
|
switch t.DBDriver {
|
|
|
|
|
case DBDriverMysql, DBDriverPostgresql, DBDriverTidb:
|
|
|
|
|
if t.rewriterField != nil {
|
|
|
|
|
switch t.rewriterField.goType {
|
|
|
|
|
case jsonTypeName:
|
|
|
|
|
return `""` //nolint
|
|
|
|
|
case decimalTypeName:
|
|
|
|
|
return `""`
|
|
|
|
|
case boolTypeName:
|
|
|
|
|
return `false`
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
case DBDriverMongodb:
|
|
|
|
|
if t.GoType == goTypeOID {
|
|
|
|
|
return `primitive.NilObjectID`
|
|
|
|
|
}
|
|
|
|
|
@@ -273,6 +321,10 @@ func (t tmplField) GoTypeZero() string {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if t.GoType == "" {
|
|
|
|
|
return `"unknown_zero_value"`
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return t.GoType
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -547,12 +599,12 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
|
|
|
|
|
return &codeText{tableInfo: tableInfo.getCode()}, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
updateFieldsCode, err := getUpdateFieldsCode(data, opt.IsEmbed)
|
|
|
|
|
modelStructCode, importPaths, err := getModelStructCode(data, importPath, opt.IsEmbed, opt.JSONNamedType)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
modelStructCode, importPaths, err := getModelStructCode(data, importPath, opt.IsEmbed, opt.JSONNamedType)
|
|
|
|
|
updateFieldsCode, err := getUpdateFieldsCode(data, opt.IsEmbed)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
@@ -627,9 +679,13 @@ func getModelStructCode(data tmplData, importPaths []string, isEmbed bool, jsonN
|
|
|
|
|
switch field.DBDriver {
|
|
|
|
|
case DBDriverMysql, DBDriverTidb, DBDriverPostgresql:
|
|
|
|
|
if field.rewriterField != nil {
|
|
|
|
|
if field.rewriterField.goType == jsonTypeName {
|
|
|
|
|
field.GoType = jsonTypeName
|
|
|
|
|
importPaths = append(importPaths, jsonPkgPath)
|
|
|
|
|
switch field.rewriterField.goType {
|
|
|
|
|
//case jsonTypeName, decimalTypeName:
|
|
|
|
|
// field.GoType = field.rewriterField.goType
|
|
|
|
|
// importPaths = append(importPaths, field.rewriterField.path)
|
|
|
|
|
case jsonTypeName, decimalTypeName, boolTypeName:
|
|
|
|
|
field.GoType = "*" + field.rewriterField.goType
|
|
|
|
|
importPaths = append(importPaths, field.rewriterField.path)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@@ -653,37 +709,41 @@ func getModelStructCode(data tmplData, importPaths []string, isEmbed bool, jsonN
|
|
|
|
|
}
|
|
|
|
|
newImportPaths = append(newImportPaths, "github.com/go-dev-frame/sponge/pkg/sgorm")
|
|
|
|
|
} else {
|
|
|
|
|
for i, field := range data.Fields {
|
|
|
|
|
for _, field := range data.Fields {
|
|
|
|
|
switch field.DBDriver {
|
|
|
|
|
case DBDriverMongodb:
|
|
|
|
|
if field.Name == "ID" {
|
|
|
|
|
data.Fields[i].GoType = goTypeOID
|
|
|
|
|
field.GoType = goTypeOID
|
|
|
|
|
importPaths = append(importPaths, "go.mongodb.org/mongo-driver/bson/primitive")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
default:
|
|
|
|
|
if strings.Contains(field.GoType, "time.Time") {
|
|
|
|
|
data.Fields[i].GoType = "*time.Time"
|
|
|
|
|
continue
|
|
|
|
|
field.GoType = "*time.Time"
|
|
|
|
|
}
|
|
|
|
|
// force conversion of ID field to uint64 type
|
|
|
|
|
if field.Name == "ID" {
|
|
|
|
|
data.Fields[i].GoType = "uint64"
|
|
|
|
|
field.GoType = "uint64"
|
|
|
|
|
if data.isCommonStyle(isEmbed) {
|
|
|
|
|
data.Fields[i].GoType = data.CrudInfo.GoType
|
|
|
|
|
field.GoType = data.CrudInfo.GoType
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
switch field.DBDriver {
|
|
|
|
|
case DBDriverMysql, DBDriverTidb, DBDriverPostgresql:
|
|
|
|
|
if field.DBDriver == DBDriverMysql || field.DBDriver == DBDriverPostgresql || field.DBDriver == DBDriverTidb {
|
|
|
|
|
if field.rewriterField != nil {
|
|
|
|
|
if field.rewriterField.goType == jsonTypeName {
|
|
|
|
|
data.Fields[i].GoType = jsonTypeName
|
|
|
|
|
importPaths = append(importPaths, jsonPkgPath)
|
|
|
|
|
switch field.rewriterField.goType {
|
|
|
|
|
//case jsonTypeName, decimalTypeName:
|
|
|
|
|
// field.GoType = field.rewriterField.goType
|
|
|
|
|
// importPaths = append(importPaths, field.rewriterField.path)
|
|
|
|
|
case jsonTypeName, decimalTypeName, boolTypeName:
|
|
|
|
|
field.GoType = "*" + field.rewriterField.goType
|
|
|
|
|
importPaths = append(importPaths, field.rewriterField.path)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
newFields = append(newFields, field)
|
|
|
|
|
}
|
|
|
|
|
data.Fields = newFields
|
|
|
|
|
newImportPaths = importPaths
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -729,7 +789,7 @@ func getModelCode(data modelCodes) (string, error) {
|
|
|
|
|
|
|
|
|
|
code, err := format.Source([]byte(builder.String()))
|
|
|
|
|
if err != nil {
|
|
|
|
|
return "", fmt.Errorf("format.Source error: %v", err)
|
|
|
|
|
return "", fmt.Errorf("getModelCode format.Source error: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return string(code), nil
|
|
|
|
|
@@ -784,6 +844,7 @@ func getHandlerStructCodes(data tmplData, jsonNamedType int) (string, error) {
|
|
|
|
|
} else {
|
|
|
|
|
field.JSONName = customToCamel(field.ColName) // camel case (default)
|
|
|
|
|
}
|
|
|
|
|
field.GoType = getHandlerGoType(&field)
|
|
|
|
|
newFields = append(newFields, field)
|
|
|
|
|
}
|
|
|
|
|
data.Fields = newFields
|
|
|
|
|
@@ -839,7 +900,7 @@ func getModelJSONCode(data tmplData) (string, error) {
|
|
|
|
|
|
|
|
|
|
code, err := format.Source([]byte(builder.String()))
|
|
|
|
|
if err != nil {
|
|
|
|
|
return "", fmt.Errorf("format.Source error: %v", err)
|
|
|
|
|
return "", fmt.Errorf("getModelJSONCode format.Source error: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
modelJSONCode := strings.ReplaceAll(string(code), " =", ":")
|
|
|
|
|
@@ -1042,33 +1103,51 @@ func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path s
|
|
|
|
|
if style == NullInSql {
|
|
|
|
|
path = "database/sql"
|
|
|
|
|
switch colTp.Tp {
|
|
|
|
|
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong:
|
|
|
|
|
case mysql.TypeTiny:
|
|
|
|
|
name = "sql.NullInt8"
|
|
|
|
|
case mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeYear:
|
|
|
|
|
name = "sql.NullInt32"
|
|
|
|
|
case mysql.TypeLonglong:
|
|
|
|
|
case mysql.TypeLonglong, mysql.TypeDuration:
|
|
|
|
|
name = "sql.NullInt64"
|
|
|
|
|
case mysql.TypeFloat, mysql.TypeDouble:
|
|
|
|
|
name = "sql.NullFloat64"
|
|
|
|
|
case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString,
|
|
|
|
|
mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
|
|
|
|
|
name = "sql.NullString"
|
|
|
|
|
case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate:
|
|
|
|
|
case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate, mysql.TypeNewDate:
|
|
|
|
|
name = "sql.NullTime"
|
|
|
|
|
case mysql.TypeDecimal, mysql.TypeNewDecimal:
|
|
|
|
|
name = "sql.NullString"
|
|
|
|
|
case mysql.TypeJSON, mysql.TypeEnum:
|
|
|
|
|
case mysql.TypeJSON, mysql.TypeEnum, mysql.TypeSet, mysql.TypeGeometry:
|
|
|
|
|
name = "sql.NullString"
|
|
|
|
|
case mysql.TypeBit:
|
|
|
|
|
name = "sql.NullBool"
|
|
|
|
|
default:
|
|
|
|
|
return "UnSupport", "", nil
|
|
|
|
|
return unknownCustomType, "", nil
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
switch colTp.Tp {
|
|
|
|
|
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong:
|
|
|
|
|
case mysql.TypeTiny:
|
|
|
|
|
if strings.ToLower(colTp.String()) == "tinyint(1)" {
|
|
|
|
|
name = "bool"
|
|
|
|
|
rrField = &rewriterField{
|
|
|
|
|
goType: boolTypeName,
|
|
|
|
|
path: boolPkgPath,
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if mysql.HasUnsignedFlag(colTp.Flag) {
|
|
|
|
|
name = "uint"
|
|
|
|
|
} else {
|
|
|
|
|
name = "int"
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
case mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeYear:
|
|
|
|
|
if mysql.HasUnsignedFlag(colTp.Flag) {
|
|
|
|
|
name = "uint"
|
|
|
|
|
} else {
|
|
|
|
|
name = "int"
|
|
|
|
|
}
|
|
|
|
|
case mysql.TypeLonglong:
|
|
|
|
|
case mysql.TypeLonglong, mysql.TypeDuration:
|
|
|
|
|
if mysql.HasUnsignedFlag(colTp.Flag) {
|
|
|
|
|
name = "uint64"
|
|
|
|
|
} else {
|
|
|
|
|
@@ -1079,12 +1158,10 @@ func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path s
|
|
|
|
|
case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString,
|
|
|
|
|
mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
|
|
|
|
|
name = "string"
|
|
|
|
|
case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate:
|
|
|
|
|
case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate, mysql.TypeNewDate:
|
|
|
|
|
path = "time" //nolint
|
|
|
|
|
name = "time.Time"
|
|
|
|
|
case mysql.TypeDecimal, mysql.TypeNewDecimal:
|
|
|
|
|
name = "string"
|
|
|
|
|
case mysql.TypeEnum:
|
|
|
|
|
case mysql.TypeEnum, mysql.TypeSet, mysql.TypeGeometry:
|
|
|
|
|
name = "string"
|
|
|
|
|
case mysql.TypeJSON:
|
|
|
|
|
name = "string"
|
|
|
|
|
@@ -1092,8 +1169,24 @@ func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path s
|
|
|
|
|
goType: jsonTypeName,
|
|
|
|
|
path: jsonPkgPath,
|
|
|
|
|
}
|
|
|
|
|
case mysql.TypeBit:
|
|
|
|
|
if strings.ToLower(colTp.String()) == "bit(1)" {
|
|
|
|
|
name = "bool"
|
|
|
|
|
rrField = &rewriterField{
|
|
|
|
|
goType: boolTypeName,
|
|
|
|
|
path: boolPkgPath,
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
name = "[]byte"
|
|
|
|
|
}
|
|
|
|
|
case mysql.TypeDecimal, mysql.TypeNewDecimal:
|
|
|
|
|
name = "string"
|
|
|
|
|
rrField = &rewriterField{
|
|
|
|
|
goType: decimalTypeName,
|
|
|
|
|
path: decimalPkgPath,
|
|
|
|
|
}
|
|
|
|
|
default:
|
|
|
|
|
return "UnSupport", "", nil
|
|
|
|
|
return unknownCustomType, "", nil
|
|
|
|
|
}
|
|
|
|
|
if style == NullInPointer {
|
|
|
|
|
name = "*" + name
|
|
|
|
|
@@ -1150,6 +1243,15 @@ func goTypeToProto(fields []tmplField, jsonNameType int, isCommonStyle bool) []t
|
|
|
|
|
field.JSONName = customToCamel(field.ColName) // camel case (default)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if field.rewriterField != nil {
|
|
|
|
|
switch field.rewriterField.goType {
|
|
|
|
|
case jsonTypeName, decimalTypeName:
|
|
|
|
|
field.GoType = "string"
|
|
|
|
|
case boolTypeName:
|
|
|
|
|
field.GoType = "bool"
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
newFields = append(newFields, field)
|
|
|
|
|
}
|
|
|
|
|
return newFields
|
|
|
|
|
|