From 40c40333611f2501a485aa2ac44b019a3d45c271 Mon Sep 17 00:00:00 2001 From: zodial Date: Wed, 22 Oct 2025 16:23:17 +0800 Subject: [PATCH] =?UTF-8?q?feature:=20PGORM=E6=94=AF=E6=8C=81=E6=8C=87?= =?UTF-8?q?=E5=AE=9ASchema=E7=94=9F=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- console/commands/orm.go | 8 +++++++- console/commands/pgorm/pgsql.go | 32 ++++++++++++++++++++++---------- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/console/commands/orm.go b/console/commands/orm.go index df8961e..e53db98 100644 --- a/console/commands/orm.go +++ b/console/commands/orm.go @@ -33,6 +33,11 @@ func (OrmCommand) Configure() command.Configure { Description: "输出文件", Default: "@root/app/entity", }, + { + Name: "schema", + Description: "specified schema, only works in PostgreSql", + Default: "public", + }, }, }, } @@ -44,6 +49,7 @@ func (OrmCommand) Execute(input command.Input) { file = strings.Replace(file, "@root", root, 1) outBase := input.GetOption("out") outBase = strings.Replace(outBase, "@root", root, 1) + schema := input.GetOption("schema") err := godotenv.Load(root + "/.env") if err != nil { @@ -66,7 +72,7 @@ func (OrmCommand) Execute(input command.Input) { case "mysql": orm.GenMysql(s.(string), conf, out) case "postgresql": - pgorm.GenSql(s.(string), conf, out) + pgorm.GenSql(s.(string), conf, out, schema) } cmd := exec.Command("go", []string{"fmt", out}...) diff --git a/console/commands/pgorm/pgsql.go b/console/commands/pgorm/pgsql.go index 762d459..cd26208 100644 --- a/console/commands/pgorm/pgsql.go +++ b/console/commands/pgorm/pgsql.go @@ -26,17 +26,25 @@ func IsExist(f string) bool { type Conf map[interface{}]interface{} -func GenSql(name string, conf Conf, out string) { +func GenSql(name string, conf Conf, out string, schema string) { + relationshipName := name + packageName := name + if schema != "public" { + relationshipName = fmt.Sprintf("%s.%s", relationshipName, schema) + packageName = schema + out += "/" + schema + } + if !IsExist(out) { os.MkdirAll(out, 0766) } db := NewDb(conf) - tableInfos := db.tableColumns() + tableInfos := db.tableColumns(schema) tableColumns := tableInfos.Columns root, _ := os.Getwd() - file, err := os.ReadFile(root + "/config/database/" + name + ".json") + file, err := os.ReadFile(root + "/config/database/" + relationshipName + ".json") var relationship map[string][]*orm.Relationship if err == nil { err = json.Unmarshal(file, &relationship) @@ -50,13 +58,17 @@ func GenSql(name string, conf Conf, out string) { for table, columns := range tableColumns { tableConfig := tableInfos.Infos[table] mysqlTableName := parser.StringToSnake(table) + tableName := table + if schema != "public" { + tableName = fmt.Sprintf("%s.%s", schema, table) + } file := out + "/" + mysqlTableName if _, err := os.Stat(file + "_lock.go"); !os.IsNotExist(err) { continue } - str := "package " + name + str := "package " + packageName str += "\nimport (" + imports[table] + "\n)" str += "\n" + genOrmStruct(table, columns, conf, relationship[table]) @@ -68,7 +80,7 @@ func GenSql(name string, conf Conf, out string) { } for old, newStr := range map[string]string{ "{orm_table_name}": parser.StringToHump(table), - "{table_name}": table, + "{table_name}": tableName, "{connect_name}": name, } { baseFunStr = strings.ReplaceAll(baseFunStr, old, newStr) @@ -391,8 +403,8 @@ func (d *DB) GetDB() *sql.DB { // 获取所有表信息 // 过滤分表信息, table_{1-9} 只返回table -func (d *DB) tableColumns() TableInfo { - var sqlStr = "SELECT tablename FROM pg_tables WHERE schemaname = 'public'" +func (d *DB) tableColumns(schema string) TableInfo { + var sqlStr = fmt.Sprintf("SELECT tablename FROM pg_tables WHERE schemaname = '%s'", schema) rows, err := d.db.Query(sqlStr) if err != nil { @@ -407,15 +419,15 @@ func (d *DB) tableColumns() TableInfo { _ = rows.Scan( &tableName, ) - _rows, _ := d.db.Query(` + _rows, _ := d.db.Query(fmt.Sprintf(` SELECT i.column_name, i.column_default, i.is_nullable, i.udt_name, col_description(a.attrelid,a.attnum) as comment, ixc.relname FROM information_schema.columns as i LEFT JOIN pg_class as c on c.relname = i.table_name LEFT JOIN pg_attribute as a on a.attrelid = c.oid and a.attname = i.column_name LEFT JOIN pg_index ix ON c.oid = ix.indrelid AND a.attnum = ANY(ix.indkey) LEFT JOIN pg_class ixc ON ixc.oid = ix.indexrelid -WHERE table_schema = 'public' and i.table_name = $1; - `, tableName) +WHERE table_schema = '%s' and i.table_name = '%s'; + `, schema, tableName)) defer _rows.Close() //获取主键 __rows, _ := d.db.Query(`