From 044f9c7543661a8307c9f94155fb3dab7316c768 Mon Sep 17 00:00:00 2001 From: aliang Date: Mon, 6 Jun 2022 17:25:17 +0800 Subject: [PATCH] pgsql --- console/commands/orm.go | 11 +- console/commands/pgorm/pgsql.go | 448 +++++++++++++++++++++++++++ console/commands/pgorm/pgsql.go.text | 223 +++++++++++++ go.mod | 1 + go.sum | 2 + 5 files changed, 680 insertions(+), 5 deletions(-) create mode 100644 console/commands/pgorm/pgsql.go create mode 100644 console/commands/pgorm/pgsql.go.text diff --git a/console/commands/orm.go b/console/commands/orm.go index a756b0e..083fa22 100644 --- a/console/commands/orm.go +++ b/console/commands/orm.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/ctfang/command" "github.com/go-home-admin/toolset/console/commands/orm" + "github.com/go-home-admin/toolset/console/commands/pgorm" "github.com/joho/godotenv" "gopkg.in/yaml.v2" "os" @@ -51,9 +52,9 @@ func (OrmCommand) Execute(input command.Input) { fileContext = SetEnv(fileContext) m := make(map[string]interface{}) err = yaml.Unmarshal(fileContext, &m) - if err != nil { - panic(err) - } + //if err != nil { + // panic(err) + //} connections := m["connections"].(map[interface{}]interface{}) for s, confT := range connections { @@ -63,8 +64,8 @@ func (OrmCommand) Execute(input command.Input) { switch driver { case "mysql": orm.GenMysql(s.(string), conf, out) - case "postgresql": - + case "pgsql": + pgorm.GenSql(s.(string), conf, out) } cmd := exec.Command("go", []string{"fmt", out}...) diff --git a/console/commands/pgorm/pgsql.go b/console/commands/pgorm/pgsql.go new file mode 100644 index 0000000..ea22880 --- /dev/null +++ b/console/commands/pgorm/pgsql.go @@ -0,0 +1,448 @@ +package pgorm + +import ( + "database/sql" + _ "embed" + "fmt" + "github.com/go-home-admin/home/bootstrap/services" + "github.com/go-home-admin/toolset/parser" + _ "github.com/lib/pq" + "log" + "os" + "strings" + "time" +) + +// IsExist checks whether a file or directory exists. +// It returns false when the file or directory does not exist. +func IsExist(f string) bool { + _, err := os.Stat(f) + return err == nil || os.IsExist(err) +} + +type Conf map[interface{}]interface{} + +func GenSql(name string, conf Conf, out string) { + if !IsExist(out) { + os.MkdirAll(out, 0766) + } + + db := NewDb(conf) + tableColumns := db.tableColumns() + + // 计算import + imports := getImports(tableColumns) + for table, columns := range tableColumns { + tableName := parser.StringToSnake(table) + file := out + "/" + tableName + + str := "package " + name + str += "\nimport (" + imports[table] + "\n)" + str += "\n" + genOrmStruct(table, columns, conf) + + baseFunStr := baseMysqlFuncStr + for old, newStr := range map[string]string{ + "{orm_table_name}": parser.StringToHump(table), + "{table_name}": table, + "{db}": name, + } { + baseFunStr = strings.ReplaceAll(baseFunStr, old, newStr) + } + + str += baseFunStr + str += genFieldFunc(table, columns) + str += genListFunc(table, columns) + str += genWithFunc(table, columns, conf) + err := os.WriteFile(file+"_gen.go", []byte(str), 0766) + if err != nil { + log.Fatal(err) + } + } +} + +func genListFunc(table string, columns []tableColumn) string { + TableName := parser.StringToHump(table) + str := "\ntype " + TableName + "List []*" + TableName + for _, column := range columns { + // 索引,或者枚举字段 + if strInStr(column.ColumnName, []string{"id", "code"}) { + str += "\nfunc (l " + TableName + "List) Get" + column.ColumnName + "List() []" + column.GoType + " {" + + "\n\tgot := make([]" + column.GoType + ", 0)\n\tfor _, val := range l {" + + "\n\t\tgot = append(got, val." + column.ColumnName + ")" + + "\n\t}" + + "\n\treturn got" + + "\n}" + + str += "\nfunc (l " + TableName + "List) Get" + column.ColumnName + "Map() map[" + column.GoType + "]*" + TableName + " {" + + "\n\tgot := make(map[" + column.GoType + "]*" + TableName + ")\n\tfor _, val := range l {" + + "\n\t\tgot[val." + column.ColumnName + "] = val" + + "\n\t}" + + "\n\treturn got" + + "\n}" + } + } + return str +} + +func genWithFunc(table string, columns []tableColumn, conf Conf) string { + TableName := parser.StringToHump(table) + str := "" + if helper, ok := conf["helper"]; ok { + helperConf := helper.(map[interface{}]interface{}) + tableConfig, ok := helperConf[table].([]interface{}) + if ok { + for _, c := range tableConfig { + cf := c.(map[interface{}]interface{}) + with := cf["with"] + tbName := parser.StringToHump(cf["table"].(string)) + switch with { + case "many2many": + + default: + str += "\nfunc (orm *Orm" + TableName + ") Joins" + tbName + "(args ...interface{}) *Orm" + TableName + " {" + + "\n\torm.db.Joins(\"" + cf["alias"].(string) + "\", args...)" + + "\n\treturn orm" + + "\n}" + str += "\nfunc (orm *Orm" + TableName + ") Preload" + tbName + "(args ...interface{}) *Orm" + TableName + " {" + + "\n\torm.db.Preload(\"" + cf["alias"].(string) + "\", args...)" + + "\n\treturn orm" + + "\n}" + } + } + } + } + return str +} + +func genFieldFunc(table string, columns []tableColumn) string { + TableName := parser.StringToHump(table) + + str := "" + for _, column := range columns { + // 等于函数 + str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "(val " + column.GoType + ") *Orm" + TableName + " {" + + "\n\torm.db.Where(\"`" + column.ColumnName + "` = ?\", val)" + + "\n\treturn orm" + + "\n}" + + if column.IsPKey { + // if 主键, 生成In, > < + str += "\nfunc (orm *Orm" + TableName + ") InsertGet" + column.ColumnName + "(row *" + TableName + ") " + column.GoType + " {" + + "\n\torm.db.Create(row)" + + "\n\treturn row." + column.ColumnName + + "\n}" + + str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "In(val []" + column.GoType + ") *Orm" + TableName + " {" + + "\n\torm.db.Where(\"`" + column.ColumnName + "` IN ?\", val)" + + "\n\treturn orm" + + "\n}" + + str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Gt(val " + column.GoType + ") *Orm" + TableName + " {" + + "\n\torm.db.Where(\"`" + column.ColumnName + "` > ?\", val)" + + "\n\treturn orm" + + "\n}" + str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Gte(val " + column.GoType + ") *Orm" + TableName + " {" + + "\n\torm.db.Where(\"`" + column.ColumnName + "` >= ?\", val)" + + "\n\treturn orm" + + "\n}" + + str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Lt(val " + column.GoType + ") *Orm" + TableName + " {" + + "\n\torm.db.Where(\"`" + column.ColumnName + "` < ?\", val)" + + "\n\treturn orm" + + "\n}" + str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Lte(val " + column.GoType + ") *Orm" + TableName + " {" + + "\n\torm.db.Where(\"`" + column.ColumnName + "` <= ?\", val)" + + "\n\treturn orm" + + "\n}" + } else { + // 索引,或者枚举字段 + if strInStr(column.ColumnName, []string{"id", "code", "status", "state"}) { + // else if 名称存在 id, code, status 生成in操作 + str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "In(val []" + column.GoType + ") *Orm" + TableName + " {" + + "\n\torm.db.Where(\"`" + column.ColumnName + "` IN ?\", val)" + + "\n\treturn orm" + + "\n}" + + str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Ne(val " + column.GoType + ") *Orm" + TableName + " {" + + "\n\torm.db.Where(\"`" + column.ColumnName + "` <> ?\", val)" + + "\n\treturn orm" + + "\n}" + } + // 时间字段 + if strInStr(column.ColumnName, []string{"created", "updated", "time", "_at"}) || (column.GoType == "database.Time") { + str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Between(begin " + column.GoType + ", end " + column.GoType + ") *Orm" + TableName + " {" + + "\n\torm.db.Where(\"`" + column.ColumnName + "` BETWEEN ? AND ?\", begin, end)" + + "\n\treturn orm" + + "\n}" + + str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Lte(val " + column.GoType + ") *Orm" + TableName + " {" + + "\n\torm.db.Where(\"`" + column.ColumnName + "` <= ?\", val)" + + "\n\treturn orm" + + "\n}" + + str += "\nfunc (orm *Orm" + TableName + ") Where" + column.ColumnName + "Gte(val " + column.GoType + ") *Orm" + TableName + " {" + + "\n\torm.db.Where(\"`" + column.ColumnName + "` >= ?\", val)" + + "\n\treturn orm" + + "\n}" + } + } + } + + return str +} + +func strInStr(s string, in []string) bool { + for _, sub := range in { + if strings.Index(s, sub) != -1 { + return true + } + } + return false +} + +//go:embed pgsql.go.text +var baseMysqlFuncStr string + +// 字段类型引入 +var alias = map[string]string{ + "database": "github.com/go-home-admin/home/bootstrap/services/database", + "datatypes": "gorm.io/datatypes", +} + +// 获得 table => map{alias => github.com/*} +func getImports(tableColumns map[string][]tableColumn) map[string]string { + got := make(map[string]string) + for table, columns := range tableColumns { + // 初始引入 + tm := map[string]string{ + "gorm.io/gorm": "gorm", + "github.com/go-home-admin/home/bootstrap/providers": "providers", + "github.com/sirupsen/logrus": "logrus", + "database/sql": "sql", + } + for _, column := range columns { + index := strings.Index(column.GoType, ".") + if index != -1 { + as := strings.Replace(column.GoType[:index], "*", "", 1) + importStr := alias[as] + tm[importStr] = as + } + } + got[table] = parser.GetImportStrForMap(tm) + } + + return got +} + +func genOrmStruct(table string, columns []tableColumn, conf Conf) string { + TableName := parser.StringToHump(table) + + hasField := make(map[string]bool) + str := `type {TableName} struct {` + for _, column := range columns { + p := "" + if column.IsNullable { + p = "*" + } + hasField[column.ColumnName] = true + fieldName := parser.StringToHump(column.ColumnName) + str += fmt.Sprintf("\n\t%v %v%v`%v` // %v", fieldName, p, column.GoType, genGormTag(column), strings.ReplaceAll(column.Comment, "\n", " ")) + } + // 表依赖 + if helper, ok := conf["helper"]; ok { + helperConf := helper.(map[interface{}]interface{}) + tableConfig, ok := helperConf[table].([]interface{}) + if ok { + for _, c := range tableConfig { + cf := c.(map[interface{}]interface{}) + with := cf["with"] + tbName := parser.StringToHump(cf["table"].(string)) + switch with { + case "belongs_to": + str += fmt.Sprintf("\n\t%v %v `gorm:\"%v\"`", parser.StringToHump(cf["alias"].(string)), tbName, cf["gorm"]) + case "has_one": + str += fmt.Sprintf("\n\t%v %v `gorm:\"%v\"`", parser.StringToHump(cf["alias"].(string)), tbName, cf["gorm"]) + case "has_many": + str += fmt.Sprintf("\n\t%v []%v `gorm:\"%v\"`", parser.StringToHump(cf["alias"].(string)), tbName, cf["gorm"]) + case "many2many": + str += fmt.Sprintf("\n\t%v []%v `gorm:\"%v\"`", parser.StringToHump(cf["alias"].(string)), tbName, cf["gorm"]) + default: + panic("with: belongs_to,has_one,has_many,many2many") + } + } + } + } + + str = strings.ReplaceAll(str, "{TableName}", TableName) + return "\n" + str + "\n}" +} + +func genGormTag(column tableColumn) string { + var arr []string + // 字段 + arr = append(arr, "column:"+column.ColumnName) + if column.ColumnDefault == "CURRENT_TIMESTAMP" { + arr = append(arr, "autoUpdateTime") + } + if strings.Contains(column.ColumnDefault, "nextval") { + arr = append(arr, "autoIncrement") + } + // 类型ing + arr = append(arr, "type:"+column.PgType) + // 主键 + if column.IsPKey { + arr = append(arr, "primaryKey") + } + // default + if column.ColumnDefault != "" { + arr = append(arr, "default:"+column.ColumnDefault) + } + + if column.Comment != "" { + arr = append(arr, fmt.Sprintf("comment:'%v'", strings.ReplaceAll(column.Comment, "'", ""))) + } + str := "" + for i := 0; i < len(arr)-1; i++ { + str += arr[i] + ";" + } + str += "" + arr[len(arr)-1] + return "gorm:\"" + str + "\"" +} + +type DB struct { + db *sql.DB +} + +func (d *DB) tableColumns() map[string][]tableColumn { + var sqlStr = "SELECT tablename FROM pg_tables WHERE schemaname = 'public'" + + rows, err := d.db.Query(sqlStr) + if err != nil { + log.Println("Error reading table information: ", err.Error()) + return nil + } + defer rows.Close() + ormColumns := make(map[string][]tableColumn) + for rows.Next() { + var tableName string + var pkey string + _ = rows.Scan( + &tableName, + ) + _rows, _ := d.db.Query(` +SELECT i.column_name, i.column_default, i.is_nullable, i.udt_name, col_description(a.attrelid,a.attnum) as comment +FROM information_schema.columns as i +LEFT JOIN pg_class as c on c.relname = i.table_name +LEFT JOIN pg_attribute as a on a.attrelid = c.oid and a.attname = i.column_name +WHERE table_schema = 'public' and i.table_name = $1; + `, tableName) + defer _rows.Close() + //获取主键 + __rows, _ := d.db.Query(` +SELECT pg_attribute.attname +FROM pg_constraint +INNER JOIN pg_class ON pg_constraint.conrelid = pg_class.oid +INNER JOIN pg_attribute ON pg_attribute.attrelid = pg_class.oid +AND pg_attribute.attnum = pg_constraint.conkey [ 1 ] +INNER JOIN pg_type ON pg_type.oid = pg_attribute.atttypid +WHERE pg_class.relname = $1 AND pg_constraint.contype = 'p' + `, tableName) + defer __rows.Close() + for __rows.Next() { + _ = __rows.Scan(&pkey) + } + for _rows.Next() { + var ( + column_name string + column_default *string + is_nullable string + udt_name string + comment *string + ) + err = _rows.Scan(&column_name, &column_default, &is_nullable, &udt_name, &comment) + if err != nil { + panic(err) + } + var columnComment string + if comment != nil { + columnComment = *comment + } + var ColumnDefault string + if column_default != nil { + ColumnDefault = *column_default + } + + ormColumns[tableName] = append(ormColumns[tableName], tableColumn{ + ColumnName: parser.StringToHump(column_name), + ColumnDefault: ColumnDefault, + PgType: udt_name, + GoType: PgTypeToGoType(udt_name, column_name), + IsNullable: is_nullable == "YES", + IsPKey: false, + Comment: columnComment, + }) + } + } + return ormColumns +} + +type tableColumn struct { + // 驼峰命名的字段 + ColumnName string + ColumnDefault string + PgType string + GoType string + IsNullable bool + IsPKey bool + Comment string +} + +func PgTypeToGoType(pgType string, columnName string) string { + switch pgType { + case "int2", "int4": + return "int32" + case "int8": + return "int64" + case "date": + return "datatypes.Date" + case "json", "jsonb": + return "database.JSON" + case "time", "timetz": + return "database.Time" + case "numeric": + return "float64" + default: + if strings.Contains(pgType, "timestamp") { + if columnName == "deleted_at" { + return "gorm.DeletedAt" + } else { + return "database.Time" + } + } + return "string" + } +} + +func NewDb(conf map[interface{}]interface{}) *DB { + config := services.NewConfig(conf) + connStr := fmt.Sprintf( + "postgres://%s:%s@%s:%d/%s?sslmode=disable", + config.GetString("username", "root"), + config.GetString("password", "123456"), + config.GetString("host", "localhost:"), + config.GetInt("port", 5432), + config.GetString("database", "demo"), + ) + db, err := sql.Open("postgres", connStr) + if err != nil { + panic(err) + } + // See "Important settings" section. + db.SetConnMaxLifetime(time.Minute * 3) + db.SetMaxOpenConns(10) + db.SetMaxIdleConns(10) + + return &DB{ + db: db, + } +} diff --git a/console/commands/pgorm/pgsql.go.text b/console/commands/pgorm/pgsql.go.text new file mode 100644 index 0000000..89add6f --- /dev/null +++ b/console/commands/pgorm/pgsql.go.text @@ -0,0 +1,223 @@ + +func (receiver *{orm_table_name}) TableName() string { + return "{table_name}" +} + +type Orm{orm_table_name} struct { + db *gorm.DB +} + +func NewOrm{orm_table_name}() *Orm{orm_table_name} { + orm := &Orm{orm_table_name}{} + orm.db = providers.NewMysqlProvider().GetBean("{db}").(*gorm.DB) + return orm +} + +func (orm *Orm{orm_table_name}) GetDB() *gorm.DB { + return orm.db +} + +// Create insert the value into database +func (orm *Orm{orm_table_name}) Create(value interface{}) *gorm.DB { + return orm.db.Create(value) +} + +// CreateInBatches insert the value in batches into database +func (orm *Orm{orm_table_name}) CreateInBatches(value interface{}, batchSize int) *gorm.DB { + return orm.db.CreateInBatches(value, batchSize) +} + +// Save update value in database, if the value doesn't have primary key, will insert it +func (orm *Orm{orm_table_name}) Save(value interface{}) *gorm.DB { + return orm.db.Save(value) +} + +func (orm *Orm{orm_table_name}) Row() *sql.Row { + return orm.db.Row() +} + +func (orm *Orm{orm_table_name}) Rows() (*sql.Rows, error) { + return orm.db.Rows() +} + +// Scan scan value to a struct +func (orm *Orm{orm_table_name}) Scan(dest interface{}) *gorm.DB { + return orm.db.Scan(dest) +} + +func (orm *Orm{orm_table_name}) ScanRows(rows *sql.Rows, dest interface{}) error { + return orm.db.ScanRows(rows, dest) +} + +// Connection use a db conn to execute Multiple commands,this conn will put conn pool after it is executed. +func (orm *Orm{orm_table_name}) Connection(fc func(tx *gorm.DB) error) (err error) { + return orm.db.Connection(fc) +} + +// Transaction start a transaction as a block, return error will rollback, otherwise to commit. +func (orm *Orm{orm_table_name}) Transaction(fc func(tx *gorm.DB) error, opts ...*sql.TxOptions) (err error) { + return orm.db.Transaction(fc, opts...) +} + +// Begin begins a transaction +func (orm *Orm{orm_table_name}) Begin(opts ...*sql.TxOptions) *gorm.DB { + return orm.db.Begin(opts...) +} + +// Commit commit a transaction +func (orm *Orm{orm_table_name}) Commit() *gorm.DB { + return orm.db.Commit() +} + +// Rollback rollback a transaction +func (orm *Orm{orm_table_name}) Rollback() *gorm.DB { + return orm.db.Rollback() +} + +func (orm *Orm{orm_table_name}) SavePoint(name string) *gorm.DB { + return orm.db.SavePoint(name) +} + +func (orm *Orm{orm_table_name}) RollbackTo(name string) *gorm.DB { + return orm.db.RollbackTo(name) +} + +// Exec execute raw sql +func (orm *Orm{orm_table_name}) Exec(sql string, values ...interface{}) *gorm.DB { + return orm.db.Exec(sql, values...) +} + +// ------------ 以下是单表独有的函数, 便捷字段条件, Laravel风格操作 --------- + +func (orm *Orm{orm_table_name}) Insert(row *{orm_table_name}) *gorm.DB { + return orm.db.Create(row) +} + +func (orm *Orm{orm_table_name}) Inserts(rows []*{orm_table_name}) *gorm.DB { + return orm.db.Create(rows) +} + +func (orm *Orm{orm_table_name}) Order(value interface{}) *Orm{orm_table_name} { + orm.db.Order(value) + return orm +} + +func (orm *Orm{orm_table_name}) Limit(limit int) *Orm{orm_table_name} { + orm.db.Limit(limit) + return orm +} + +func (orm *Orm{orm_table_name}) Offset(offset int) *Orm{orm_table_name} { + orm.db.Offset(offset) + return orm +} +// 直接查询列表, 如果需要条数, 使用Find() +func (orm *Orm{orm_table_name}) Get() {orm_table_name}List { + got, _ := orm.Find() + return got +} + +// Pluck used to query single column from a model as a map +// var ages []int64 +// db.Model(&users).Pluck("age", &ages) +func (orm *Orm{orm_table_name}) Pluck(column string, dest interface{}) *gorm.DB { + return orm.db.Model(&{orm_table_name}{}).Pluck(column, dest) +} + +// Delete 有条件删除 +func (orm *Orm{orm_table_name}) Delete(conds ...interface{}) *gorm.DB { + return orm.db.Delete(&{orm_table_name}{}, conds...) +} + +// DeleteAll 删除所有 +func (orm *Orm{orm_table_name}) DeleteAll() *gorm.DB { + return orm.db.Exec("DELETE FROM {table_name}") +} + +func (orm *Orm{orm_table_name}) Count() int64 { + var count int64 + orm.db.Model(&{orm_table_name}{}).Count(&count) + return count +} + +// First 检索单个对象 +func (orm *Orm{orm_table_name}) First(conds ...interface{}) (*{orm_table_name}, bool) { + dest := &{orm_table_name}{} + db := orm.db.Limit(1).Find(dest, conds...) + return dest, db.RowsAffected == 1 +} + +// Take return a record that match given conditions, the order will depend on the database implementation +func (orm *Orm{orm_table_name}) Take(conds ...interface{}) (*{orm_table_name}, int64) { + dest := &{orm_table_name}{} + db := orm.db.Take(dest, conds...) + return dest, db.RowsAffected +} + +// Last find last record that match given conditions, order by primary key +func (orm *Orm{orm_table_name}) Last(conds ...interface{}) (*{orm_table_name}, int64) { + dest := &{orm_table_name}{} + db := orm.db.Last(dest, conds...) + return dest, db.RowsAffected +} + +func (orm *Orm{orm_table_name}) Find(conds ...interface{}) ({orm_table_name}List, int64) { + list := make([]*{orm_table_name}, 0) + tx := orm.db.Model(&{orm_table_name}{}).Find(&list, conds...) + if tx.Error != nil { + logrus.Error(tx.Error) + } + return list, tx.RowsAffected +} + +// FindInBatches find records in batches +func (orm *Orm{orm_table_name}) FindInBatches(dest interface{}, batchSize int, fc func(tx *gorm.DB, batch int) error) *gorm.DB { + return orm.db.FindInBatches(dest, batchSize, fc) +} + +// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions) +func (orm *Orm{orm_table_name}) FirstOrInit(dest *{orm_table_name}, conds ...interface{}) (*{orm_table_name}, *gorm.DB) { + return dest, orm.db.FirstOrInit(dest, conds...) +} + +// FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions) +func (orm *Orm{orm_table_name}) FirstOrCreate(dest interface{}, conds ...interface{}) *gorm.DB { + return orm.db.FirstOrCreate(dest, conds...) +} + +// Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields +func (orm *Orm{orm_table_name}) Update(column string, value interface{}) *gorm.DB { + return orm.db.Model(&{orm_table_name}{}).Update(column, value) +} + +// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields +func (orm *Orm{orm_table_name}) Updates(values interface{}) *gorm.DB { + return orm.db.Model(&{orm_table_name}{}).Updates(values) +} + +func (orm *Orm{orm_table_name}) UpdateColumn(column string, value interface{}) *gorm.DB { + return orm.db.Model(&{orm_table_name}{}).UpdateColumn(column, value) +} + +func (orm *Orm{orm_table_name}) UpdateColumns(values interface{}) *gorm.DB { + return orm.db.Model(&{orm_table_name}{}).UpdateColumns(values) +} + +func (orm *Orm{orm_table_name}) Where(query interface{}, args ...interface{}) *Orm{orm_table_name} { + orm.db.Where(query, args...) + return orm +} + +func (orm *Orm{orm_table_name}) And(fuc func(orm *Orm{orm_table_name})) *Orm{orm_table_name} { + ormAnd := NewOrm{orm_table_name}() + fuc(ormAnd) + orm.db.Where(ormAnd.db) + return orm +} + +func (orm *Orm{orm_table_name}) Or(fuc func(orm *Orm{orm_table_name})) *Orm{orm_table_name} { + ormOr := NewOrm{orm_table_name}() + fuc(ormOr) + orm.db.Or(ormOr.db) + return orm +} diff --git a/go.mod b/go.mod index 3d5ef05..ae03d94 100644 --- a/go.mod +++ b/go.mod @@ -7,5 +7,6 @@ require ( github.com/go-home-admin/home v0.0.3 github.com/go-sql-driver/mysql v1.6.0 github.com/joho/godotenv v1.4.0 + github.com/lib/pq v1.10.6 gopkg.in/yaml.v2 v2.4.0 ) diff --git a/go.sum b/go.sum index 9d1d77b..6d51729 100644 --- a/go.sum +++ b/go.sum @@ -78,6 +78,8 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= +github.com/lib/pq v1.10.6 h1:jbk+ZieJ0D7EVGJYpL9QTz7/YW6UHbmdnZWYyK5cdBs= +github.com/lib/pq v1.10.6/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=