mirror of
https://github.com/zhufuyi/sponge.git
synced 2025-10-04 00:16:25 +08:00
support mongodb generate code
This commit is contained in:
@@ -43,6 +43,8 @@ const (
|
||||
DBDriverTidb = "tidb"
|
||||
// DBDriverSqlite sqlite driver
|
||||
DBDriverSqlite = "sqlite"
|
||||
// DBDriverMongodb mongodb driver
|
||||
DBDriverMongodb = "mongodb"
|
||||
)
|
||||
|
||||
// Codes content
|
||||
@@ -126,12 +128,15 @@ func ParseSQL(sql string, options ...Option) (map[string]string, error) {
|
||||
}
|
||||
|
||||
type tmplData struct {
|
||||
TableName string
|
||||
TName string
|
||||
NameFunc bool
|
||||
RawTableName string
|
||||
Fields []tmplField
|
||||
Comment string
|
||||
TableName string
|
||||
TName string
|
||||
NameFunc bool
|
||||
RawTableName string
|
||||
Fields []tmplField
|
||||
Comment string
|
||||
SubStructs string // sub structs for model
|
||||
ProtoSubStructs string // sub structs for protobuf
|
||||
DBDriver string
|
||||
}
|
||||
|
||||
type tmplField struct {
|
||||
@@ -141,9 +146,10 @@ type tmplField struct {
|
||||
Tag string
|
||||
Comment string
|
||||
JSONName string
|
||||
DBDriver string
|
||||
}
|
||||
|
||||
// ConditionZero type of condition 0
|
||||
// ConditionZero type of condition 0, used in dao template code
|
||||
func (t tmplField) ConditionZero() string {
|
||||
switch t.GoType {
|
||||
case "int8", "int16", "int32", "int64", "int", "uint8", "uint16", "uint32", "uint64", "uint", "float64", "float32", //nolint
|
||||
@@ -153,12 +159,28 @@ func (t tmplField) ConditionZero() string {
|
||||
return `!= ""`
|
||||
case "time.Time", "*time.Time", "sql.NullTime": //nolint
|
||||
return `.IsZero() == false`
|
||||
case "[]byte", "[]string", "[]int", "interface{}": //nolint
|
||||
return `!= nil` //nolint
|
||||
case "bool": //nolint
|
||||
return `!= false /*Warning: if the value itself is false, can't be updated*/`
|
||||
}
|
||||
|
||||
if t.DBDriver == DBDriverMongodb {
|
||||
if t.GoType == goTypeOID {
|
||||
return `!= primitive.NilObjectID`
|
||||
}
|
||||
if t.GoType == "*"+t.Name {
|
||||
return `!= nil`
|
||||
}
|
||||
if strings.Contains(t.GoType, "[]") {
|
||||
return `!= nil`
|
||||
}
|
||||
}
|
||||
|
||||
return `!= ` + t.GoType
|
||||
}
|
||||
|
||||
// GoZero type of 0
|
||||
// GoZero type of 0, used in model to json template code
|
||||
func (t tmplField) GoZero() string {
|
||||
switch t.GoType {
|
||||
case "int8", "int16", "int32", "int64", "int", "uint8", "uint16", "uint32", "uint64", "uint", "float64", "float32",
|
||||
@@ -168,12 +190,28 @@ func (t tmplField) GoZero() string {
|
||||
return `= "string"`
|
||||
case "time.Time", "*time.Time", "sql.NullTime":
|
||||
return `= "0000-01-00T00:00:00.000+08:00"`
|
||||
case "[]byte", "[]string", "[]int", "interface{}": //nolint
|
||||
return `= nil` //nolint
|
||||
case "bool": //nolint
|
||||
return `= false`
|
||||
}
|
||||
|
||||
if t.DBDriver == DBDriverMongodb {
|
||||
if t.GoType == goTypeOID {
|
||||
return `= primitive.NilObjectID`
|
||||
}
|
||||
if t.GoType == "*"+t.Name {
|
||||
return `= nil`
|
||||
}
|
||||
if strings.Contains(t.GoType, "[]") {
|
||||
return `= nil`
|
||||
}
|
||||
}
|
||||
|
||||
return `= ` + t.GoType
|
||||
}
|
||||
|
||||
// GoTypeZero type of 0
|
||||
// GoTypeZero type of 0, used in service template code
|
||||
func (t tmplField) GoTypeZero() string {
|
||||
switch t.GoType {
|
||||
case "int8", "int16", "int32", "int64", "int", "uint8", "uint16", "uint32", "uint64", "uint", "float64", "float32",
|
||||
@@ -183,6 +221,22 @@ func (t tmplField) GoTypeZero() string {
|
||||
return `""`
|
||||
case "time.Time", "*time.Time", "sql.NullTime":
|
||||
return `0 /*time.Now().Second()*/`
|
||||
case "[]byte", "[]string", "[]int", "interface{}": //nolint
|
||||
return `nil` //nolint
|
||||
case "bool": //nolint
|
||||
return `false`
|
||||
}
|
||||
|
||||
if t.DBDriver == DBDriverMongodb {
|
||||
if t.GoType == goTypeOID {
|
||||
return `primitive.NilObjectID`
|
||||
}
|
||||
if t.GoType == "*"+t.Name {
|
||||
return `nil` //nolint
|
||||
}
|
||||
if strings.Contains(t.GoType, "[]") {
|
||||
return `nil` //nolint
|
||||
}
|
||||
}
|
||||
|
||||
return t.GoType
|
||||
@@ -214,6 +268,7 @@ var replaceFields = map[string]string{
|
||||
|
||||
const (
|
||||
columnID = "id"
|
||||
_columnID = "_id"
|
||||
columnCreatedAt = "created_at"
|
||||
columnUpdatedAt = "updated_at"
|
||||
columnDeletedAt = "deleted_at"
|
||||
@@ -349,30 +404,56 @@ func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
|
||||
//return "", nil, errors.Errorf(" unsupport option %d\n", o.Tp)
|
||||
}
|
||||
}
|
||||
if !isPrimaryKey[colName] && isNotNull {
|
||||
gormTag.WriteString(";NOT NULL")
|
||||
}
|
||||
tags = append(tags, "gorm", gormTag.String())
|
||||
|
||||
if opt.JSONTag {
|
||||
tags = append(tags, "json", jsonName)
|
||||
}
|
||||
field.DBDriver = opt.DBDriver
|
||||
switch opt.DBDriver {
|
||||
case DBDriverMongodb: // mongodb
|
||||
tags = append(tags, "bson", gormTag.String())
|
||||
if opt.JSONTag {
|
||||
if strings.ToLower(jsonName) == "_id" {
|
||||
jsonName = "id"
|
||||
}
|
||||
field.JSONName = jsonName
|
||||
tags = append(tags, "json", jsonName)
|
||||
}
|
||||
field.Tag = makeTagStr(tags)
|
||||
field.GoType = opt.FieldTypes[colName]
|
||||
if field.GoType == "time.Time" {
|
||||
importPath = append(importPath, "time")
|
||||
}
|
||||
|
||||
field.Tag = makeTagStr(tags)
|
||||
default: // gorm
|
||||
if !isPrimaryKey[colName] && isNotNull {
|
||||
gormTag.WriteString(";NOT NULL")
|
||||
}
|
||||
tags = append(tags, "gorm", gormTag.String())
|
||||
|
||||
// get type in golang
|
||||
nullStyle := opt.NullStyle
|
||||
if !canNull {
|
||||
nullStyle = NullDisable
|
||||
if opt.JSONTag {
|
||||
tags = append(tags, "json", jsonName)
|
||||
}
|
||||
field.Tag = makeTagStr(tags)
|
||||
|
||||
// get type in golang
|
||||
nullStyle := opt.NullStyle
|
||||
if !canNull {
|
||||
nullStyle = NullDisable
|
||||
}
|
||||
goType, pkg := mysqlToGoType(col.Tp, nullStyle)
|
||||
if pkg != "" {
|
||||
importPath = append(importPath, pkg)
|
||||
}
|
||||
field.GoType = goType
|
||||
}
|
||||
goType, pkg := mysqlToGoType(col.Tp, nullStyle)
|
||||
if pkg != "" {
|
||||
importPath = append(importPath, pkg)
|
||||
}
|
||||
field.GoType = goType
|
||||
|
||||
data.Fields = append(data.Fields, field)
|
||||
}
|
||||
if v, ok := opt.FieldTypes[SubStructKey]; ok {
|
||||
data.SubStructs = v
|
||||
}
|
||||
if v, ok := opt.FieldTypes[ProtoSubStructKey]; ok {
|
||||
data.ProtoSubStructs = v
|
||||
}
|
||||
data.DBDriver = opt.DBDriver
|
||||
|
||||
updateFieldsCode, err := getUpdateFieldsCode(data, opt.IsEmbed)
|
||||
if err != nil {
|
||||
@@ -454,13 +535,22 @@ func getModelStructCode(data tmplData, importPaths []string, isEmbed bool) (stri
|
||||
newImportPaths = append(newImportPaths, "github.com/zhufuyi/sponge/pkg/ggorm")
|
||||
} else {
|
||||
for i, field := range data.Fields {
|
||||
if strings.Contains(field.GoType, "time.Time") {
|
||||
data.Fields[i].GoType = "*time.Time"
|
||||
continue
|
||||
}
|
||||
// force conversion of ID field to uint64 type
|
||||
if field.Name == "ID" {
|
||||
data.Fields[i].GoType = "uint64"
|
||||
switch field.DBDriver {
|
||||
case DBDriverMongodb:
|
||||
if field.Name == "ID" {
|
||||
data.Fields[i].GoType = goTypeOID
|
||||
importPaths = append(importPaths, "go.mongodb.org/mongo-driver/bson/primitive")
|
||||
}
|
||||
|
||||
default:
|
||||
if strings.Contains(field.GoType, "time.Time") {
|
||||
data.Fields[i].GoType = "*time.Time"
|
||||
continue
|
||||
}
|
||||
// force conversion of ID field to uint64 type
|
||||
if field.Name == "ID" {
|
||||
data.Fields[i].GoType = "uint64"
|
||||
}
|
||||
}
|
||||
}
|
||||
newImportPaths = importPaths
|
||||
@@ -482,6 +572,16 @@ func getModelStructCode(data tmplData, importPaths []string, isEmbed bool) (stri
|
||||
structCode = strings.ReplaceAll(structCode, __type__, replaceFields[__type__])
|
||||
}
|
||||
|
||||
if data.SubStructs != "" {
|
||||
structCode += data.SubStructs
|
||||
}
|
||||
if data.DBDriver == DBDriverMongodb {
|
||||
structCode = strings.ReplaceAll(structCode, `bson:"column:`, `bson:"`)
|
||||
structCode = strings.ReplaceAll(structCode, `;type:"`, `"`)
|
||||
structCode = strings.ReplaceAll(structCode, `;type:;primary_key`, ``)
|
||||
structCode = strings.ReplaceAll(structCode, `bson:"id" json:"id"`, `bson:"_id" json:"id"`)
|
||||
}
|
||||
|
||||
return structCode, newImportPaths, nil
|
||||
}
|
||||
|
||||
@@ -507,28 +607,39 @@ func getUpdateFieldsCode(data tmplData, isEmbed bool) (string, error) {
|
||||
var newFields = []tmplField{}
|
||||
for _, field := range data.Fields {
|
||||
falseColumns := []string{}
|
||||
if isIgnoreFields(field.ColName, falseColumns...) {
|
||||
if isIgnoreFields(field.ColName, falseColumns...) || field.ColName == columnID || field.ColName == _columnID {
|
||||
continue
|
||||
}
|
||||
newFields = append(newFields, field)
|
||||
}
|
||||
data.Fields = newFields
|
||||
|
||||
builder := strings.Builder{}
|
||||
err := updateFieldTmpl.Execute(&builder, data)
|
||||
buf := new(bytes.Buffer)
|
||||
err := updateFieldTmpl.Execute(buf, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
code, err := format.Source([]byte(builder.String()))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(code), nil
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func getHandlerStructCodes(data tmplData) (string, error) {
|
||||
newFields := []tmplField{}
|
||||
for _, field := range data.Fields {
|
||||
if field.DBDriver == DBDriverMongodb { // mongodb
|
||||
if field.Name == "ID" {
|
||||
field.GoType = "string"
|
||||
}
|
||||
if "*"+field.Name == field.GoType {
|
||||
field.GoType = "*model." + field.Name
|
||||
}
|
||||
if strings.Contains(field.GoType, "[]*") {
|
||||
field.GoType = "[]*model." + strings.ReplaceAll(field.GoType, "[]*", "")
|
||||
}
|
||||
}
|
||||
newFields = append(newFields, field)
|
||||
}
|
||||
data.Fields = newFields
|
||||
|
||||
postStructCode, err := tmplExecuteWithFilter(data, handlerCreateStructTmpl)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("handlerCreateStructTmpl error: %v", err)
|
||||
@@ -554,6 +665,11 @@ func tmplExecuteWithFilter(data tmplData, tmpl *template.Template, reservedColum
|
||||
if isIgnoreFields(field.ColName, reservedColumns...) {
|
||||
continue
|
||||
}
|
||||
if field.DBDriver == DBDriverMongodb { // mongodb
|
||||
if strings.ToLower(field.Name) == "id" {
|
||||
field.GoType = "string"
|
||||
}
|
||||
}
|
||||
newFields = append(newFields, field)
|
||||
}
|
||||
data.Fields = newFields
|
||||
@@ -624,10 +740,86 @@ func getProtoFileCode(data tmplData, isWebProto bool) (string, error) {
|
||||
code = strings.ReplaceAll(code, "// protoMessageDetailCode", protoMessageDetailCode)
|
||||
code = strings.ReplaceAll(code, "*time.Time", "int64")
|
||||
code = strings.ReplaceAll(code, "time.Time", "int64")
|
||||
code = adaptedDbType(data, isWebProto, code)
|
||||
|
||||
return code, nil
|
||||
}
|
||||
|
||||
const (
|
||||
createTableReplyFieldCodeMark = "// createTableReplyFieldCode"
|
||||
deleteTableByIDRequestFieldCodeMark = "// deleteTableByIDRequestFieldCode"
|
||||
deleteTableByIDsRequestFieldCodeMark = "// deleteTableByIDsRequestFieldCode"
|
||||
getTableByIDRequestFieldCodeMark = "// getTableByIDRequestFieldCode"
|
||||
getTableByIDsRequestFieldCodeMark = "// getTableByIDsRequestFieldCode"
|
||||
listTableByLastIDRequestFieldCodeMark = "// listTableByLastIDRequestFieldCode"
|
||||
)
|
||||
|
||||
var grpcDefaultProtoMessageFieldCodes = map[string]string{
|
||||
createTableReplyFieldCodeMark: "uint64 id = 1;",
|
||||
deleteTableByIDRequestFieldCodeMark: "uint64 id = 1 [(validate.rules).uint64.gt = 0];",
|
||||
deleteTableByIDsRequestFieldCodeMark: "repeated uint64 ids = 1 [(validate.rules).repeated.min_items = 1];",
|
||||
getTableByIDRequestFieldCodeMark: "uint64 id = 1 [(validate.rules).uint64.gt = 0];",
|
||||
getTableByIDsRequestFieldCodeMark: "repeated uint64 ids = 1 [(validate.rules).repeated.min_items = 1];",
|
||||
listTableByLastIDRequestFieldCodeMark: "uint64 lastID = 1; // last id",
|
||||
}
|
||||
|
||||
var webDefaultProtoMessageFieldCodes = map[string]string{
|
||||
createTableReplyFieldCodeMark: "uint64 id = 1;",
|
||||
deleteTableByIDRequestFieldCodeMark: `uint64 id =1 [(validate.rules).uint64.gt = 0, (tagger.tags) = "uri:\"id\""];`,
|
||||
deleteTableByIDsRequestFieldCodeMark: "repeated uint64 ids = 1 [(validate.rules).repeated.min_items = 1];",
|
||||
getTableByIDRequestFieldCodeMark: `uint64 id =1 [(validate.rules).uint64.gt = 0, (tagger.tags) = "uri:\"id\"" ];`,
|
||||
getTableByIDsRequestFieldCodeMark: "repeated uint64 ids = 1 [(validate.rules).repeated.min_items = 1];",
|
||||
listTableByLastIDRequestFieldCodeMark: `uint64 lastID = 1 [(tagger.tags) = "form:\"lastID\""]; // last id`,
|
||||
}
|
||||
|
||||
var grpcProtoMessageFieldCodes = map[string]string{
|
||||
createTableReplyFieldCodeMark: "string id = 1;",
|
||||
deleteTableByIDRequestFieldCodeMark: "string id = 1 [(validate.rules).string.min_len = 6];",
|
||||
deleteTableByIDsRequestFieldCodeMark: "repeated string ids = 1 [(validate.rules).repeated.min_items = 1];",
|
||||
getTableByIDRequestFieldCodeMark: "string id = 1 [(validate.rules).string.min_len = 6];",
|
||||
getTableByIDsRequestFieldCodeMark: "repeated string ids = 1 [(validate.rules).repeated.min_items = 1];",
|
||||
listTableByLastIDRequestFieldCodeMark: "string lastID = 1; // last id",
|
||||
}
|
||||
|
||||
var webProtoMessageFieldCodes = map[string]string{
|
||||
createTableReplyFieldCodeMark: "string id = 1;",
|
||||
deleteTableByIDRequestFieldCodeMark: `string id =1 [(validate.rules).string.min_len = 6, (tagger.tags) = "uri:\"id\""];`,
|
||||
deleteTableByIDsRequestFieldCodeMark: "repeated string ids = 1 [(validate.rules).repeated.min_items = 1];",
|
||||
getTableByIDRequestFieldCodeMark: `string id =1 [(validate.rules).string.min_len = 6, (tagger.tags) = "uri:\"id\"" ];`,
|
||||
getTableByIDsRequestFieldCodeMark: "repeated string ids = 1 [(validate.rules).repeated.min_items = 1];",
|
||||
listTableByLastIDRequestFieldCodeMark: `string lastID = 1 [(tagger.tags) = "form:\"lastID\""]; // last id`,
|
||||
}
|
||||
|
||||
func adaptedDbType(data tmplData, isWebProto bool, code string) string {
|
||||
switch data.DBDriver {
|
||||
case DBDriverMongodb: // mongodb
|
||||
if isWebProto {
|
||||
code = replaceProtoMessageFieldCode(code, webProtoMessageFieldCodes)
|
||||
} else {
|
||||
code = replaceProtoMessageFieldCode(code, grpcProtoMessageFieldCodes)
|
||||
}
|
||||
default:
|
||||
if isWebProto {
|
||||
code = replaceProtoMessageFieldCode(code, webDefaultProtoMessageFieldCodes)
|
||||
} else {
|
||||
code = replaceProtoMessageFieldCode(code, grpcDefaultProtoMessageFieldCodes)
|
||||
}
|
||||
}
|
||||
|
||||
if data.ProtoSubStructs != "" {
|
||||
code += "\n" + data.ProtoSubStructs
|
||||
}
|
||||
|
||||
return code
|
||||
}
|
||||
|
||||
func replaceProtoMessageFieldCode(code string, messageFields map[string]string) string {
|
||||
for k, v := range messageFields {
|
||||
code = strings.ReplaceAll(code, k, v)
|
||||
}
|
||||
return code
|
||||
}
|
||||
|
||||
func getServiceStructCode(data tmplData) (string, error) {
|
||||
builder := strings.Builder{}
|
||||
err := serviceStructTmpl.Execute(&builder, data)
|
||||
@@ -688,6 +880,7 @@ func addCommaToJSON(modelJSONCode string) string {
|
||||
return out
|
||||
}
|
||||
|
||||
// nolint
|
||||
func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path string) {
|
||||
if style == NullInSql {
|
||||
path = "database/sql"
|
||||
@@ -746,6 +939,7 @@ func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path s
|
||||
return name, path
|
||||
}
|
||||
|
||||
// nolint
|
||||
func goTypeToProto(fields []tmplField) []tmplField {
|
||||
var newFields []tmplField
|
||||
for _, field := range fields {
|
||||
@@ -760,7 +954,24 @@ func goTypeToProto(fields []tmplField) []tmplField {
|
||||
field.GoType = "float"
|
||||
case "float64":
|
||||
field.GoType = "double"
|
||||
case goTypeInts, "[]int64":
|
||||
field.GoType = "repeated int64"
|
||||
case "[]int32":
|
||||
field.GoType = "repeated int32"
|
||||
case "[]byte":
|
||||
field.GoType = "string"
|
||||
case goTypeStrings:
|
||||
field.GoType = "repeated string"
|
||||
}
|
||||
|
||||
if field.DBDriver == DBDriverMongodb {
|
||||
if field.GoType[0] == '*' {
|
||||
field.GoType = field.GoType[1:]
|
||||
} else if strings.Contains(field.GoType, "[]*") {
|
||||
field.GoType = "repeated " + strings.ReplaceAll(field.GoType, "[]*", "")
|
||||
}
|
||||
}
|
||||
|
||||
newFields = append(newFields, field)
|
||||
}
|
||||
return newFields
|
||||
|
Reference in New Issue
Block a user