feat: custom Bool type compatible with postgresql bool

This commit is contained in:
zhuyasen
2025-03-25 16:41:54 +08:00
parent 4fabf6f3a9
commit 5ccb334c21
5 changed files with 105 additions and 87 deletions

View File

@@ -51,12 +51,13 @@ const (
// DBDriverMongodb mongodb driver
DBDriverMongodb = "mongodb"
jsonTypeName = "datatypes.JSON"
jsonPkgPath = "gorm.io/datatypes"
boolTypeName = "sgorm.Bool"
boolPkgPath = "github.com/go-dev-frame/sponge/pkg/sgorm"
decimalTypeName = "decimal.Decimal"
decimalPkgPath = "github.com/shopspring/decimal"
jsonTypeName = "datatypes.JSON"
jsonPkgPath = "gorm.io/datatypes"
boolTypeName = "sgorm.Bool"
boolTypeTinyName = "sgorm.TinyBool"
boolPkgPath = "github.com/go-dev-frame/sponge/pkg/sgorm"
decimalTypeName = "decimal.Decimal"
decimalPkgPath = "github.com/shopspring/decimal"
unknownCustomType = "UnknownCustomType"
)
@@ -194,7 +195,7 @@ func (t tmplField) ConditionZero() string {
if t.DBDriver == DBDriverMysql || t.DBDriver == DBDriverPostgresql || t.DBDriver == DBDriverTidb {
if t.rewriterField != nil {
switch t.rewriterField.goType {
case boolTypeName:
case boolTypeName, boolTypeTinyName:
return ` != nil` //nolint
case jsonTypeName:
return `.String() != ""`
@@ -244,7 +245,7 @@ func (t tmplField) GoZero() string {
switch t.rewriterField.goType {
case jsonTypeName, decimalTypeName:
return ` = "string"`
case boolTypeName:
case boolTypeName, boolTypeTinyName:
return `= false`
}
}
@@ -292,7 +293,7 @@ func (t tmplField) GoTypeZero() string {
return `""` //nolint
case decimalTypeName:
return `""`
case boolTypeName:
case boolTypeName, boolTypeTinyName:
return `false`
}
}
@@ -686,7 +687,7 @@ func getModelStructCode(data tmplData, importPaths []string, isEmbed bool, jsonN
//case jsonTypeName, decimalTypeName:
// field.GoType = field.rewriterField.goType
// importPaths = append(importPaths, field.rewriterField.path)
case jsonTypeName, decimalTypeName, boolTypeName:
case jsonTypeName, decimalTypeName, boolTypeName, boolTypeTinyName:
field.GoType = "*" + field.rewriterField.goType
importPaths = append(importPaths, field.rewriterField.path)
}
@@ -737,7 +738,7 @@ func getModelStructCode(data tmplData, importPaths []string, isEmbed bool, jsonN
//case jsonTypeName, decimalTypeName:
// field.GoType = field.rewriterField.goType
// importPaths = append(importPaths, field.rewriterField.path)
case jsonTypeName, decimalTypeName, boolTypeName:
case jsonTypeName, decimalTypeName, boolTypeName, boolTypeTinyName:
field.GoType = "*" + field.rewriterField.goType
importPaths = append(importPaths, field.rewriterField.path)
}
@@ -1134,7 +1135,7 @@ func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path s
if strings.ToLower(colTp.String()) == "tinyint(1)" {
name = "bool"
rrField = &rewriterField{
goType: boolTypeName,
goType: boolTypeTinyName,
path: boolPkgPath,
}
} else {
@@ -1250,7 +1251,7 @@ func goTypeToProto(fields []tmplField, jsonNameType int, isCommonStyle bool) []t
switch field.rewriterField.goType {
case jsonTypeName, decimalTypeName:
field.GoType = "string"
case boolTypeName:
case boolTypeName, boolTypeTinyName:
field.GoType = "bool"
}
}