From 352bcb37b2bd329ed6a1d6e058fa2165c23e7df3 Mon Sep 17 00:00:00 2001 From: zodial Date: Mon, 15 Jul 2024 11:31:17 +0800 Subject: [PATCH] feat: postgresql orm --- console/commands/orm.go | 2 +- console/commands/orm/mysql.go.subtext | 2 +- console/commands/orm/mysql.go.text | 2 +- console/commands/pgorm/pgsql.go | 187 +++++++++++----- console/commands/pgorm/pgsql.go.subtext | 278 ++++++++++++++++++++++++ console/commands/pgorm/pgsql.go.text | 67 +++++- 6 files changed, 471 insertions(+), 67 deletions(-) create mode 100644 console/commands/pgorm/pgsql.go.subtext diff --git a/console/commands/orm.go b/console/commands/orm.go index aefa581..df8961e 100644 --- a/console/commands/orm.go +++ b/console/commands/orm.go @@ -65,7 +65,7 @@ func (OrmCommand) Execute(input command.Input) { switch driver { case "mysql": orm.GenMysql(s.(string), conf, out) - case "pgsql": + case "postgresql": pgorm.GenSql(s.(string), conf, out) } diff --git a/console/commands/orm/mysql.go.subtext b/console/commands/orm/mysql.go.subtext index 371c368..85e85f8 100644 --- a/console/commands/orm/mysql.go.subtext +++ b/console/commands/orm/mysql.go.subtext @@ -123,7 +123,7 @@ func (orm *OrmMysqlTableName) Offset(offset int) *OrmMysqlTableName { orm.db.Offset(offset) return orm } -// 直接查询列表, 如果需要条数, 使用Find() +// Get 直接查询列表, 如果需要条数, 使用Find() func (orm *OrmMysqlTableName) Get() MysqlTableNameList { got, _ := orm.Find() return got diff --git a/console/commands/orm/mysql.go.text b/console/commands/orm/mysql.go.text index 15914b9..7ed2486 100644 --- a/console/commands/orm/mysql.go.text +++ b/console/commands/orm/mysql.go.text @@ -141,7 +141,7 @@ func (orm *OrmMysqlTableName) Offset(offset int) *OrmMysqlTableName { orm.db.Offset(offset) return orm } -// 直接查询列表, 如果需要条数, 使用Find() +// Get 直接查询列表, 如果需要条数, 使用Find() func (orm *OrmMysqlTableName) Get() MysqlTableNameList { got, _ := orm.Find() return got diff --git a/console/commands/pgorm/pgsql.go b/console/commands/pgorm/pgsql.go index 9a059a2..5875959 100644 --- a/console/commands/pgorm/pgsql.go +++ b/console/commands/pgorm/pgsql.go @@ -11,6 +11,7 @@ import ( _ "github.com/lib/pq" "log" "os" + "strconv" "strings" "time" ) @@ -30,7 +31,8 @@ func GenSql(name string, conf Conf, out string) { } db := NewDb(conf) - tableColumns := db.tableColumns() + tableInfos := db.tableColumns() + tableColumns := tableInfos.Columns root, _ := os.Getwd() file, err := os.ReadFile(root + "/config/database/" + name + ".json") @@ -43,20 +45,30 @@ func GenSql(name string, conf Conf, out string) { } // 计算import - imports := getImports(tableColumns) + imports := getImports(tableInfos.Infos, tableColumns) for table, columns := range tableColumns { - tableName := parser.StringToSnake(table) - file := out + "/" + tableName + tableConfig := tableInfos.Infos[table] + mysqlTableName := parser.StringToSnake(table) + file := out + "/" + mysqlTableName + + if _, err := os.Stat(file + "_lock.go"); !os.IsNotExist(err) { + continue + } str := "package " + name str += "\nimport (" + imports[table] + "\n)" str += "\n" + genOrmStruct(table, columns, conf, relationship[table]) - baseFunStr := baseMysqlFuncStr + var baseFunStr string + if tableConfig.IsSub() { + baseFunStr = basePgsqlSubInfoStr + } else { + baseFunStr = basePgsqlFuncStr + } for old, newStr := range map[string]string{ "{orm_table_name}": parser.StringToHump(table), "{table_name}": table, - "{db}": name, + "{connect_name}": name, } { baseFunStr = strings.ReplaceAll(baseFunStr, old, newStr) } @@ -75,22 +87,23 @@ func genListFunc(table string, columns []tableColumn) string { TableName := parser.StringToHump(table) str := "\ntype " + TableName + "List []*" + TableName for _, column := range columns { + ColumnName := parser.StringToHump(column.ColumnName) // 索引,或者枚举字段 - if strInStr(column.ColumnName, []string{"id", "code"}) { + if strInStr(column.ColumnName, []string{"id", "code"}) || strInStr(column.Comment, []string{"@index"}) { goType := column.GoType if column.IsNullable { goType = "*" + goType } - str += "\nfunc (l " + TableName + "List) Get" + column.ColumnName + "List() []" + goType + " {" + + str += "\nfunc (l " + TableName + "List) Get" + ColumnName + "List() []" + goType + " {" + "\n\tgot := make([]" + goType + ", 0)\n\tfor _, val := range l {" + - "\n\t\tgot = append(got, val." + column.ColumnName + ")" + + "\n\t\tgot = append(got, val." + ColumnName + ")" + "\n\t}" + "\n\treturn got" + "\n}" - str += "\nfunc (l " + TableName + "List) Get" + column.ColumnName + "Map() map[" + goType + "]*" + TableName + " {" + + str += "\nfunc (l " + TableName + "List) Get" + ColumnName + "Map() map[" + goType + "]*" + TableName + " {" + "\n\tgot := make(map[" + goType + "]*" + TableName + ")\n\tfor _, val := range l {" + - "\n\t\tgot[val." + column.ColumnName + "] = val" + + "\n\t\tgot[val." + ColumnName + "] = val" + "\n\t}" + "\n\treturn got" + "\n}" @@ -104,73 +117,76 @@ func genFieldFunc(table string, columns []tableColumn) string { str := "" for _, column := range columns { + ColumnName := parser.StringToHump(column.ColumnName) // 等于函数 - str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "(val " + column.GoType + ") *Orm" + TableName + " {" + - "\n\torm.db.Where(\"`" + column.ColumnName + "` = ?\", val)" + + str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "(val " + column.GoType + ") *Orm" + TableName + " {" + + "\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" = ?\", val)" + "\n\treturn orm" + "\n}" - if column.IsPKey { + if strInStr(column.GoType, []string{"int32", "int64"}) { goType := column.GoType if column.IsNullable { goType = "*" + goType } // if 主键, 生成In, > < - str += "\nfunc (orm *Orm" + TableName + ") InsertGet" + column.ColumnName + "(row *" + TableName + ") " + goType + " {" + - "\n\torm.db.Create(row)" + - "\n\treturn row." + column.ColumnName + - "\n}" + if column.IsPKey { + str += "\nfunc (orm *Orm" + TableName + ") InsertGet" + ColumnName + "(row *" + TableName + ") " + goType + " {" + + "\n\torm.db.Create(row)" + + "\n\treturn row." + ColumnName + + "\n}" + } - str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "In(val []" + column.GoType + ") *Orm" + TableName + " {" + - "\n\torm.db.Where(\"`" + column.ColumnName + "` IN ?\", val)" + + str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "In(val []" + column.GoType + ") *Orm" + TableName + " {" + + "\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" IN ?\", val)" + "\n\treturn orm" + "\n}" - str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Gt(val " + column.GoType + ") *Orm" + TableName + " {" + - "\n\torm.db.Where(\"`" + column.ColumnName + "` > ?\", val)" + + str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "Gt(val " + column.GoType + ") *Orm" + TableName + " {" + + "\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" > ?\", val)" + "\n\treturn orm" + "\n}" - str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Gte(val " + column.GoType + ") *Orm" + TableName + " {" + - "\n\torm.db.Where(\"`" + column.ColumnName + "` >= ?\", val)" + + str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "Gte(val " + column.GoType + ") *Orm" + TableName + " {" + + "\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" >= ?\", val)" + "\n\treturn orm" + "\n}" - str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Lt(val " + column.GoType + ") *Orm" + TableName + " {" + - "\n\torm.db.Where(\"`" + column.ColumnName + "` < ?\", val)" + + str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "Lt(val " + column.GoType + ") *Orm" + TableName + " {" + + "\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" < ?\", val)" + "\n\treturn orm" + "\n}" - str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Lte(val " + column.GoType + ") *Orm" + TableName + " {" + - "\n\torm.db.Where(\"`" + column.ColumnName + "` <= ?\", val)" + + str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "Lte(val " + column.GoType + ") *Orm" + TableName + " {" + + "\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" <= ?\", val)" + "\n\treturn orm" + "\n}" } else { // 索引,或者枚举字段 if strInStr(column.ColumnName, []string{"id", "code", "status", "state"}) { // else if 名称存在 id, code, status 生成in操作 - str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "In(val []" + column.GoType + ") *Orm" + TableName + " {" + - "\n\torm.db.Where(\"`" + column.ColumnName + "` IN ?\", val)" + + str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "In(val []" + column.GoType + ") *Orm" + TableName + " {" + + "\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" IN ?\", val)" + "\n\treturn orm" + "\n}" - str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Ne(val " + column.GoType + ") *Orm" + TableName + " {" + - "\n\torm.db.Where(\"`" + column.ColumnName + "` <> ?\", val)" + + str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "Ne(val " + column.GoType + ") *Orm" + TableName + " {" + + "\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" <> ?\", val)" + "\n\treturn orm" + "\n}" } // 时间字段 if strInStr(column.ColumnName, []string{"created", "updated", "time", "_at"}) || (column.GoType == "database.Time") { - str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Between(begin " + column.GoType + ", end " + column.GoType + ") *Orm" + TableName + " {" + - "\n\torm.db.Where(\"`" + column.ColumnName + "` BETWEEN ? AND ?\", begin, end)" + + str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "Between(begin " + column.GoType + ", end " + column.GoType + ") *Orm" + TableName + " {" + + "\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" BETWEEN ? AND ?\", begin, end)" + "\n\treturn orm" + "\n}" - str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Lte(val " + column.GoType + ") *Orm" + TableName + " {" + - "\n\torm.db.Where(\"`" + column.ColumnName + "` <= ?\", val)" + + str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "Lte(val " + column.GoType + ") *Orm" + TableName + " {" + + "\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" <= ?\", val)" + "\n\treturn orm" + "\n}" - str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Gte(val " + column.GoType + ") *Orm" + TableName + " {" + - "\n\torm.db.Where(\"`" + column.ColumnName + "` >= ?\", val)" + + str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "Gte(val " + column.GoType + ") *Orm" + TableName + " {" + + "\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" >= ?\", val)" + "\n\treturn orm" + "\n}" } @@ -189,17 +205,19 @@ func strInStr(s string, in []string) bool { return false } +//go:embed pgsql.go.subtext +var basePgsqlSubInfoStr string + //go:embed pgsql.go.text -var baseMysqlFuncStr string +var basePgsqlFuncStr string // 字段类型引入 var alias = map[string]string{ - "database": "github.com/go-home-admin/home/bootstrap/services/database", - "datatypes": "gorm.io/datatypes", + "database": "github.com/go-home-admin/home/bootstrap/services/database", } // 获得 table => map{alias => github.com/*} -func getImports(tableColumns map[string][]tableColumn) map[string]string { +func getImports(infos map[string]orm.TableInfos, tableColumns map[string][]tableColumn) map[string]string { got := make(map[string]string) for table, columns := range tableColumns { // 初始引入 @@ -211,10 +229,14 @@ func getImports(tableColumns map[string][]tableColumn) map[string]string { "database/sql": "sql", "github.com/go-home-admin/home/app": "home", } + if infos[table].IsSub() { + delete(tm, "github.com/go-home-admin/home/bootstrap/providers") + } + for _, column := range columns { index := strings.Index(column.GoType, ".") - if index != -1 && column.GoType[:index] != "gorm" { - as := strings.Replace(column.GoType[:index], "*", "", 1) + if index != -1 { + as := column.GoType[:index] importStr := alias[as] tm[importStr] = as } @@ -232,12 +254,20 @@ func genOrmStruct(table string, columns []tableColumn, conf Conf, relationships str := `type {TableName} struct {` for _, column := range columns { p := "" - if column.IsNullable && column.ColumnName != "deleted_at" { + if column.IsNullable && !(column.ColumnName == "deleted_at" && column.GoType == "database.Time") { p = "*" } - if column.ColumnName == "deleted_at" { + if column.ColumnName == "deleted_at" && column.GoType == "database.Time" { column.GoType = "gorm.DeletedAt" } + + // 使用注释@Type(int), 强制设置生成的go struct 属性 类型 + if i := strings.Index(column.ColumnName, "@type("); i != -1 { + s := column.Comment[i+6:] + e := strings.Index(s, ")") + column.GoType = s[:e] + } + hasField[column.ColumnName] = true fieldName := parser.StringToHump(column.ColumnName) str += fmt.Sprintf("\n\t%v %v%v`%v` // %v", fieldName, p, column.GoType, genGormTag(column), strings.ReplaceAll(column.Comment, "\n", " ")) @@ -310,6 +340,8 @@ func genGormTag(column tableColumn) string { // 主键 if column.IsPKey { arr = append(arr, "primaryKey") + } else if column.IndexName != "" { + arr = append(arr, "index:"+column.ColumnName) } // default if column.ColumnDefault != "" { @@ -335,13 +367,15 @@ func (d *DB) GetDB() *sql.DB { return d.db } -func (d *DB) tableColumns() map[string][]tableColumn { +// 获取所有表信息 +// 过滤分表信息, table_{1-9} 只返回table +func (d *DB) tableColumns() TableInfo { var sqlStr = "SELECT tablename FROM pg_tables WHERE schemaname = 'public'" rows, err := d.db.Query(sqlStr) if err != nil { log.Println("Error reading table information: ", err.Error()) - return nil + return TableInfo{} } defer rows.Close() ormColumns := make(map[string][]tableColumn) @@ -352,10 +386,12 @@ func (d *DB) tableColumns() map[string][]tableColumn { &tableName, ) _rows, _ := d.db.Query(` -SELECT i.column_name, i.column_default, i.is_nullable, i.udt_name, col_description(a.attrelid,a.attnum) as comment +SELECT i.column_name, i.column_default, i.is_nullable, i.udt_name, col_description(a.attrelid,a.attnum) as comment, ixc.relname FROM information_schema.columns as i LEFT JOIN pg_class as c on c.relname = i.table_name LEFT JOIN pg_attribute as a on a.attrelid = c.oid and a.attname = i.column_name +LEFT JOIN pg_index ix ON c.oid = ix.indrelid AND a.attnum = ANY(ix.indkey) +LEFT JOIN pg_class ixc ON ixc.oid = ix.indexrelid WHERE table_schema = 'public' and i.table_name = $1; `, tableName) defer _rows.Close() @@ -380,32 +416,74 @@ WHERE pg_class.relname = $1 AND pg_constraint.contype = 'p' is_nullable string udt_name string comment *string + index_name *string ) - err = _rows.Scan(&column_name, &column_default, &is_nullable, &udt_name, &comment) + err = _rows.Scan(&column_name, &column_default, &is_nullable, &udt_name, &comment, &index_name) if err != nil { panic(err) } - var columnComment string + var columnComment, indexName string if comment != nil { columnComment = *comment } + if index_name != nil { + indexName = *index_name + } var ColumnDefault string if column_default != nil { ColumnDefault = *column_default } ormColumns[tableName] = append(ormColumns[tableName], tableColumn{ - ColumnName: parser.StringToHump(column_name), + ColumnName: column_name, ColumnDefault: ColumnDefault, PgType: udt_name, GoType: PgTypeToGoType(udt_name, column_name), IsNullable: is_nullable == "YES", - IsPKey: false, + IsPKey: column_name == pkey, Comment: columnComment, + IndexName: indexName, }) } } - return ormColumns + return Filter(ormColumns) +} + +// Filter 过滤分表格式 +// table_{0-9} 只返回table +func Filter(tableColumns map[string][]tableColumn) TableInfo { + info := TableInfo{ + Columns: make(map[string][]tableColumn), + Infos: make(map[string]orm.TableInfos), + } + tableSort := make(map[string]int) + for tableName, columns := range tableColumns { + arr := strings.Split(tableName, "_") + arrLen := len(arr) + if arrLen > 1 { + str := arr[arrLen-1] + tn, err := strconv.Atoi(str) + if err == nil { + tableName = strings.ReplaceAll(tableName, "_"+str, "") + info.Infos[tableName] = orm.TableInfos{ + "sub": "true", // 分表 + } + // 保留数字最大的 + n, ok := tableSort[tableName] + if ok && n > tn { + continue + } + tableSort[tableName] = tn + } + } + info.Columns[tableName] = columns + } + return info +} + +type TableInfo struct { + Columns map[string][]tableColumn + Infos map[string]orm.TableInfos } type tableColumn struct { @@ -417,6 +495,7 @@ type tableColumn struct { IsNullable bool IsPKey bool Comment string + IndexName string } func PgTypeToGoType(pgType string, columnName string) string { diff --git a/console/commands/pgorm/pgsql.go.subtext b/console/commands/pgorm/pgsql.go.subtext new file mode 100644 index 0000000..57272fc --- /dev/null +++ b/console/commands/pgorm/pgsql.go.subtext @@ -0,0 +1,278 @@ + +type Orm{orm_table_name} struct { + db *gorm.DB +} + +func (orm *Orm{orm_table_name}) GetDB() *gorm.DB { + return orm.db +} + +func (orm *Orm{orm_table_name}) GetTableInfo() interface{} { + return &{orm_table_name}{} +} + +// Create insert the value into database +func (orm *Orm{orm_table_name}) Create(value interface{}) *gorm.DB { + return orm.db.Create(value) +} + +// CreateInBatches insert the value in batches into database +func (orm *Orm{orm_table_name}) CreateInBatches(value interface{}, batchSize int) *gorm.DB { + return orm.db.CreateInBatches(value, batchSize) +} + +// Save update value in database, if the value doesn't have primary key, will insert it +func (orm *Orm{orm_table_name}) Save(value interface{}) *gorm.DB { + return orm.db.Save(value) +} + +func (orm *Orm{orm_table_name}) Row() *sql.Row { + return orm.db.Row() +} + +func (orm *Orm{orm_table_name}) Rows() (*sql.Rows, error) { + return orm.db.Rows() +} + +// Scan scan value to a struct +func (orm *Orm{orm_table_name}) Scan(dest interface{}) *gorm.DB { + return orm.db.Scan(dest) +} + +func (orm *Orm{orm_table_name}) ScanRows(rows *sql.Rows, dest interface{}) error { + return orm.db.ScanRows(rows, dest) +} + +// Connection use a db conn to execute Multiple commands,this conn will put conn pool after it is executed. +func (orm *Orm{orm_table_name}) Connection(fc func(tx *gorm.DB) error) (err error) { + return orm.db.Connection(fc) +} + +// Transaction start a transaction as a block, return error will rollback, otherwise to commit. +func (orm *Orm{orm_table_name}) Transaction(fc func(tx *gorm.DB) error, opts ...*sql.TxOptions) (err error) { + return orm.db.Transaction(fc, opts...) +} + +// Begin begins a transaction +func (orm *Orm{orm_table_name}) Begin(opts ...*sql.TxOptions) *gorm.DB { + return orm.db.Begin(opts...) +} + +// Commit commit a transaction +func (orm *Orm{orm_table_name}) Commit() *gorm.DB { + return orm.db.Commit() +} + +// Rollback rollback a transaction +func (orm *Orm{orm_table_name}) Rollback() *gorm.DB { + return orm.db.Rollback() +} + +func (orm *Orm{orm_table_name}) SavePoint(name string) *gorm.DB { + return orm.db.SavePoint(name) +} + +func (orm *Orm{orm_table_name}) RollbackTo(name string) *gorm.DB { + return orm.db.RollbackTo(name) +} + +// Exec execute raw sql +func (orm *Orm{orm_table_name}) Exec(sql string, values ...interface{}) *gorm.DB { + return orm.db.Exec(sql, values...) +} + +// Exists 检索对象是否存在 +func (orm *Orm{orm_table_name}) Exists() (bool, error) { + dest := &struct { + H int `json:"h"` + }{} + db := orm.db.Select("1 as h").Limit(1).Find(dest) + return dest.H == 1, db.Error +} + +func (orm *Orm{orm_table_name}) Unscoped() *Orm{orm_table_name} { + orm.db.Unscoped() + return orm +} +// ------------ 以下是单表独有的函数, 便捷字段条件, Laravel风格操作 --------- + +func (orm *Orm{orm_table_name}) Insert(row *{orm_table_name}) error { + return orm.db.Create(row).Error +} + +func (orm *Orm{orm_table_name}) Inserts(rows []*{orm_table_name}) *gorm.DB { + return orm.db.Create(rows) +} + +func (orm *Orm{orm_table_name}) Order(value interface{}) *Orm{orm_table_name} { + orm.db.Order(value) + return orm +} + +func (orm *Orm{orm_table_name}) Group(name string) *Orm{orm_table_name} { + orm.db.Group(name) + return orm +} + +func (orm *Orm{orm_table_name}) Limit(limit int) *Orm{orm_table_name} { + orm.db.Limit(limit) + return orm +} + +func (orm *Orm{orm_table_name}) Offset(offset int) *Orm{orm_table_name} { + orm.db.Offset(offset) + return orm +} +// Get 直接查询列表, 如果需要条数, 使用Find() +func (orm *Orm{orm_table_name}) Get() {orm_table_name}List { + got, _ := orm.Find() + return got +} + +// Pluck used to query single column from a model as a map +// var ages []int64 +// db.Model(&users).Pluck("age", &ages) +func (orm *Orm{orm_table_name}) Pluck(column string, dest interface{}) *gorm.DB { + return orm.db.Pluck(column, dest) +} + +// Delete 有条件删除 +func (orm *Orm{orm_table_name}) Delete(conds ...interface{}) *gorm.DB { + return orm.db.Delete(&{orm_table_name}{}, conds...) +} + +// DeleteAll 删除所有 +func (orm *Orm{orm_table_name}) DeleteAll() *gorm.DB { + return orm.db.Exec("DELETE FROM {table_name}") +} + +func (orm *Orm{orm_table_name}) Count() int64 { + var count int64 + orm.db.Count(&count) + return count +} + +// First 检索单个对象 +func (orm *Orm{orm_table_name}) First(conds ...interface{}) (*{orm_table_name}, bool) { + dest := &{orm_table_name}{} + db := orm.db.Limit(1).Find(dest, conds...) + return dest, db.RowsAffected == 1 +} + +// Take return a record that match given conditions, the order will depend on the database implementation +func (orm *Orm{orm_table_name}) Take(conds ...interface{}) (*{orm_table_name}, int64) { + dest := &{orm_table_name}{} + db := orm.db.Take(dest, conds...) + return dest, db.RowsAffected +} + +// Last find last record that match given conditions, order by primary key +func (orm *Orm{orm_table_name}) Last(conds ...interface{}) (*{orm_table_name}, int64) { + dest := &{orm_table_name}{} + db := orm.db.Last(dest, conds...) + return dest, db.RowsAffected +} + +func (orm *Orm{orm_table_name}) Find(conds ...interface{}) ({orm_table_name}List, int64) { + list := make([]*{orm_table_name}, 0) + tx := orm.db.Find(&list, conds...) + if tx.Error != nil { + logrus.Error(tx.Error) + } + return list, tx.RowsAffected +} + +// Paginate 分页 +func (orm *Orm{orm_table_name}) Paginate(page int, limit int) ({orm_table_name}List, int64) { + var total int64 + list := make([]*{orm_table_name}, 0) + orm.db.Count(&total) + if total > 0 { + if page == 0 { + page = 1 + } + + offset := (page - 1) * limit + tx := orm.db.Offset(offset).Limit(limit).Find(&list) + if tx.Error != nil { + logrus.Error(tx.Error) + } + } + + return list, total +} + +// FindInBatches find records in batches +func (orm *Orm{orm_table_name}) FindInBatches(dest interface{}, batchSize int, fc func(tx *gorm.DB, batch int) error) *gorm.DB { + return orm.db.FindInBatches(dest, batchSize, fc) +} + +// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions) +func (orm *Orm{orm_table_name}) FirstOrInit(dest *{orm_table_name}, conds ...interface{}) (*{orm_table_name}, *gorm.DB) { + return dest, orm.db.FirstOrInit(dest, conds...) +} + +// FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions) +func (orm *Orm{orm_table_name}) FirstOrCreate(dest interface{}, conds ...interface{}) *gorm.DB { + return orm.db.FirstOrCreate(dest, conds...) +} + +// Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields +func (orm *Orm{orm_table_name}) Update(column string, value interface{}) *gorm.DB { + return orm.db.Update(column, value) +} + +// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields +func (orm *Orm{orm_table_name}) Updates(values interface{}) *gorm.DB { + return orm.db.Updates(values) +} + +func (orm *Orm{orm_table_name}) UpdateColumn(column string, value interface{}) *gorm.DB { + return orm.db.UpdateColumn(column, value) +} + +func (orm *Orm{orm_table_name}) UpdateColumns(values interface{}) *gorm.DB { + return orm.db.UpdateColumns(values) +} + +func (orm *Orm{orm_table_name}) Where(query interface{}, args ...interface{}) *Orm{orm_table_name} { + orm.db.Where(query, args...) + return orm +} + +func (orm *Orm{orm_table_name}) Select(query interface{}, args ...interface{}) *Orm{orm_table_name} { + orm.db.Select(query, args...) + return orm +} + +func (orm *Orm{orm_table_name}) Sum(field string) int64 { + type result struct { + S int64 `json:"s"` + } + ret := result{} + orm.db.Select("SUM(\""+field+"\") AS s").Scan(&ret) + return ret.S +} + +// Preload preload associations with given conditions +// db.Preload("Orders|orders", "state NOT IN (?)", "cancelled").Find(&users) +func (orm *Orm{orm_table_name}) Preload(query string, args ...interface{}) *Orm{orm_table_name} { + arr := strings.Split(query, ".") + for i, _ := range arr { + arr[i] = home.StringToHump(arr[i]) + } + orm.db.Preload(strings.Join(arr, "."), args...) + return orm +} + +// Joins specify Joins conditions +// db.Joins("Account|account").Find(&user) +// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) +// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) +func (orm *Orm{orm_table_name}) Joins(query string, args ...interface{}) *Orm{orm_table_name} { + if !strings.Contains(query, " ") { + query = home.StringToHump(query) + } + orm.db.Joins(query, args...) + return orm +} diff --git a/console/commands/pgorm/pgsql.go.text b/console/commands/pgorm/pgsql.go.text index 394a569..557e9a0 100644 --- a/console/commands/pgorm/pgsql.go.text +++ b/console/commands/pgorm/pgsql.go.text @@ -7,16 +7,28 @@ type Orm{orm_table_name} struct { db *gorm.DB } -func NewOrm{orm_table_name}() *Orm{orm_table_name} { - orm := &Orm{orm_table_name}{} - orm.db = providers.NewMysqlProvider().GetBean("{db}").(*gorm.DB).Model(&MysqlTableName{}).Model(&{orm_table_name}{}) - return orm +func (orm *Orm{orm_table_name}) TableName() string { + return "{table_name}" +} + +func NewOrm{orm_table_name}(txs ...*gorm.DB) *Orm{orm_table_name} { + var tx *gorm.DB + if len(txs) == 0 { + tx = providers.GetBean("postgresql, {connect_name}").(*gorm.DB).Model(&{orm_table_name}{}) + } else { + tx = txs[0].Model(&{orm_table_name}{}) + } + return &Orm{orm_table_name}{db: tx} } func (orm *Orm{orm_table_name}) GetDB() *gorm.DB { return orm.db } +func (orm *Orm{orm_table_name}) GetTableInfo() interface{} { + return &{orm_table_name}{} +} + // Create insert the value into database func (orm *Orm{orm_table_name}) Create(value interface{}) *gorm.DB { return orm.db.Create(value) @@ -87,6 +99,19 @@ func (orm *Orm{orm_table_name}) Exec(sql string, values ...interface{}) *gorm.DB return orm.db.Exec(sql, values...) } +// Exists 检索对象是否存在 +func (orm *Orm{orm_table_name}) Exists() (bool, error) { + dest := &struct { + H int `json:"h"` + }{} + db := orm.db.Select("1 as h").Limit(1).Find(dest) + return dest.H == 1, db.Error +} + +func (orm *Orm{orm_table_name}) Unscoped() *Orm{orm_table_name} { + orm.db.Unscoped() + return orm +} // ------------ 以下是单表独有的函数, 便捷字段条件, Laravel风格操作 --------- func (orm *Orm{orm_table_name}) Insert(row *{orm_table_name}) error { @@ -102,6 +127,11 @@ func (orm *Orm{orm_table_name}) Order(value interface{}) *Orm{orm_table_name} { return orm } +func (orm *Orm{orm_table_name}) Group(name string) *Orm{orm_table_name} { + orm.db.Group(name) + return orm +} + func (orm *Orm{orm_table_name}) Limit(limit int) *Orm{orm_table_name} { orm.db.Limit(limit) return orm @@ -111,7 +141,7 @@ func (orm *Orm{orm_table_name}) Offset(offset int) *Orm{orm_table_name} { orm.db.Offset(offset) return orm } -// 直接查询列表, 如果需要条数, 使用Find() +// Get 直接查询列表, 如果需要条数, 使用Find() func (orm *Orm{orm_table_name}) Get() {orm_table_name}List { got, _ := orm.Find() return got @@ -191,12 +221,11 @@ func (orm *Orm{orm_table_name}) Paginate(page int, limit int) ({orm_table_name}L } // SimplePaginate 不统计总数的分页 -func (orm *Orm{orm_table_name}) Paginate(page int, limit int) {orm_table_name}List { +func (orm *Orm{orm_table_name}) SimplePaginate(page int, limit int) {orm_table_name}List { list := make([]*{orm_table_name}, 0) if page == 0 { page = 1 } - offset := (page - 1) * limit tx := orm.db.Offset(offset).Limit(limit).Find(&list) if tx.Error != nil { @@ -243,6 +272,20 @@ func (orm *Orm{orm_table_name}) Where(query interface{}, args ...interface{}) *O return orm } +func (orm *Orm{orm_table_name}) Select(query interface{}, args ...interface{}) *Orm{orm_table_name} { + orm.db.Select(query, args...) + return orm +} + +func (orm *Orm{orm_table_name}) Sum(field string) int64 { + type result struct { + S int64 `json:"s"` + } + ret := result{} + orm.db.Select("SUM(\""+field+"\") AS s").Scan(&ret) + return ret.S +} + func (orm *Orm{orm_table_name}) And(fuc func(orm *Orm{orm_table_name})) *Orm{orm_table_name} { ormAnd := NewOrm{orm_table_name}() fuc(ormAnd) @@ -259,8 +302,12 @@ func (orm *Orm{orm_table_name}) Or(fuc func(orm *Orm{orm_table_name})) *Orm{orm_ // Preload preload associations with given conditions // db.Preload("Orders|orders", "state NOT IN (?)", "cancelled").Find(&users) -func (orm *OrmMysqlTableName) Preload(query string, args ...interface{}) *OrmMysqlTableName { - orm.db.Preload(home.StringToHump(query), args...) +func (orm *Orm{orm_table_name}) Preload(query string, args ...interface{}) *Orm{orm_table_name} { + arr := strings.Split(query, ".") + for i, _ := range arr { + arr[i] = home.StringToHump(arr[i]) + } + orm.db.Preload(strings.Join(arr, "."), args...) return orm } @@ -268,7 +315,7 @@ func (orm *OrmMysqlTableName) Preload(query string, args ...interface{}) *OrmMys // db.Joins("Account|account").Find(&user) // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) // db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) -func (orm *OrmMysqlTableName) Joins(query string, args ...interface{}) *OrmMysqlTableName { +func (orm *Orm{orm_table_name}) Joins(query string, args ...interface{}) *Orm{orm_table_name} { if !strings.Contains(query, " ") { query = home.StringToHump(query) }