feat: postgresql orm

This commit is contained in:
zodial
2024-07-15 11:31:17 +08:00
parent 694baf713b
commit 352bcb37b2
6 changed files with 471 additions and 67 deletions

View File

@@ -65,7 +65,7 @@ func (OrmCommand) Execute(input command.Input) {
switch driver {
case "mysql":
orm.GenMysql(s.(string), conf, out)
case "pgsql":
case "postgresql":
pgorm.GenSql(s.(string), conf, out)
}

View File

@@ -123,7 +123,7 @@ func (orm *OrmMysqlTableName) Offset(offset int) *OrmMysqlTableName {
orm.db.Offset(offset)
return orm
}
// 直接查询列表, 如果需要条数, 使用Find()
// Get 直接查询列表, 如果需要条数, 使用Find()
func (orm *OrmMysqlTableName) Get() MysqlTableNameList {
got, _ := orm.Find()
return got

View File

@@ -141,7 +141,7 @@ func (orm *OrmMysqlTableName) Offset(offset int) *OrmMysqlTableName {
orm.db.Offset(offset)
return orm
}
// 直接查询列表, 如果需要条数, 使用Find()
// Get 直接查询列表, 如果需要条数, 使用Find()
func (orm *OrmMysqlTableName) Get() MysqlTableNameList {
got, _ := orm.Find()
return got

View File

@@ -11,6 +11,7 @@ import (
_ "github.com/lib/pq"
"log"
"os"
"strconv"
"strings"
"time"
)
@@ -30,7 +31,8 @@ func GenSql(name string, conf Conf, out string) {
}
db := NewDb(conf)
tableColumns := db.tableColumns()
tableInfos := db.tableColumns()
tableColumns := tableInfos.Columns
root, _ := os.Getwd()
file, err := os.ReadFile(root + "/config/database/" + name + ".json")
@@ -43,20 +45,30 @@ func GenSql(name string, conf Conf, out string) {
}
// 计算import
imports := getImports(tableColumns)
imports := getImports(tableInfos.Infos, tableColumns)
for table, columns := range tableColumns {
tableName := parser.StringToSnake(table)
file := out + "/" + tableName
tableConfig := tableInfos.Infos[table]
mysqlTableName := parser.StringToSnake(table)
file := out + "/" + mysqlTableName
if _, err := os.Stat(file + "_lock.go"); !os.IsNotExist(err) {
continue
}
str := "package " + name
str += "\nimport (" + imports[table] + "\n)"
str += "\n" + genOrmStruct(table, columns, conf, relationship[table])
baseFunStr := baseMysqlFuncStr
var baseFunStr string
if tableConfig.IsSub() {
baseFunStr = basePgsqlSubInfoStr
} else {
baseFunStr = basePgsqlFuncStr
}
for old, newStr := range map[string]string{
"{orm_table_name}": parser.StringToHump(table),
"{table_name}": table,
"{db}": name,
"{connect_name}": name,
} {
baseFunStr = strings.ReplaceAll(baseFunStr, old, newStr)
}
@@ -75,22 +87,23 @@ func genListFunc(table string, columns []tableColumn) string {
TableName := parser.StringToHump(table)
str := "\ntype " + TableName + "List []*" + TableName
for _, column := range columns {
ColumnName := parser.StringToHump(column.ColumnName)
// 索引,或者枚举字段
if strInStr(column.ColumnName, []string{"id", "code"}) {
if strInStr(column.ColumnName, []string{"id", "code"}) || strInStr(column.Comment, []string{"@index"}) {
goType := column.GoType
if column.IsNullable {
goType = "*" + goType
}
str += "\nfunc (l " + TableName + "List) Get" + column.ColumnName + "List() []" + goType + " {" +
str += "\nfunc (l " + TableName + "List) Get" + ColumnName + "List() []" + goType + " {" +
"\n\tgot := make([]" + goType + ", 0)\n\tfor _, val := range l {" +
"\n\t\tgot = append(got, val." + column.ColumnName + ")" +
"\n\t\tgot = append(got, val." + ColumnName + ")" +
"\n\t}" +
"\n\treturn got" +
"\n}"
str += "\nfunc (l " + TableName + "List) Get" + column.ColumnName + "Map() map[" + goType + "]*" + TableName + " {" +
str += "\nfunc (l " + TableName + "List) Get" + ColumnName + "Map() map[" + goType + "]*" + TableName + " {" +
"\n\tgot := make(map[" + goType + "]*" + TableName + ")\n\tfor _, val := range l {" +
"\n\t\tgot[val." + column.ColumnName + "] = val" +
"\n\t\tgot[val." + ColumnName + "] = val" +
"\n\t}" +
"\n\treturn got" +
"\n}"
@@ -104,73 +117,76 @@ func genFieldFunc(table string, columns []tableColumn) string {
str := ""
for _, column := range columns {
ColumnName := parser.StringToHump(column.ColumnName)
// 等于函数
str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "(val " + column.GoType + ") *Orm" + TableName + " {" +
"\n\torm.db.Where(\"`" + column.ColumnName + "` = ?\", val)" +
str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "(val " + column.GoType + ") *Orm" + TableName + " {" +
"\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" = ?\", val)" +
"\n\treturn orm" +
"\n}"
if column.IsPKey {
if strInStr(column.GoType, []string{"int32", "int64"}) {
goType := column.GoType
if column.IsNullable {
goType = "*" + goType
}
// if 主键, 生成In, > <
str += "\nfunc (orm *Orm" + TableName + ") InsertGet" + column.ColumnName + "(row *" + TableName + ") " + goType + " {" +
"\n\torm.db.Create(row)" +
"\n\treturn row." + column.ColumnName +
"\n}"
if column.IsPKey {
str += "\nfunc (orm *Orm" + TableName + ") InsertGet" + ColumnName + "(row *" + TableName + ") " + goType + " {" +
"\n\torm.db.Create(row)" +
"\n\treturn row." + ColumnName +
"\n}"
}
str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "In(val []" + column.GoType + ") *Orm" + TableName + " {" +
"\n\torm.db.Where(\"`" + column.ColumnName + "` IN ?\", val)" +
str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "In(val []" + column.GoType + ") *Orm" + TableName + " {" +
"\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" IN ?\", val)" +
"\n\treturn orm" +
"\n}"
str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Gt(val " + column.GoType + ") *Orm" + TableName + " {" +
"\n\torm.db.Where(\"`" + column.ColumnName + "` > ?\", val)" +
str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "Gt(val " + column.GoType + ") *Orm" + TableName + " {" +
"\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" > ?\", val)" +
"\n\treturn orm" +
"\n}"
str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Gte(val " + column.GoType + ") *Orm" + TableName + " {" +
"\n\torm.db.Where(\"`" + column.ColumnName + "` >= ?\", val)" +
str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "Gte(val " + column.GoType + ") *Orm" + TableName + " {" +
"\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" >= ?\", val)" +
"\n\treturn orm" +
"\n}"
str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Lt(val " + column.GoType + ") *Orm" + TableName + " {" +
"\n\torm.db.Where(\"`" + column.ColumnName + "` < ?\", val)" +
str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "Lt(val " + column.GoType + ") *Orm" + TableName + " {" +
"\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" < ?\", val)" +
"\n\treturn orm" +
"\n}"
str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Lte(val " + column.GoType + ") *Orm" + TableName + " {" +
"\n\torm.db.Where(\"`" + column.ColumnName + "` <= ?\", val)" +
str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "Lte(val " + column.GoType + ") *Orm" + TableName + " {" +
"\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" <= ?\", val)" +
"\n\treturn orm" +
"\n}"
} else {
// 索引,或者枚举字段
if strInStr(column.ColumnName, []string{"id", "code", "status", "state"}) {
// else if 名称存在 id, code, status 生成in操作
str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "In(val []" + column.GoType + ") *Orm" + TableName + " {" +
"\n\torm.db.Where(\"`" + column.ColumnName + "` IN ?\", val)" +
str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "In(val []" + column.GoType + ") *Orm" + TableName + " {" +
"\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" IN ?\", val)" +
"\n\treturn orm" +
"\n}"
str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Ne(val " + column.GoType + ") *Orm" + TableName + " {" +
"\n\torm.db.Where(\"`" + column.ColumnName + "` <> ?\", val)" +
str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "Ne(val " + column.GoType + ") *Orm" + TableName + " {" +
"\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" <> ?\", val)" +
"\n\treturn orm" +
"\n}"
}
// 时间字段
if strInStr(column.ColumnName, []string{"created", "updated", "time", "_at"}) || (column.GoType == "database.Time") {
str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Between(begin " + column.GoType + ", end " + column.GoType + ") *Orm" + TableName + " {" +
"\n\torm.db.Where(\"`" + column.ColumnName + "` BETWEEN ? AND ?\", begin, end)" +
str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "Between(begin " + column.GoType + ", end " + column.GoType + ") *Orm" + TableName + " {" +
"\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" BETWEEN ? AND ?\", begin, end)" +
"\n\treturn orm" +
"\n}"
str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Lte(val " + column.GoType + ") *Orm" + TableName + " {" +
"\n\torm.db.Where(\"`" + column.ColumnName + "` <= ?\", val)" +
str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "Lte(val " + column.GoType + ") *Orm" + TableName + " {" +
"\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" <= ?\", val)" +
"\n\treturn orm" +
"\n}"
str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Gte(val " + column.GoType + ") *Orm" + TableName + " {" +
"\n\torm.db.Where(\"`" + column.ColumnName + "` >= ?\", val)" +
str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "Gte(val " + column.GoType + ") *Orm" + TableName + " {" +
"\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" >= ?\", val)" +
"\n\treturn orm" +
"\n}"
}
@@ -189,17 +205,19 @@ func strInStr(s string, in []string) bool {
return false
}
//go:embed pgsql.go.subtext
var basePgsqlSubInfoStr string
//go:embed pgsql.go.text
var baseMysqlFuncStr string
var basePgsqlFuncStr string
// 字段类型引入
var alias = map[string]string{
"database": "github.com/go-home-admin/home/bootstrap/services/database",
"datatypes": "gorm.io/datatypes",
"database": "github.com/go-home-admin/home/bootstrap/services/database",
}
// 获得 table => map{alias => github.com/*}
func getImports(tableColumns map[string][]tableColumn) map[string]string {
func getImports(infos map[string]orm.TableInfos, tableColumns map[string][]tableColumn) map[string]string {
got := make(map[string]string)
for table, columns := range tableColumns {
// 初始引入
@@ -211,10 +229,14 @@ func getImports(tableColumns map[string][]tableColumn) map[string]string {
"database/sql": "sql",
"github.com/go-home-admin/home/app": "home",
}
if infos[table].IsSub() {
delete(tm, "github.com/go-home-admin/home/bootstrap/providers")
}
for _, column := range columns {
index := strings.Index(column.GoType, ".")
if index != -1 && column.GoType[:index] != "gorm" {
as := strings.Replace(column.GoType[:index], "*", "", 1)
if index != -1 {
as := column.GoType[:index]
importStr := alias[as]
tm[importStr] = as
}
@@ -232,12 +254,20 @@ func genOrmStruct(table string, columns []tableColumn, conf Conf, relationships
str := `type {TableName} struct {`
for _, column := range columns {
p := ""
if column.IsNullable && column.ColumnName != "deleted_at" {
if column.IsNullable && !(column.ColumnName == "deleted_at" && column.GoType == "database.Time") {
p = "*"
}
if column.ColumnName == "deleted_at" {
if column.ColumnName == "deleted_at" && column.GoType == "database.Time" {
column.GoType = "gorm.DeletedAt"
}
// 使用注释@Type(int), 强制设置生成的go struct 属性 类型
if i := strings.Index(column.ColumnName, "@type("); i != -1 {
s := column.Comment[i+6:]
e := strings.Index(s, ")")
column.GoType = s[:e]
}
hasField[column.ColumnName] = true
fieldName := parser.StringToHump(column.ColumnName)
str += fmt.Sprintf("\n\t%v %v%v`%v` // %v", fieldName, p, column.GoType, genGormTag(column), strings.ReplaceAll(column.Comment, "\n", " "))
@@ -310,6 +340,8 @@ func genGormTag(column tableColumn) string {
// 主键
if column.IsPKey {
arr = append(arr, "primaryKey")
} else if column.IndexName != "" {
arr = append(arr, "index:"+column.ColumnName)
}
// default
if column.ColumnDefault != "" {
@@ -335,13 +367,15 @@ func (d *DB) GetDB() *sql.DB {
return d.db
}
func (d *DB) tableColumns() map[string][]tableColumn {
// 获取所有表信息
// 过滤分表信息, table_{1-9} 只返回table
func (d *DB) tableColumns() TableInfo {
var sqlStr = "SELECT tablename FROM pg_tables WHERE schemaname = 'public'"
rows, err := d.db.Query(sqlStr)
if err != nil {
log.Println("Error reading table information: ", err.Error())
return nil
return TableInfo{}
}
defer rows.Close()
ormColumns := make(map[string][]tableColumn)
@@ -352,10 +386,12 @@ func (d *DB) tableColumns() map[string][]tableColumn {
&tableName,
)
_rows, _ := d.db.Query(`
SELECT i.column_name, i.column_default, i.is_nullable, i.udt_name, col_description(a.attrelid,a.attnum) as comment
SELECT i.column_name, i.column_default, i.is_nullable, i.udt_name, col_description(a.attrelid,a.attnum) as comment, ixc.relname
FROM information_schema.columns as i
LEFT JOIN pg_class as c on c.relname = i.table_name
LEFT JOIN pg_attribute as a on a.attrelid = c.oid and a.attname = i.column_name
LEFT JOIN pg_index ix ON c.oid = ix.indrelid AND a.attnum = ANY(ix.indkey)
LEFT JOIN pg_class ixc ON ixc.oid = ix.indexrelid
WHERE table_schema = 'public' and i.table_name = $1;
`, tableName)
defer _rows.Close()
@@ -380,32 +416,74 @@ WHERE pg_class.relname = $1 AND pg_constraint.contype = 'p'
is_nullable string
udt_name string
comment *string
index_name *string
)
err = _rows.Scan(&column_name, &column_default, &is_nullable, &udt_name, &comment)
err = _rows.Scan(&column_name, &column_default, &is_nullable, &udt_name, &comment, &index_name)
if err != nil {
panic(err)
}
var columnComment string
var columnComment, indexName string
if comment != nil {
columnComment = *comment
}
if index_name != nil {
indexName = *index_name
}
var ColumnDefault string
if column_default != nil {
ColumnDefault = *column_default
}
ormColumns[tableName] = append(ormColumns[tableName], tableColumn{
ColumnName: parser.StringToHump(column_name),
ColumnName: column_name,
ColumnDefault: ColumnDefault,
PgType: udt_name,
GoType: PgTypeToGoType(udt_name, column_name),
IsNullable: is_nullable == "YES",
IsPKey: false,
IsPKey: column_name == pkey,
Comment: columnComment,
IndexName: indexName,
})
}
}
return ormColumns
return Filter(ormColumns)
}
// Filter 过滤分表格式
// table_{0-9} 只返回table
func Filter(tableColumns map[string][]tableColumn) TableInfo {
info := TableInfo{
Columns: make(map[string][]tableColumn),
Infos: make(map[string]orm.TableInfos),
}
tableSort := make(map[string]int)
for tableName, columns := range tableColumns {
arr := strings.Split(tableName, "_")
arrLen := len(arr)
if arrLen > 1 {
str := arr[arrLen-1]
tn, err := strconv.Atoi(str)
if err == nil {
tableName = strings.ReplaceAll(tableName, "_"+str, "")
info.Infos[tableName] = orm.TableInfos{
"sub": "true", // 分表
}
// 保留数字最大的
n, ok := tableSort[tableName]
if ok && n > tn {
continue
}
tableSort[tableName] = tn
}
}
info.Columns[tableName] = columns
}
return info
}
type TableInfo struct {
Columns map[string][]tableColumn
Infos map[string]orm.TableInfos
}
type tableColumn struct {
@@ -417,6 +495,7 @@ type tableColumn struct {
IsNullable bool
IsPKey bool
Comment string
IndexName string
}
func PgTypeToGoType(pgType string, columnName string) string {

View File

@@ -0,0 +1,278 @@
type Orm{orm_table_name} struct {
db *gorm.DB
}
func (orm *Orm{orm_table_name}) GetDB() *gorm.DB {
return orm.db
}
func (orm *Orm{orm_table_name}) GetTableInfo() interface{} {
return &{orm_table_name}{}
}
// Create insert the value into database
func (orm *Orm{orm_table_name}) Create(value interface{}) *gorm.DB {
return orm.db.Create(value)
}
// CreateInBatches insert the value in batches into database
func (orm *Orm{orm_table_name}) CreateInBatches(value interface{}, batchSize int) *gorm.DB {
return orm.db.CreateInBatches(value, batchSize)
}
// Save update value in database, if the value doesn't have primary key, will insert it
func (orm *Orm{orm_table_name}) Save(value interface{}) *gorm.DB {
return orm.db.Save(value)
}
func (orm *Orm{orm_table_name}) Row() *sql.Row {
return orm.db.Row()
}
func (orm *Orm{orm_table_name}) Rows() (*sql.Rows, error) {
return orm.db.Rows()
}
// Scan scan value to a struct
func (orm *Orm{orm_table_name}) Scan(dest interface{}) *gorm.DB {
return orm.db.Scan(dest)
}
func (orm *Orm{orm_table_name}) ScanRows(rows *sql.Rows, dest interface{}) error {
return orm.db.ScanRows(rows, dest)
}
// Connection use a db conn to execute Multiple commands,this conn will put conn pool after it is executed.
func (orm *Orm{orm_table_name}) Connection(fc func(tx *gorm.DB) error) (err error) {
return orm.db.Connection(fc)
}
// Transaction start a transaction as a block, return error will rollback, otherwise to commit.
func (orm *Orm{orm_table_name}) Transaction(fc func(tx *gorm.DB) error, opts ...*sql.TxOptions) (err error) {
return orm.db.Transaction(fc, opts...)
}
// Begin begins a transaction
func (orm *Orm{orm_table_name}) Begin(opts ...*sql.TxOptions) *gorm.DB {
return orm.db.Begin(opts...)
}
// Commit commit a transaction
func (orm *Orm{orm_table_name}) Commit() *gorm.DB {
return orm.db.Commit()
}
// Rollback rollback a transaction
func (orm *Orm{orm_table_name}) Rollback() *gorm.DB {
return orm.db.Rollback()
}
func (orm *Orm{orm_table_name}) SavePoint(name string) *gorm.DB {
return orm.db.SavePoint(name)
}
func (orm *Orm{orm_table_name}) RollbackTo(name string) *gorm.DB {
return orm.db.RollbackTo(name)
}
// Exec execute raw sql
func (orm *Orm{orm_table_name}) Exec(sql string, values ...interface{}) *gorm.DB {
return orm.db.Exec(sql, values...)
}
// Exists 检索对象是否存在
func (orm *Orm{orm_table_name}) Exists() (bool, error) {
dest := &struct {
H int `json:"h"`
}{}
db := orm.db.Select("1 as h").Limit(1).Find(dest)
return dest.H == 1, db.Error
}
func (orm *Orm{orm_table_name}) Unscoped() *Orm{orm_table_name} {
orm.db.Unscoped()
return orm
}
// ------------ 以下是单表独有的函数, 便捷字段条件, Laravel风格操作 ---------
func (orm *Orm{orm_table_name}) Insert(row *{orm_table_name}) error {
return orm.db.Create(row).Error
}
func (orm *Orm{orm_table_name}) Inserts(rows []*{orm_table_name}) *gorm.DB {
return orm.db.Create(rows)
}
func (orm *Orm{orm_table_name}) Order(value interface{}) *Orm{orm_table_name} {
orm.db.Order(value)
return orm
}
func (orm *Orm{orm_table_name}) Group(name string) *Orm{orm_table_name} {
orm.db.Group(name)
return orm
}
func (orm *Orm{orm_table_name}) Limit(limit int) *Orm{orm_table_name} {
orm.db.Limit(limit)
return orm
}
func (orm *Orm{orm_table_name}) Offset(offset int) *Orm{orm_table_name} {
orm.db.Offset(offset)
return orm
}
// Get 直接查询列表, 如果需要条数, 使用Find()
func (orm *Orm{orm_table_name}) Get() {orm_table_name}List {
got, _ := orm.Find()
return got
}
// Pluck used to query single column from a model as a map
// var ages []int64
// db.Model(&users).Pluck("age", &ages)
func (orm *Orm{orm_table_name}) Pluck(column string, dest interface{}) *gorm.DB {
return orm.db.Pluck(column, dest)
}
// Delete 有条件删除
func (orm *Orm{orm_table_name}) Delete(conds ...interface{}) *gorm.DB {
return orm.db.Delete(&{orm_table_name}{}, conds...)
}
// DeleteAll 删除所有
func (orm *Orm{orm_table_name}) DeleteAll() *gorm.DB {
return orm.db.Exec("DELETE FROM {table_name}")
}
func (orm *Orm{orm_table_name}) Count() int64 {
var count int64
orm.db.Count(&count)
return count
}
// First 检索单个对象
func (orm *Orm{orm_table_name}) First(conds ...interface{}) (*{orm_table_name}, bool) {
dest := &{orm_table_name}{}
db := orm.db.Limit(1).Find(dest, conds...)
return dest, db.RowsAffected == 1
}
// Take return a record that match given conditions, the order will depend on the database implementation
func (orm *Orm{orm_table_name}) Take(conds ...interface{}) (*{orm_table_name}, int64) {
dest := &{orm_table_name}{}
db := orm.db.Take(dest, conds...)
return dest, db.RowsAffected
}
// Last find last record that match given conditions, order by primary key
func (orm *Orm{orm_table_name}) Last(conds ...interface{}) (*{orm_table_name}, int64) {
dest := &{orm_table_name}{}
db := orm.db.Last(dest, conds...)
return dest, db.RowsAffected
}
func (orm *Orm{orm_table_name}) Find(conds ...interface{}) ({orm_table_name}List, int64) {
list := make([]*{orm_table_name}, 0)
tx := orm.db.Find(&list, conds...)
if tx.Error != nil {
logrus.Error(tx.Error)
}
return list, tx.RowsAffected
}
// Paginate 分页
func (orm *Orm{orm_table_name}) Paginate(page int, limit int) ({orm_table_name}List, int64) {
var total int64
list := make([]*{orm_table_name}, 0)
orm.db.Count(&total)
if total > 0 {
if page == 0 {
page = 1
}
offset := (page - 1) * limit
tx := orm.db.Offset(offset).Limit(limit).Find(&list)
if tx.Error != nil {
logrus.Error(tx.Error)
}
}
return list, total
}
// FindInBatches find records in batches
func (orm *Orm{orm_table_name}) FindInBatches(dest interface{}, batchSize int, fc func(tx *gorm.DB, batch int) error) *gorm.DB {
return orm.db.FindInBatches(dest, batchSize, fc)
}
// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
func (orm *Orm{orm_table_name}) FirstOrInit(dest *{orm_table_name}, conds ...interface{}) (*{orm_table_name}, *gorm.DB) {
return dest, orm.db.FirstOrInit(dest, conds...)
}
// FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions)
func (orm *Orm{orm_table_name}) FirstOrCreate(dest interface{}, conds ...interface{}) *gorm.DB {
return orm.db.FirstOrCreate(dest, conds...)
}
// Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
func (orm *Orm{orm_table_name}) Update(column string, value interface{}) *gorm.DB {
return orm.db.Update(column, value)
}
// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
func (orm *Orm{orm_table_name}) Updates(values interface{}) *gorm.DB {
return orm.db.Updates(values)
}
func (orm *Orm{orm_table_name}) UpdateColumn(column string, value interface{}) *gorm.DB {
return orm.db.UpdateColumn(column, value)
}
func (orm *Orm{orm_table_name}) UpdateColumns(values interface{}) *gorm.DB {
return orm.db.UpdateColumns(values)
}
func (orm *Orm{orm_table_name}) Where(query interface{}, args ...interface{}) *Orm{orm_table_name} {
orm.db.Where(query, args...)
return orm
}
func (orm *Orm{orm_table_name}) Select(query interface{}, args ...interface{}) *Orm{orm_table_name} {
orm.db.Select(query, args...)
return orm
}
func (orm *Orm{orm_table_name}) Sum(field string) int64 {
type result struct {
S int64 `json:"s"`
}
ret := result{}
orm.db.Select("SUM(\""+field+"\") AS s").Scan(&ret)
return ret.S
}
// Preload preload associations with given conditions
// db.Preload("Orders|orders", "state NOT IN (?)", "cancelled").Find(&users)
func (orm *Orm{orm_table_name}) Preload(query string, args ...interface{}) *Orm{orm_table_name} {
arr := strings.Split(query, ".")
for i, _ := range arr {
arr[i] = home.StringToHump(arr[i])
}
orm.db.Preload(strings.Join(arr, "."), args...)
return orm
}
// Joins specify Joins conditions
// db.Joins("Account|account").Find(&user)
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
func (orm *Orm{orm_table_name}) Joins(query string, args ...interface{}) *Orm{orm_table_name} {
if !strings.Contains(query, " ") {
query = home.StringToHump(query)
}
orm.db.Joins(query, args...)
return orm
}

View File

@@ -7,16 +7,28 @@ type Orm{orm_table_name} struct {
db *gorm.DB
}
func NewOrm{orm_table_name}() *Orm{orm_table_name} {
orm := &Orm{orm_table_name}{}
orm.db = providers.NewMysqlProvider().GetBean("{db}").(*gorm.DB).Model(&MysqlTableName{}).Model(&{orm_table_name}{})
return orm
func (orm *Orm{orm_table_name}) TableName() string {
return "{table_name}"
}
func NewOrm{orm_table_name}(txs ...*gorm.DB) *Orm{orm_table_name} {
var tx *gorm.DB
if len(txs) == 0 {
tx = providers.GetBean("postgresql, {connect_name}").(*gorm.DB).Model(&{orm_table_name}{})
} else {
tx = txs[0].Model(&{orm_table_name}{})
}
return &Orm{orm_table_name}{db: tx}
}
func (orm *Orm{orm_table_name}) GetDB() *gorm.DB {
return orm.db
}
func (orm *Orm{orm_table_name}) GetTableInfo() interface{} {
return &{orm_table_name}{}
}
// Create insert the value into database
func (orm *Orm{orm_table_name}) Create(value interface{}) *gorm.DB {
return orm.db.Create(value)
@@ -87,6 +99,19 @@ func (orm *Orm{orm_table_name}) Exec(sql string, values ...interface{}) *gorm.DB
return orm.db.Exec(sql, values...)
}
// Exists 检索对象是否存在
func (orm *Orm{orm_table_name}) Exists() (bool, error) {
dest := &struct {
H int `json:"h"`
}{}
db := orm.db.Select("1 as h").Limit(1).Find(dest)
return dest.H == 1, db.Error
}
func (orm *Orm{orm_table_name}) Unscoped() *Orm{orm_table_name} {
orm.db.Unscoped()
return orm
}
// ------------ 以下是单表独有的函数, 便捷字段条件, Laravel风格操作 ---------
func (orm *Orm{orm_table_name}) Insert(row *{orm_table_name}) error {
@@ -102,6 +127,11 @@ func (orm *Orm{orm_table_name}) Order(value interface{}) *Orm{orm_table_name} {
return orm
}
func (orm *Orm{orm_table_name}) Group(name string) *Orm{orm_table_name} {
orm.db.Group(name)
return orm
}
func (orm *Orm{orm_table_name}) Limit(limit int) *Orm{orm_table_name} {
orm.db.Limit(limit)
return orm
@@ -111,7 +141,7 @@ func (orm *Orm{orm_table_name}) Offset(offset int) *Orm{orm_table_name} {
orm.db.Offset(offset)
return orm
}
// 直接查询列表, 如果需要条数, 使用Find()
// Get 直接查询列表, 如果需要条数, 使用Find()
func (orm *Orm{orm_table_name}) Get() {orm_table_name}List {
got, _ := orm.Find()
return got
@@ -191,12 +221,11 @@ func (orm *Orm{orm_table_name}) Paginate(page int, limit int) ({orm_table_name}L
}
// SimplePaginate 不统计总数的分页
func (orm *Orm{orm_table_name}) Paginate(page int, limit int) {orm_table_name}List {
func (orm *Orm{orm_table_name}) SimplePaginate(page int, limit int) {orm_table_name}List {
list := make([]*{orm_table_name}, 0)
if page == 0 {
page = 1
}
offset := (page - 1) * limit
tx := orm.db.Offset(offset).Limit(limit).Find(&list)
if tx.Error != nil {
@@ -243,6 +272,20 @@ func (orm *Orm{orm_table_name}) Where(query interface{}, args ...interface{}) *O
return orm
}
func (orm *Orm{orm_table_name}) Select(query interface{}, args ...interface{}) *Orm{orm_table_name} {
orm.db.Select(query, args...)
return orm
}
func (orm *Orm{orm_table_name}) Sum(field string) int64 {
type result struct {
S int64 `json:"s"`
}
ret := result{}
orm.db.Select("SUM(\""+field+"\") AS s").Scan(&ret)
return ret.S
}
func (orm *Orm{orm_table_name}) And(fuc func(orm *Orm{orm_table_name})) *Orm{orm_table_name} {
ormAnd := NewOrm{orm_table_name}()
fuc(ormAnd)
@@ -259,8 +302,12 @@ func (orm *Orm{orm_table_name}) Or(fuc func(orm *Orm{orm_table_name})) *Orm{orm_
// Preload preload associations with given conditions
// db.Preload("Orders|orders", "state NOT IN (?)", "cancelled").Find(&users)
func (orm *OrmMysqlTableName) Preload(query string, args ...interface{}) *OrmMysqlTableName {
orm.db.Preload(home.StringToHump(query), args...)
func (orm *Orm{orm_table_name}) Preload(query string, args ...interface{}) *Orm{orm_table_name} {
arr := strings.Split(query, ".")
for i, _ := range arr {
arr[i] = home.StringToHump(arr[i])
}
orm.db.Preload(strings.Join(arr, "."), args...)
return orm
}
@@ -268,7 +315,7 @@ func (orm *OrmMysqlTableName) Preload(query string, args ...interface{}) *OrmMys
// db.Joins("Account|account").Find(&user)
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
func (orm *OrmMysqlTableName) Joins(query string, args ...interface{}) *OrmMysqlTableName {
func (orm *Orm{orm_table_name}) Joins(query string, args ...interface{}) *Orm{orm_table_name} {
if !strings.Contains(query, " ") {
query = home.StringToHump(query)
}