feat: support parsing mysql data types bit(1) and decimal

This commit is contained in:
zhuyasen
2025-01-19 21:55:27 +08:00
parent 2e313280ff
commit 44493211a5
6 changed files with 197 additions and 99 deletions

View File

@@ -168,6 +168,7 @@ func getCommonHandlerStructCodes(data tmplData, jsonNamedType int) (string, erro
} else {
field.JSONName = customToCamel(field.ColName) // camel case (default)
}
field.GoType = getHandlerGoType(&field)
newFields = append(newFields, field)
}
data.Fields = newFields
@@ -388,3 +389,23 @@ func customEndOfLetterToLower(srcStr string, str string) string {
return str
}
func getHandlerGoType(field *tmplField) string {
var goType = field.GoType
if field.DBDriver == DBDriverMysql || field.DBDriver == DBDriverPostgresql || field.DBDriver == DBDriverTidb {
if field.rewriterField != nil {
switch field.rewriterField.goType {
case jsonTypeName:
goType = "string"
case boolTypeName:
goType = "*bool"
case decimalTypeName:
goType = "string"
}
}
}
if field.GoType == "time.Time" {
goType = "*time.Time"
}
return goType
}

View File

@@ -289,7 +289,7 @@ func convertMongoToMysqlType(goType string) string {
case goTypeTime:
return "timestamp" //nolint
case goTypeBool:
return "tinyint(1)"
return "bit(1)"
case goTypeOID, goTypeNil, goTypeBytes, goTypeInterface, goTypeSliceInterface, goTypeInts, goTypeStrings:
return "json"
}
@@ -314,7 +314,7 @@ func convertToProtoFieldType(name string, goType string) string {
case "[]int32":
return "repeated int32"
case "[]byte":
return "string"
return "bytes"
case goTypeStrings:
return "repeated string"
}

View File

@@ -51,8 +51,14 @@ const (
// DBDriverMongodb mongodb driver
DBDriverMongodb = "mongodb"
jsonTypeName = "datatypes.JSON"
jsonPkgPath = "gorm.io/datatypes"
jsonTypeName = "datatypes.JSON"
jsonPkgPath = "gorm.io/datatypes"
boolTypeName = "sgorm.Bool"
boolPkgPath = "github.com/go-dev-frame/sponge/pkg/sgorm"
decimalTypeName = "decimal.Decimal"
decimalPkgPath = "github.com/shopspring/decimal"
unknownCustomType = "UnknownCustomType"
)
// Codes content
@@ -135,8 +141,8 @@ func ParseSQL(sql string, options ...Option) (map[string]string, error) {
CodeTypeProto: strings.Join(protoFileCodes, "\n\n"),
CodeTypeService: strings.Join(serviceStructCodes, "\n\n"),
TableName: strings.Join(tableNames, ", "),
CodeTypeCrudInfo: strings.Join(primaryKeysCodes, "||||"),
CodeTypeTableInfo: strings.Join(tableInfoCodes, "||||"),
CodeTypeCrudInfo: strings.Join(primaryKeysCodes, " |||| "),
CodeTypeTableInfo: strings.Join(tableInfoCodes, " |||| "),
}
return codesMap, nil
@@ -188,30 +194,46 @@ func (t tmplField) ConditionZero() string {
switch t.GoType {
case "int8", "int16", "int32", "int64", "int", "uint8", "uint16", "uint32", "uint64", "uint", "float64", "float32", //nolint
"sql.NullInt32", "sql.NullInt64", "sql.NullFloat64": //nolint
return `!= 0`
return ` != 0`
case "string", "sql.NullString": //nolint
return `!= ""`
return ` != ""`
case "time.Time", "*time.Time", "sql.NullTime": //nolint
return `.IsZero() == false`
case "[]byte", "[]string", "[]int", "interface{}": //nolint
return `!= nil` //nolint
return ` != nil` //nolint
case "bool": //nolint
return `!= false /*Warning: if the value itself is false, can't be updated*/`
return ` != false`
}
if t.DBDriver == DBDriverMongodb {
switch t.DBDriver {
case DBDriverMysql, DBDriverPostgresql, DBDriverTidb:
if t.rewriterField != nil {
switch t.rewriterField.goType {
case boolTypeName:
return ` != nil` //nolint
case jsonTypeName:
return `.String() != ""`
case decimalTypeName:
return `.IsZero() == false`
}
}
case DBDriverMongodb:
if t.GoType == goTypeOID {
return `!= primitive.NilObjectID`
return ` != primitive.NilObjectID`
}
if t.GoType == "*"+t.Name {
return `!= nil`
return ` != nil` //nolint
}
if strings.Contains(t.GoType, "[]") {
return `!= nil`
return ` != nil` //nolint
}
}
return `!= ` + t.GoType
if t.GoType == "" {
return ` != "unknown_zero_value"`
}
return ` != ` + t.GoType
}
// GoZero type of 0, used in model to json template code
@@ -230,7 +252,17 @@ func (t tmplField) GoZero() string {
return `= false`
}
if t.DBDriver == DBDriverMongodb {
switch t.DBDriver {
case DBDriverMysql, DBDriverPostgresql, DBDriverTidb:
if t.rewriterField != nil {
switch t.rewriterField.goType {
case jsonTypeName, decimalTypeName:
return ` = "string"`
case boolTypeName:
return `= false`
}
}
case DBDriverMongodb:
if t.GoType == goTypeOID {
return `= primitive.NilObjectID`
}
@@ -242,10 +274,14 @@ func (t tmplField) GoZero() string {
}
}
if t.GoType == "" {
return `!= "unknown_zero_value"`
}
return `= ` + t.GoType
}
// GoTypeZero type of 0, used in service template code
// GoTypeZero type of 0, used in service template code, corresponding protobuf type
func (t tmplField) GoTypeZero() string {
switch t.GoType {
case "int8", "int16", "int32", "int64", "int", "uint8", "uint16", "uint32", "uint64", "uint", "float64", "float32",
@@ -254,14 +290,26 @@ func (t tmplField) GoTypeZero() string {
case "string", "sql.NullString", jsonTypeName:
return `""`
case "time.Time", "*time.Time", "sql.NullTime":
return `0 /*time.Now().Second()*/`
return `""`
case "[]byte", "[]string", "[]int", "interface{}": //nolint
return `nil` //nolint
case "bool": //nolint
return `false`
}
if t.DBDriver == DBDriverMongodb {
switch t.DBDriver {
case DBDriverMysql, DBDriverPostgresql, DBDriverTidb:
if t.rewriterField != nil {
switch t.rewriterField.goType {
case jsonTypeName:
return `""` //nolint
case decimalTypeName:
return `""`
case boolTypeName:
return `false`
}
}
case DBDriverMongodb:
if t.GoType == goTypeOID {
return `primitive.NilObjectID`
}
@@ -273,6 +321,10 @@ func (t tmplField) GoTypeZero() string {
}
}
if t.GoType == "" {
return `"unknown_zero_value"`
}
return t.GoType
}
@@ -547,12 +599,12 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
return &codeText{tableInfo: tableInfo.getCode()}, nil
}
updateFieldsCode, err := getUpdateFieldsCode(data, opt.IsEmbed)
modelStructCode, importPaths, err := getModelStructCode(data, importPath, opt.IsEmbed, opt.JSONNamedType)
if err != nil {
return nil, err
}
modelStructCode, importPaths, err := getModelStructCode(data, importPath, opt.IsEmbed, opt.JSONNamedType)
updateFieldsCode, err := getUpdateFieldsCode(data, opt.IsEmbed)
if err != nil {
return nil, err
}
@@ -627,9 +679,13 @@ func getModelStructCode(data tmplData, importPaths []string, isEmbed bool, jsonN
switch field.DBDriver {
case DBDriverMysql, DBDriverTidb, DBDriverPostgresql:
if field.rewriterField != nil {
if field.rewriterField.goType == jsonTypeName {
field.GoType = jsonTypeName
importPaths = append(importPaths, jsonPkgPath)
switch field.rewriterField.goType {
//case jsonTypeName, decimalTypeName:
// field.GoType = field.rewriterField.goType
// importPaths = append(importPaths, field.rewriterField.path)
case jsonTypeName, decimalTypeName, boolTypeName:
field.GoType = "*" + field.rewriterField.goType
importPaths = append(importPaths, field.rewriterField.path)
}
}
}
@@ -653,37 +709,41 @@ func getModelStructCode(data tmplData, importPaths []string, isEmbed bool, jsonN
}
newImportPaths = append(newImportPaths, "github.com/go-dev-frame/sponge/pkg/sgorm")
} else {
for i, field := range data.Fields {
for _, field := range data.Fields {
switch field.DBDriver {
case DBDriverMongodb:
if field.Name == "ID" {
data.Fields[i].GoType = goTypeOID
field.GoType = goTypeOID
importPaths = append(importPaths, "go.mongodb.org/mongo-driver/bson/primitive")
}
default:
if strings.Contains(field.GoType, "time.Time") {
data.Fields[i].GoType = "*time.Time"
continue
field.GoType = "*time.Time"
}
// force conversion of ID field to uint64 type
if field.Name == "ID" {
data.Fields[i].GoType = "uint64"
field.GoType = "uint64"
if data.isCommonStyle(isEmbed) {
data.Fields[i].GoType = data.CrudInfo.GoType
field.GoType = data.CrudInfo.GoType
}
}
switch field.DBDriver {
case DBDriverMysql, DBDriverTidb, DBDriverPostgresql:
if field.DBDriver == DBDriverMysql || field.DBDriver == DBDriverPostgresql || field.DBDriver == DBDriverTidb {
if field.rewriterField != nil {
if field.rewriterField.goType == jsonTypeName {
data.Fields[i].GoType = jsonTypeName
importPaths = append(importPaths, jsonPkgPath)
switch field.rewriterField.goType {
//case jsonTypeName, decimalTypeName:
// field.GoType = field.rewriterField.goType
// importPaths = append(importPaths, field.rewriterField.path)
case jsonTypeName, decimalTypeName, boolTypeName:
field.GoType = "*" + field.rewriterField.goType
importPaths = append(importPaths, field.rewriterField.path)
}
}
}
}
newFields = append(newFields, field)
}
data.Fields = newFields
newImportPaths = importPaths
}
@@ -729,7 +789,7 @@ func getModelCode(data modelCodes) (string, error) {
code, err := format.Source([]byte(builder.String()))
if err != nil {
return "", fmt.Errorf("format.Source error: %v", err)
return "", fmt.Errorf("getModelCode format.Source error: %v", err)
}
return string(code), nil
@@ -784,6 +844,7 @@ func getHandlerStructCodes(data tmplData, jsonNamedType int) (string, error) {
} else {
field.JSONName = customToCamel(field.ColName) // camel case (default)
}
field.GoType = getHandlerGoType(&field)
newFields = append(newFields, field)
}
data.Fields = newFields
@@ -839,7 +900,7 @@ func getModelJSONCode(data tmplData) (string, error) {
code, err := format.Source([]byte(builder.String()))
if err != nil {
return "", fmt.Errorf("format.Source error: %v", err)
return "", fmt.Errorf("getModelJSONCode format.Source error: %v", err)
}
modelJSONCode := strings.ReplaceAll(string(code), " =", ":")
@@ -1042,33 +1103,51 @@ func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path s
if style == NullInSql {
path = "database/sql"
switch colTp.Tp {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong:
case mysql.TypeTiny:
name = "sql.NullInt8"
case mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeYear:
name = "sql.NullInt32"
case mysql.TypeLonglong:
case mysql.TypeLonglong, mysql.TypeDuration:
name = "sql.NullInt64"
case mysql.TypeFloat, mysql.TypeDouble:
name = "sql.NullFloat64"
case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString,
mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
name = "sql.NullString"
case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate:
case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate, mysql.TypeNewDate:
name = "sql.NullTime"
case mysql.TypeDecimal, mysql.TypeNewDecimal:
name = "sql.NullString"
case mysql.TypeJSON, mysql.TypeEnum:
case mysql.TypeJSON, mysql.TypeEnum, mysql.TypeSet, mysql.TypeGeometry:
name = "sql.NullString"
case mysql.TypeBit:
name = "sql.NullBool"
default:
return "UnSupport", "", nil
return unknownCustomType, "", nil
}
} else {
switch colTp.Tp {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong:
case mysql.TypeTiny:
if strings.ToLower(colTp.String()) == "tinyint(1)" {
name = "bool"
rrField = &rewriterField{
goType: boolTypeName,
path: boolPkgPath,
}
} else {
if mysql.HasUnsignedFlag(colTp.Flag) {
name = "uint"
} else {
name = "int"
}
}
case mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeYear:
if mysql.HasUnsignedFlag(colTp.Flag) {
name = "uint"
} else {
name = "int"
}
case mysql.TypeLonglong:
case mysql.TypeLonglong, mysql.TypeDuration:
if mysql.HasUnsignedFlag(colTp.Flag) {
name = "uint64"
} else {
@@ -1079,12 +1158,10 @@ func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path s
case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString,
mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
name = "string"
case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate:
case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate, mysql.TypeNewDate:
path = "time" //nolint
name = "time.Time"
case mysql.TypeDecimal, mysql.TypeNewDecimal:
name = "string"
case mysql.TypeEnum:
case mysql.TypeEnum, mysql.TypeSet, mysql.TypeGeometry:
name = "string"
case mysql.TypeJSON:
name = "string"
@@ -1092,8 +1169,24 @@ func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path s
goType: jsonTypeName,
path: jsonPkgPath,
}
case mysql.TypeBit:
if strings.ToLower(colTp.String()) == "bit(1)" {
name = "bool"
rrField = &rewriterField{
goType: boolTypeName,
path: boolPkgPath,
}
} else {
name = "[]byte"
}
case mysql.TypeDecimal, mysql.TypeNewDecimal:
name = "string"
rrField = &rewriterField{
goType: decimalTypeName,
path: decimalPkgPath,
}
default:
return "UnSupport", "", nil
return unknownCustomType, "", nil
}
if style == NullInPointer {
name = "*" + name
@@ -1150,6 +1243,15 @@ func goTypeToProto(fields []tmplField, jsonNameType int, isCommonStyle bool) []t
field.JSONName = customToCamel(field.ColName) // camel case (default)
}
if field.rewriterField != nil {
switch field.rewriterField.goType {
case jsonTypeName, decimalTypeName:
field.GoType = "string"
case boolTypeName:
field.GoType = "bool"
}
}
newFields = append(newFields, field)
}
return newFields

View File

@@ -10,6 +10,24 @@ import (
"github.com/zhufuyi/sqlparser/dependency/types"
)
func TestParseMysqlSQL(t *testing.T) {
sql := `CREATE TABLE orders (
order_id bigint NOT NULL AUTO_INCREMENT COMMENT 'order id',
user_id bigint NOT NULL COMMENT 'user id',
total_amount decimal(10,2) NOT NULL COMMENT 'total amount',
order_remark json NOT NULL COMMENT 'order remark',
order_status ENUM('Pending Payment', 'Paid', 'Shipped', 'Completed', 'Cancelled') NOT NULL DEFAULT 'active' COMMENT 'order status',
pay_type SET('Alipay', 'WeChat Pay', 'Bank Card') NOT NULL DEFAULT '' COMMENT 'pay type',
is_deleted bit(1) NOT NULL DEFAULT b'0' COMMENT '0-no1-yes',
created_time datetime NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time'
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='orders table';`
codes, err := ParseSQL(sql, WithJSONTag(1), WithDBDriver(DBDriverMysql))
assert.Nil(t, err)
assert.NotEmpty(t, codes)
//printCode(codes)
}
func TestParseSQL(t *testing.T) {
sqls := []string{`create table user (
id bigint unsigned auto_increment,
@@ -292,6 +310,7 @@ func Test_mysqlToGoType(t *testing.T) {
{Tp: mysql.TypeTimestamp},
{Tp: mysql.TypeDecimal},
{Tp: mysql.TypeJSON},
{Tp: mysql.TypeBit},
}
var names []string
for _, d := range fields {

View File

@@ -65,14 +65,12 @@ func (field *PGField) getMysqlType() string {
return "int"
case "bigint", "bigserial", "int8":
return "bigint"
case "real":
case "real", "float4":
return "float"
case "decimal", "numeric", "float4", "float8":
return "decimal"
case "double precision":
case "double precision", "float8":
return "double"
case "money":
return "varchar(30)"
case "decimal", "numeric", "money":
return "decimal(10, 2)"
case "character", "character varying", "varchar", "char", "bpchar":
if field.Lengthvar > 4 {
return fmt.Sprintf("varchar(%d)", field.Lengthvar-4)
@@ -92,9 +90,9 @@ func (field *PGField) getMysqlType() string {
case "json", "jsonb":
return "json"
case "boolean", "bool":
return "bool"
return "bit(1)"
case "bit":
return "tinyint(1)"
return "bit"
}
// unknown type convert to varchar
@@ -113,27 +111,6 @@ func (fields PGFields) getPrimaryField() *PGField {
return f
}
}
/*
// if no primary key, find the first xxx_id field
if f == nil {
for _, field := range fields {
if strings.HasSuffix(field.Name, "_id") {
f = field
f.IsPrimaryKey = true
return f
}
}
}
// if no xxx_id field, find the first field
if f == nil {
for _, field := range fields {
f = field
f.IsPrimaryKey = true
return f
}
}
*/
return f
}

View File

@@ -68,27 +68,6 @@ func (fields SqliteFields) getPrimaryField() *SqliteField {
return f
}
}
/*
// if no primary key, find the first xxx_id field
if f == nil {
for _, field := range fields {
if strings.HasSuffix(field.Name, "_id") {
f = field
f.Pk = 1
return f
}
}
}
// if no xxx_id field, find the first field
if f == nil {
for _, field := range fields {
f = field
f.Pk = 1
return f
}
}
*/
return f
}