mirror of
https://github.com/go-home-admin/toolset.git
synced 2025-12-24 13:37:52 +08:00
feat: postgresql orm
This commit is contained in:
@@ -65,7 +65,7 @@ func (OrmCommand) Execute(input command.Input) {
|
||||
switch driver {
|
||||
case "mysql":
|
||||
orm.GenMysql(s.(string), conf, out)
|
||||
case "pgsql":
|
||||
case "postgresql":
|
||||
pgorm.GenSql(s.(string), conf, out)
|
||||
}
|
||||
|
||||
|
||||
@@ -123,7 +123,7 @@ func (orm *OrmMysqlTableName) Offset(offset int) *OrmMysqlTableName {
|
||||
orm.db.Offset(offset)
|
||||
return orm
|
||||
}
|
||||
// 直接查询列表, 如果需要条数, 使用Find()
|
||||
// Get 直接查询列表, 如果需要条数, 使用Find()
|
||||
func (orm *OrmMysqlTableName) Get() MysqlTableNameList {
|
||||
got, _ := orm.Find()
|
||||
return got
|
||||
|
||||
@@ -141,7 +141,7 @@ func (orm *OrmMysqlTableName) Offset(offset int) *OrmMysqlTableName {
|
||||
orm.db.Offset(offset)
|
||||
return orm
|
||||
}
|
||||
// 直接查询列表, 如果需要条数, 使用Find()
|
||||
// Get 直接查询列表, 如果需要条数, 使用Find()
|
||||
func (orm *OrmMysqlTableName) Get() MysqlTableNameList {
|
||||
got, _ := orm.Find()
|
||||
return got
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
_ "github.com/lib/pq"
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -30,7 +31,8 @@ func GenSql(name string, conf Conf, out string) {
|
||||
}
|
||||
|
||||
db := NewDb(conf)
|
||||
tableColumns := db.tableColumns()
|
||||
tableInfos := db.tableColumns()
|
||||
tableColumns := tableInfos.Columns
|
||||
|
||||
root, _ := os.Getwd()
|
||||
file, err := os.ReadFile(root + "/config/database/" + name + ".json")
|
||||
@@ -43,20 +45,30 @@ func GenSql(name string, conf Conf, out string) {
|
||||
}
|
||||
|
||||
// 计算import
|
||||
imports := getImports(tableColumns)
|
||||
imports := getImports(tableInfos.Infos, tableColumns)
|
||||
for table, columns := range tableColumns {
|
||||
tableName := parser.StringToSnake(table)
|
||||
file := out + "/" + tableName
|
||||
tableConfig := tableInfos.Infos[table]
|
||||
mysqlTableName := parser.StringToSnake(table)
|
||||
file := out + "/" + mysqlTableName
|
||||
|
||||
if _, err := os.Stat(file + "_lock.go"); !os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
|
||||
str := "package " + name
|
||||
str += "\nimport (" + imports[table] + "\n)"
|
||||
str += "\n" + genOrmStruct(table, columns, conf, relationship[table])
|
||||
|
||||
baseFunStr := baseMysqlFuncStr
|
||||
var baseFunStr string
|
||||
if tableConfig.IsSub() {
|
||||
baseFunStr = basePgsqlSubInfoStr
|
||||
} else {
|
||||
baseFunStr = basePgsqlFuncStr
|
||||
}
|
||||
for old, newStr := range map[string]string{
|
||||
"{orm_table_name}": parser.StringToHump(table),
|
||||
"{table_name}": table,
|
||||
"{db}": name,
|
||||
"{connect_name}": name,
|
||||
} {
|
||||
baseFunStr = strings.ReplaceAll(baseFunStr, old, newStr)
|
||||
}
|
||||
@@ -75,22 +87,23 @@ func genListFunc(table string, columns []tableColumn) string {
|
||||
TableName := parser.StringToHump(table)
|
||||
str := "\ntype " + TableName + "List []*" + TableName
|
||||
for _, column := range columns {
|
||||
ColumnName := parser.StringToHump(column.ColumnName)
|
||||
// 索引,或者枚举字段
|
||||
if strInStr(column.ColumnName, []string{"id", "code"}) {
|
||||
if strInStr(column.ColumnName, []string{"id", "code"}) || strInStr(column.Comment, []string{"@index"}) {
|
||||
goType := column.GoType
|
||||
if column.IsNullable {
|
||||
goType = "*" + goType
|
||||
}
|
||||
str += "\nfunc (l " + TableName + "List) Get" + column.ColumnName + "List() []" + goType + " {" +
|
||||
str += "\nfunc (l " + TableName + "List) Get" + ColumnName + "List() []" + goType + " {" +
|
||||
"\n\tgot := make([]" + goType + ", 0)\n\tfor _, val := range l {" +
|
||||
"\n\t\tgot = append(got, val." + column.ColumnName + ")" +
|
||||
"\n\t\tgot = append(got, val." + ColumnName + ")" +
|
||||
"\n\t}" +
|
||||
"\n\treturn got" +
|
||||
"\n}"
|
||||
|
||||
str += "\nfunc (l " + TableName + "List) Get" + column.ColumnName + "Map() map[" + goType + "]*" + TableName + " {" +
|
||||
str += "\nfunc (l " + TableName + "List) Get" + ColumnName + "Map() map[" + goType + "]*" + TableName + " {" +
|
||||
"\n\tgot := make(map[" + goType + "]*" + TableName + ")\n\tfor _, val := range l {" +
|
||||
"\n\t\tgot[val." + column.ColumnName + "] = val" +
|
||||
"\n\t\tgot[val." + ColumnName + "] = val" +
|
||||
"\n\t}" +
|
||||
"\n\treturn got" +
|
||||
"\n}"
|
||||
@@ -104,73 +117,76 @@ func genFieldFunc(table string, columns []tableColumn) string {
|
||||
|
||||
str := ""
|
||||
for _, column := range columns {
|
||||
ColumnName := parser.StringToHump(column.ColumnName)
|
||||
// 等于函数
|
||||
str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "(val " + column.GoType + ") *Orm" + TableName + " {" +
|
||||
"\n\torm.db.Where(\"`" + column.ColumnName + "` = ?\", val)" +
|
||||
str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "(val " + column.GoType + ") *Orm" + TableName + " {" +
|
||||
"\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" = ?\", val)" +
|
||||
"\n\treturn orm" +
|
||||
"\n}"
|
||||
|
||||
if column.IsPKey {
|
||||
if strInStr(column.GoType, []string{"int32", "int64"}) {
|
||||
goType := column.GoType
|
||||
if column.IsNullable {
|
||||
goType = "*" + goType
|
||||
}
|
||||
// if 主键, 生成In, > <
|
||||
str += "\nfunc (orm *Orm" + TableName + ") InsertGet" + column.ColumnName + "(row *" + TableName + ") " + goType + " {" +
|
||||
"\n\torm.db.Create(row)" +
|
||||
"\n\treturn row." + column.ColumnName +
|
||||
"\n}"
|
||||
if column.IsPKey {
|
||||
str += "\nfunc (orm *Orm" + TableName + ") InsertGet" + ColumnName + "(row *" + TableName + ") " + goType + " {" +
|
||||
"\n\torm.db.Create(row)" +
|
||||
"\n\treturn row." + ColumnName +
|
||||
"\n}"
|
||||
}
|
||||
|
||||
str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "In(val []" + column.GoType + ") *Orm" + TableName + " {" +
|
||||
"\n\torm.db.Where(\"`" + column.ColumnName + "` IN ?\", val)" +
|
||||
str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "In(val []" + column.GoType + ") *Orm" + TableName + " {" +
|
||||
"\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" IN ?\", val)" +
|
||||
"\n\treturn orm" +
|
||||
"\n}"
|
||||
|
||||
str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Gt(val " + column.GoType + ") *Orm" + TableName + " {" +
|
||||
"\n\torm.db.Where(\"`" + column.ColumnName + "` > ?\", val)" +
|
||||
str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "Gt(val " + column.GoType + ") *Orm" + TableName + " {" +
|
||||
"\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" > ?\", val)" +
|
||||
"\n\treturn orm" +
|
||||
"\n}"
|
||||
str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Gte(val " + column.GoType + ") *Orm" + TableName + " {" +
|
||||
"\n\torm.db.Where(\"`" + column.ColumnName + "` >= ?\", val)" +
|
||||
str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "Gte(val " + column.GoType + ") *Orm" + TableName + " {" +
|
||||
"\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" >= ?\", val)" +
|
||||
"\n\treturn orm" +
|
||||
"\n}"
|
||||
|
||||
str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Lt(val " + column.GoType + ") *Orm" + TableName + " {" +
|
||||
"\n\torm.db.Where(\"`" + column.ColumnName + "` < ?\", val)" +
|
||||
str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "Lt(val " + column.GoType + ") *Orm" + TableName + " {" +
|
||||
"\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" < ?\", val)" +
|
||||
"\n\treturn orm" +
|
||||
"\n}"
|
||||
str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Lte(val " + column.GoType + ") *Orm" + TableName + " {" +
|
||||
"\n\torm.db.Where(\"`" + column.ColumnName + "` <= ?\", val)" +
|
||||
str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "Lte(val " + column.GoType + ") *Orm" + TableName + " {" +
|
||||
"\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" <= ?\", val)" +
|
||||
"\n\treturn orm" +
|
||||
"\n}"
|
||||
} else {
|
||||
// 索引,或者枚举字段
|
||||
if strInStr(column.ColumnName, []string{"id", "code", "status", "state"}) {
|
||||
// else if 名称存在 id, code, status 生成in操作
|
||||
str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "In(val []" + column.GoType + ") *Orm" + TableName + " {" +
|
||||
"\n\torm.db.Where(\"`" + column.ColumnName + "` IN ?\", val)" +
|
||||
str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "In(val []" + column.GoType + ") *Orm" + TableName + " {" +
|
||||
"\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" IN ?\", val)" +
|
||||
"\n\treturn orm" +
|
||||
"\n}"
|
||||
|
||||
str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Ne(val " + column.GoType + ") *Orm" + TableName + " {" +
|
||||
"\n\torm.db.Where(\"`" + column.ColumnName + "` <> ?\", val)" +
|
||||
str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "Ne(val " + column.GoType + ") *Orm" + TableName + " {" +
|
||||
"\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" <> ?\", val)" +
|
||||
"\n\treturn orm" +
|
||||
"\n}"
|
||||
}
|
||||
// 时间字段
|
||||
if strInStr(column.ColumnName, []string{"created", "updated", "time", "_at"}) || (column.GoType == "database.Time") {
|
||||
str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Between(begin " + column.GoType + ", end " + column.GoType + ") *Orm" + TableName + " {" +
|
||||
"\n\torm.db.Where(\"`" + column.ColumnName + "` BETWEEN ? AND ?\", begin, end)" +
|
||||
str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "Between(begin " + column.GoType + ", end " + column.GoType + ") *Orm" + TableName + " {" +
|
||||
"\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" BETWEEN ? AND ?\", begin, end)" +
|
||||
"\n\treturn orm" +
|
||||
"\n}"
|
||||
|
||||
str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Lte(val " + column.GoType + ") *Orm" + TableName + " {" +
|
||||
"\n\torm.db.Where(\"`" + column.ColumnName + "` <= ?\", val)" +
|
||||
str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "Lte(val " + column.GoType + ") *Orm" + TableName + " {" +
|
||||
"\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" <= ?\", val)" +
|
||||
"\n\treturn orm" +
|
||||
"\n}"
|
||||
|
||||
str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Gte(val " + column.GoType + ") *Orm" + TableName + " {" +
|
||||
"\n\torm.db.Where(\"`" + column.ColumnName + "` >= ?\", val)" +
|
||||
str += "\nfunc (orm *Orm" + TableName + ") Where" + ColumnName + "Gte(val " + column.GoType + ") *Orm" + TableName + " {" +
|
||||
"\n\torm.db.Where(\"\\\"" + column.ColumnName + "\\\" >= ?\", val)" +
|
||||
"\n\treturn orm" +
|
||||
"\n}"
|
||||
}
|
||||
@@ -189,17 +205,19 @@ func strInStr(s string, in []string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
//go:embed pgsql.go.subtext
|
||||
var basePgsqlSubInfoStr string
|
||||
|
||||
//go:embed pgsql.go.text
|
||||
var baseMysqlFuncStr string
|
||||
var basePgsqlFuncStr string
|
||||
|
||||
// 字段类型引入
|
||||
var alias = map[string]string{
|
||||
"database": "github.com/go-home-admin/home/bootstrap/services/database",
|
||||
"datatypes": "gorm.io/datatypes",
|
||||
"database": "github.com/go-home-admin/home/bootstrap/services/database",
|
||||
}
|
||||
|
||||
// 获得 table => map{alias => github.com/*}
|
||||
func getImports(tableColumns map[string][]tableColumn) map[string]string {
|
||||
func getImports(infos map[string]orm.TableInfos, tableColumns map[string][]tableColumn) map[string]string {
|
||||
got := make(map[string]string)
|
||||
for table, columns := range tableColumns {
|
||||
// 初始引入
|
||||
@@ -211,10 +229,14 @@ func getImports(tableColumns map[string][]tableColumn) map[string]string {
|
||||
"database/sql": "sql",
|
||||
"github.com/go-home-admin/home/app": "home",
|
||||
}
|
||||
if infos[table].IsSub() {
|
||||
delete(tm, "github.com/go-home-admin/home/bootstrap/providers")
|
||||
}
|
||||
|
||||
for _, column := range columns {
|
||||
index := strings.Index(column.GoType, ".")
|
||||
if index != -1 && column.GoType[:index] != "gorm" {
|
||||
as := strings.Replace(column.GoType[:index], "*", "", 1)
|
||||
if index != -1 {
|
||||
as := column.GoType[:index]
|
||||
importStr := alias[as]
|
||||
tm[importStr] = as
|
||||
}
|
||||
@@ -232,12 +254,20 @@ func genOrmStruct(table string, columns []tableColumn, conf Conf, relationships
|
||||
str := `type {TableName} struct {`
|
||||
for _, column := range columns {
|
||||
p := ""
|
||||
if column.IsNullable && column.ColumnName != "deleted_at" {
|
||||
if column.IsNullable && !(column.ColumnName == "deleted_at" && column.GoType == "database.Time") {
|
||||
p = "*"
|
||||
}
|
||||
if column.ColumnName == "deleted_at" {
|
||||
if column.ColumnName == "deleted_at" && column.GoType == "database.Time" {
|
||||
column.GoType = "gorm.DeletedAt"
|
||||
}
|
||||
|
||||
// 使用注释@Type(int), 强制设置生成的go struct 属性 类型
|
||||
if i := strings.Index(column.ColumnName, "@type("); i != -1 {
|
||||
s := column.Comment[i+6:]
|
||||
e := strings.Index(s, ")")
|
||||
column.GoType = s[:e]
|
||||
}
|
||||
|
||||
hasField[column.ColumnName] = true
|
||||
fieldName := parser.StringToHump(column.ColumnName)
|
||||
str += fmt.Sprintf("\n\t%v %v%v`%v` // %v", fieldName, p, column.GoType, genGormTag(column), strings.ReplaceAll(column.Comment, "\n", " "))
|
||||
@@ -310,6 +340,8 @@ func genGormTag(column tableColumn) string {
|
||||
// 主键
|
||||
if column.IsPKey {
|
||||
arr = append(arr, "primaryKey")
|
||||
} else if column.IndexName != "" {
|
||||
arr = append(arr, "index:"+column.ColumnName)
|
||||
}
|
||||
// default
|
||||
if column.ColumnDefault != "" {
|
||||
@@ -335,13 +367,15 @@ func (d *DB) GetDB() *sql.DB {
|
||||
return d.db
|
||||
}
|
||||
|
||||
func (d *DB) tableColumns() map[string][]tableColumn {
|
||||
// 获取所有表信息
|
||||
// 过滤分表信息, table_{1-9} 只返回table
|
||||
func (d *DB) tableColumns() TableInfo {
|
||||
var sqlStr = "SELECT tablename FROM pg_tables WHERE schemaname = 'public'"
|
||||
|
||||
rows, err := d.db.Query(sqlStr)
|
||||
if err != nil {
|
||||
log.Println("Error reading table information: ", err.Error())
|
||||
return nil
|
||||
return TableInfo{}
|
||||
}
|
||||
defer rows.Close()
|
||||
ormColumns := make(map[string][]tableColumn)
|
||||
@@ -352,10 +386,12 @@ func (d *DB) tableColumns() map[string][]tableColumn {
|
||||
&tableName,
|
||||
)
|
||||
_rows, _ := d.db.Query(`
|
||||
SELECT i.column_name, i.column_default, i.is_nullable, i.udt_name, col_description(a.attrelid,a.attnum) as comment
|
||||
SELECT i.column_name, i.column_default, i.is_nullable, i.udt_name, col_description(a.attrelid,a.attnum) as comment, ixc.relname
|
||||
FROM information_schema.columns as i
|
||||
LEFT JOIN pg_class as c on c.relname = i.table_name
|
||||
LEFT JOIN pg_attribute as a on a.attrelid = c.oid and a.attname = i.column_name
|
||||
LEFT JOIN pg_index ix ON c.oid = ix.indrelid AND a.attnum = ANY(ix.indkey)
|
||||
LEFT JOIN pg_class ixc ON ixc.oid = ix.indexrelid
|
||||
WHERE table_schema = 'public' and i.table_name = $1;
|
||||
`, tableName)
|
||||
defer _rows.Close()
|
||||
@@ -380,32 +416,74 @@ WHERE pg_class.relname = $1 AND pg_constraint.contype = 'p'
|
||||
is_nullable string
|
||||
udt_name string
|
||||
comment *string
|
||||
index_name *string
|
||||
)
|
||||
err = _rows.Scan(&column_name, &column_default, &is_nullable, &udt_name, &comment)
|
||||
err = _rows.Scan(&column_name, &column_default, &is_nullable, &udt_name, &comment, &index_name)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var columnComment string
|
||||
var columnComment, indexName string
|
||||
if comment != nil {
|
||||
columnComment = *comment
|
||||
}
|
||||
if index_name != nil {
|
||||
indexName = *index_name
|
||||
}
|
||||
var ColumnDefault string
|
||||
if column_default != nil {
|
||||
ColumnDefault = *column_default
|
||||
}
|
||||
|
||||
ormColumns[tableName] = append(ormColumns[tableName], tableColumn{
|
||||
ColumnName: parser.StringToHump(column_name),
|
||||
ColumnName: column_name,
|
||||
ColumnDefault: ColumnDefault,
|
||||
PgType: udt_name,
|
||||
GoType: PgTypeToGoType(udt_name, column_name),
|
||||
IsNullable: is_nullable == "YES",
|
||||
IsPKey: false,
|
||||
IsPKey: column_name == pkey,
|
||||
Comment: columnComment,
|
||||
IndexName: indexName,
|
||||
})
|
||||
}
|
||||
}
|
||||
return ormColumns
|
||||
return Filter(ormColumns)
|
||||
}
|
||||
|
||||
// Filter 过滤分表格式
|
||||
// table_{0-9} 只返回table
|
||||
func Filter(tableColumns map[string][]tableColumn) TableInfo {
|
||||
info := TableInfo{
|
||||
Columns: make(map[string][]tableColumn),
|
||||
Infos: make(map[string]orm.TableInfos),
|
||||
}
|
||||
tableSort := make(map[string]int)
|
||||
for tableName, columns := range tableColumns {
|
||||
arr := strings.Split(tableName, "_")
|
||||
arrLen := len(arr)
|
||||
if arrLen > 1 {
|
||||
str := arr[arrLen-1]
|
||||
tn, err := strconv.Atoi(str)
|
||||
if err == nil {
|
||||
tableName = strings.ReplaceAll(tableName, "_"+str, "")
|
||||
info.Infos[tableName] = orm.TableInfos{
|
||||
"sub": "true", // 分表
|
||||
}
|
||||
// 保留数字最大的
|
||||
n, ok := tableSort[tableName]
|
||||
if ok && n > tn {
|
||||
continue
|
||||
}
|
||||
tableSort[tableName] = tn
|
||||
}
|
||||
}
|
||||
info.Columns[tableName] = columns
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
type TableInfo struct {
|
||||
Columns map[string][]tableColumn
|
||||
Infos map[string]orm.TableInfos
|
||||
}
|
||||
|
||||
type tableColumn struct {
|
||||
@@ -417,6 +495,7 @@ type tableColumn struct {
|
||||
IsNullable bool
|
||||
IsPKey bool
|
||||
Comment string
|
||||
IndexName string
|
||||
}
|
||||
|
||||
func PgTypeToGoType(pgType string, columnName string) string {
|
||||
|
||||
278
console/commands/pgorm/pgsql.go.subtext
Normal file
278
console/commands/pgorm/pgsql.go.subtext
Normal file
@@ -0,0 +1,278 @@
|
||||
|
||||
type Orm{orm_table_name} struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) GetDB() *gorm.DB {
|
||||
return orm.db
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) GetTableInfo() interface{} {
|
||||
return &{orm_table_name}{}
|
||||
}
|
||||
|
||||
// Create insert the value into database
|
||||
func (orm *Orm{orm_table_name}) Create(value interface{}) *gorm.DB {
|
||||
return orm.db.Create(value)
|
||||
}
|
||||
|
||||
// CreateInBatches insert the value in batches into database
|
||||
func (orm *Orm{orm_table_name}) CreateInBatches(value interface{}, batchSize int) *gorm.DB {
|
||||
return orm.db.CreateInBatches(value, batchSize)
|
||||
}
|
||||
|
||||
// Save update value in database, if the value doesn't have primary key, will insert it
|
||||
func (orm *Orm{orm_table_name}) Save(value interface{}) *gorm.DB {
|
||||
return orm.db.Save(value)
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) Row() *sql.Row {
|
||||
return orm.db.Row()
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) Rows() (*sql.Rows, error) {
|
||||
return orm.db.Rows()
|
||||
}
|
||||
|
||||
// Scan scan value to a struct
|
||||
func (orm *Orm{orm_table_name}) Scan(dest interface{}) *gorm.DB {
|
||||
return orm.db.Scan(dest)
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) ScanRows(rows *sql.Rows, dest interface{}) error {
|
||||
return orm.db.ScanRows(rows, dest)
|
||||
}
|
||||
|
||||
// Connection use a db conn to execute Multiple commands,this conn will put conn pool after it is executed.
|
||||
func (orm *Orm{orm_table_name}) Connection(fc func(tx *gorm.DB) error) (err error) {
|
||||
return orm.db.Connection(fc)
|
||||
}
|
||||
|
||||
// Transaction start a transaction as a block, return error will rollback, otherwise to commit.
|
||||
func (orm *Orm{orm_table_name}) Transaction(fc func(tx *gorm.DB) error, opts ...*sql.TxOptions) (err error) {
|
||||
return orm.db.Transaction(fc, opts...)
|
||||
}
|
||||
|
||||
// Begin begins a transaction
|
||||
func (orm *Orm{orm_table_name}) Begin(opts ...*sql.TxOptions) *gorm.DB {
|
||||
return orm.db.Begin(opts...)
|
||||
}
|
||||
|
||||
// Commit commit a transaction
|
||||
func (orm *Orm{orm_table_name}) Commit() *gorm.DB {
|
||||
return orm.db.Commit()
|
||||
}
|
||||
|
||||
// Rollback rollback a transaction
|
||||
func (orm *Orm{orm_table_name}) Rollback() *gorm.DB {
|
||||
return orm.db.Rollback()
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) SavePoint(name string) *gorm.DB {
|
||||
return orm.db.SavePoint(name)
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) RollbackTo(name string) *gorm.DB {
|
||||
return orm.db.RollbackTo(name)
|
||||
}
|
||||
|
||||
// Exec execute raw sql
|
||||
func (orm *Orm{orm_table_name}) Exec(sql string, values ...interface{}) *gorm.DB {
|
||||
return orm.db.Exec(sql, values...)
|
||||
}
|
||||
|
||||
// Exists 检索对象是否存在
|
||||
func (orm *Orm{orm_table_name}) Exists() (bool, error) {
|
||||
dest := &struct {
|
||||
H int `json:"h"`
|
||||
}{}
|
||||
db := orm.db.Select("1 as h").Limit(1).Find(dest)
|
||||
return dest.H == 1, db.Error
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) Unscoped() *Orm{orm_table_name} {
|
||||
orm.db.Unscoped()
|
||||
return orm
|
||||
}
|
||||
// ------------ 以下是单表独有的函数, 便捷字段条件, Laravel风格操作 ---------
|
||||
|
||||
func (orm *Orm{orm_table_name}) Insert(row *{orm_table_name}) error {
|
||||
return orm.db.Create(row).Error
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) Inserts(rows []*{orm_table_name}) *gorm.DB {
|
||||
return orm.db.Create(rows)
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) Order(value interface{}) *Orm{orm_table_name} {
|
||||
orm.db.Order(value)
|
||||
return orm
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) Group(name string) *Orm{orm_table_name} {
|
||||
orm.db.Group(name)
|
||||
return orm
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) Limit(limit int) *Orm{orm_table_name} {
|
||||
orm.db.Limit(limit)
|
||||
return orm
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) Offset(offset int) *Orm{orm_table_name} {
|
||||
orm.db.Offset(offset)
|
||||
return orm
|
||||
}
|
||||
// Get 直接查询列表, 如果需要条数, 使用Find()
|
||||
func (orm *Orm{orm_table_name}) Get() {orm_table_name}List {
|
||||
got, _ := orm.Find()
|
||||
return got
|
||||
}
|
||||
|
||||
// Pluck used to query single column from a model as a map
|
||||
// var ages []int64
|
||||
// db.Model(&users).Pluck("age", &ages)
|
||||
func (orm *Orm{orm_table_name}) Pluck(column string, dest interface{}) *gorm.DB {
|
||||
return orm.db.Pluck(column, dest)
|
||||
}
|
||||
|
||||
// Delete 有条件删除
|
||||
func (orm *Orm{orm_table_name}) Delete(conds ...interface{}) *gorm.DB {
|
||||
return orm.db.Delete(&{orm_table_name}{}, conds...)
|
||||
}
|
||||
|
||||
// DeleteAll 删除所有
|
||||
func (orm *Orm{orm_table_name}) DeleteAll() *gorm.DB {
|
||||
return orm.db.Exec("DELETE FROM {table_name}")
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) Count() int64 {
|
||||
var count int64
|
||||
orm.db.Count(&count)
|
||||
return count
|
||||
}
|
||||
|
||||
// First 检索单个对象
|
||||
func (orm *Orm{orm_table_name}) First(conds ...interface{}) (*{orm_table_name}, bool) {
|
||||
dest := &{orm_table_name}{}
|
||||
db := orm.db.Limit(1).Find(dest, conds...)
|
||||
return dest, db.RowsAffected == 1
|
||||
}
|
||||
|
||||
// Take return a record that match given conditions, the order will depend on the database implementation
|
||||
func (orm *Orm{orm_table_name}) Take(conds ...interface{}) (*{orm_table_name}, int64) {
|
||||
dest := &{orm_table_name}{}
|
||||
db := orm.db.Take(dest, conds...)
|
||||
return dest, db.RowsAffected
|
||||
}
|
||||
|
||||
// Last find last record that match given conditions, order by primary key
|
||||
func (orm *Orm{orm_table_name}) Last(conds ...interface{}) (*{orm_table_name}, int64) {
|
||||
dest := &{orm_table_name}{}
|
||||
db := orm.db.Last(dest, conds...)
|
||||
return dest, db.RowsAffected
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) Find(conds ...interface{}) ({orm_table_name}List, int64) {
|
||||
list := make([]*{orm_table_name}, 0)
|
||||
tx := orm.db.Find(&list, conds...)
|
||||
if tx.Error != nil {
|
||||
logrus.Error(tx.Error)
|
||||
}
|
||||
return list, tx.RowsAffected
|
||||
}
|
||||
|
||||
// Paginate 分页
|
||||
func (orm *Orm{orm_table_name}) Paginate(page int, limit int) ({orm_table_name}List, int64) {
|
||||
var total int64
|
||||
list := make([]*{orm_table_name}, 0)
|
||||
orm.db.Count(&total)
|
||||
if total > 0 {
|
||||
if page == 0 {
|
||||
page = 1
|
||||
}
|
||||
|
||||
offset := (page - 1) * limit
|
||||
tx := orm.db.Offset(offset).Limit(limit).Find(&list)
|
||||
if tx.Error != nil {
|
||||
logrus.Error(tx.Error)
|
||||
}
|
||||
}
|
||||
|
||||
return list, total
|
||||
}
|
||||
|
||||
// FindInBatches find records in batches
|
||||
func (orm *Orm{orm_table_name}) FindInBatches(dest interface{}, batchSize int, fc func(tx *gorm.DB, batch int) error) *gorm.DB {
|
||||
return orm.db.FindInBatches(dest, batchSize, fc)
|
||||
}
|
||||
|
||||
// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
|
||||
func (orm *Orm{orm_table_name}) FirstOrInit(dest *{orm_table_name}, conds ...interface{}) (*{orm_table_name}, *gorm.DB) {
|
||||
return dest, orm.db.FirstOrInit(dest, conds...)
|
||||
}
|
||||
|
||||
// FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions)
|
||||
func (orm *Orm{orm_table_name}) FirstOrCreate(dest interface{}, conds ...interface{}) *gorm.DB {
|
||||
return orm.db.FirstOrCreate(dest, conds...)
|
||||
}
|
||||
|
||||
// Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
|
||||
func (orm *Orm{orm_table_name}) Update(column string, value interface{}) *gorm.DB {
|
||||
return orm.db.Update(column, value)
|
||||
}
|
||||
|
||||
// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
|
||||
func (orm *Orm{orm_table_name}) Updates(values interface{}) *gorm.DB {
|
||||
return orm.db.Updates(values)
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) UpdateColumn(column string, value interface{}) *gorm.DB {
|
||||
return orm.db.UpdateColumn(column, value)
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) UpdateColumns(values interface{}) *gorm.DB {
|
||||
return orm.db.UpdateColumns(values)
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) Where(query interface{}, args ...interface{}) *Orm{orm_table_name} {
|
||||
orm.db.Where(query, args...)
|
||||
return orm
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) Select(query interface{}, args ...interface{}) *Orm{orm_table_name} {
|
||||
orm.db.Select(query, args...)
|
||||
return orm
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) Sum(field string) int64 {
|
||||
type result struct {
|
||||
S int64 `json:"s"`
|
||||
}
|
||||
ret := result{}
|
||||
orm.db.Select("SUM(\""+field+"\") AS s").Scan(&ret)
|
||||
return ret.S
|
||||
}
|
||||
|
||||
// Preload preload associations with given conditions
|
||||
// db.Preload("Orders|orders", "state NOT IN (?)", "cancelled").Find(&users)
|
||||
func (orm *Orm{orm_table_name}) Preload(query string, args ...interface{}) *Orm{orm_table_name} {
|
||||
arr := strings.Split(query, ".")
|
||||
for i, _ := range arr {
|
||||
arr[i] = home.StringToHump(arr[i])
|
||||
}
|
||||
orm.db.Preload(strings.Join(arr, "."), args...)
|
||||
return orm
|
||||
}
|
||||
|
||||
// Joins specify Joins conditions
|
||||
// db.Joins("Account|account").Find(&user)
|
||||
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
||||
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
|
||||
func (orm *Orm{orm_table_name}) Joins(query string, args ...interface{}) *Orm{orm_table_name} {
|
||||
if !strings.Contains(query, " ") {
|
||||
query = home.StringToHump(query)
|
||||
}
|
||||
orm.db.Joins(query, args...)
|
||||
return orm
|
||||
}
|
||||
@@ -7,16 +7,28 @@ type Orm{orm_table_name} struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewOrm{orm_table_name}() *Orm{orm_table_name} {
|
||||
orm := &Orm{orm_table_name}{}
|
||||
orm.db = providers.NewMysqlProvider().GetBean("{db}").(*gorm.DB).Model(&MysqlTableName{}).Model(&{orm_table_name}{})
|
||||
return orm
|
||||
func (orm *Orm{orm_table_name}) TableName() string {
|
||||
return "{table_name}"
|
||||
}
|
||||
|
||||
func NewOrm{orm_table_name}(txs ...*gorm.DB) *Orm{orm_table_name} {
|
||||
var tx *gorm.DB
|
||||
if len(txs) == 0 {
|
||||
tx = providers.GetBean("postgresql, {connect_name}").(*gorm.DB).Model(&{orm_table_name}{})
|
||||
} else {
|
||||
tx = txs[0].Model(&{orm_table_name}{})
|
||||
}
|
||||
return &Orm{orm_table_name}{db: tx}
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) GetDB() *gorm.DB {
|
||||
return orm.db
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) GetTableInfo() interface{} {
|
||||
return &{orm_table_name}{}
|
||||
}
|
||||
|
||||
// Create insert the value into database
|
||||
func (orm *Orm{orm_table_name}) Create(value interface{}) *gorm.DB {
|
||||
return orm.db.Create(value)
|
||||
@@ -87,6 +99,19 @@ func (orm *Orm{orm_table_name}) Exec(sql string, values ...interface{}) *gorm.DB
|
||||
return orm.db.Exec(sql, values...)
|
||||
}
|
||||
|
||||
// Exists 检索对象是否存在
|
||||
func (orm *Orm{orm_table_name}) Exists() (bool, error) {
|
||||
dest := &struct {
|
||||
H int `json:"h"`
|
||||
}{}
|
||||
db := orm.db.Select("1 as h").Limit(1).Find(dest)
|
||||
return dest.H == 1, db.Error
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) Unscoped() *Orm{orm_table_name} {
|
||||
orm.db.Unscoped()
|
||||
return orm
|
||||
}
|
||||
// ------------ 以下是单表独有的函数, 便捷字段条件, Laravel风格操作 ---------
|
||||
|
||||
func (orm *Orm{orm_table_name}) Insert(row *{orm_table_name}) error {
|
||||
@@ -102,6 +127,11 @@ func (orm *Orm{orm_table_name}) Order(value interface{}) *Orm{orm_table_name} {
|
||||
return orm
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) Group(name string) *Orm{orm_table_name} {
|
||||
orm.db.Group(name)
|
||||
return orm
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) Limit(limit int) *Orm{orm_table_name} {
|
||||
orm.db.Limit(limit)
|
||||
return orm
|
||||
@@ -111,7 +141,7 @@ func (orm *Orm{orm_table_name}) Offset(offset int) *Orm{orm_table_name} {
|
||||
orm.db.Offset(offset)
|
||||
return orm
|
||||
}
|
||||
// 直接查询列表, 如果需要条数, 使用Find()
|
||||
// Get 直接查询列表, 如果需要条数, 使用Find()
|
||||
func (orm *Orm{orm_table_name}) Get() {orm_table_name}List {
|
||||
got, _ := orm.Find()
|
||||
return got
|
||||
@@ -191,12 +221,11 @@ func (orm *Orm{orm_table_name}) Paginate(page int, limit int) ({orm_table_name}L
|
||||
}
|
||||
|
||||
// SimplePaginate 不统计总数的分页
|
||||
func (orm *Orm{orm_table_name}) Paginate(page int, limit int) {orm_table_name}List {
|
||||
func (orm *Orm{orm_table_name}) SimplePaginate(page int, limit int) {orm_table_name}List {
|
||||
list := make([]*{orm_table_name}, 0)
|
||||
if page == 0 {
|
||||
page = 1
|
||||
}
|
||||
|
||||
offset := (page - 1) * limit
|
||||
tx := orm.db.Offset(offset).Limit(limit).Find(&list)
|
||||
if tx.Error != nil {
|
||||
@@ -243,6 +272,20 @@ func (orm *Orm{orm_table_name}) Where(query interface{}, args ...interface{}) *O
|
||||
return orm
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) Select(query interface{}, args ...interface{}) *Orm{orm_table_name} {
|
||||
orm.db.Select(query, args...)
|
||||
return orm
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) Sum(field string) int64 {
|
||||
type result struct {
|
||||
S int64 `json:"s"`
|
||||
}
|
||||
ret := result{}
|
||||
orm.db.Select("SUM(\""+field+"\") AS s").Scan(&ret)
|
||||
return ret.S
|
||||
}
|
||||
|
||||
func (orm *Orm{orm_table_name}) And(fuc func(orm *Orm{orm_table_name})) *Orm{orm_table_name} {
|
||||
ormAnd := NewOrm{orm_table_name}()
|
||||
fuc(ormAnd)
|
||||
@@ -259,8 +302,12 @@ func (orm *Orm{orm_table_name}) Or(fuc func(orm *Orm{orm_table_name})) *Orm{orm_
|
||||
|
||||
// Preload preload associations with given conditions
|
||||
// db.Preload("Orders|orders", "state NOT IN (?)", "cancelled").Find(&users)
|
||||
func (orm *OrmMysqlTableName) Preload(query string, args ...interface{}) *OrmMysqlTableName {
|
||||
orm.db.Preload(home.StringToHump(query), args...)
|
||||
func (orm *Orm{orm_table_name}) Preload(query string, args ...interface{}) *Orm{orm_table_name} {
|
||||
arr := strings.Split(query, ".")
|
||||
for i, _ := range arr {
|
||||
arr[i] = home.StringToHump(arr[i])
|
||||
}
|
||||
orm.db.Preload(strings.Join(arr, "."), args...)
|
||||
return orm
|
||||
}
|
||||
|
||||
@@ -268,7 +315,7 @@ func (orm *OrmMysqlTableName) Preload(query string, args ...interface{}) *OrmMys
|
||||
// db.Joins("Account|account").Find(&user)
|
||||
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
||||
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
|
||||
func (orm *OrmMysqlTableName) Joins(query string, args ...interface{}) *OrmMysqlTableName {
|
||||
func (orm *Orm{orm_table_name}) Joins(query string, args ...interface{}) *Orm{orm_table_name} {
|
||||
if !strings.Contains(query, " ") {
|
||||
query = home.StringToHump(query)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user