feature: PGORM支持指定Schema生成

This commit is contained in:
zodial
2025-10-22 16:23:17 +08:00
parent a41bc8b016
commit 40c4033361
2 changed files with 29 additions and 11 deletions

View File

@@ -33,6 +33,11 @@ func (OrmCommand) Configure() command.Configure {
Description: "输出文件",
Default: "@root/app/entity",
},
{
Name: "schema",
Description: "specified schema, only works in PostgreSql",
Default: "public",
},
},
},
}
@@ -44,6 +49,7 @@ func (OrmCommand) Execute(input command.Input) {
file = strings.Replace(file, "@root", root, 1)
outBase := input.GetOption("out")
outBase = strings.Replace(outBase, "@root", root, 1)
schema := input.GetOption("schema")
err := godotenv.Load(root + "/.env")
if err != nil {
@@ -66,7 +72,7 @@ func (OrmCommand) Execute(input command.Input) {
case "mysql":
orm.GenMysql(s.(string), conf, out)
case "postgresql":
pgorm.GenSql(s.(string), conf, out)
pgorm.GenSql(s.(string), conf, out, schema)
}
cmd := exec.Command("go", []string{"fmt", out}...)

View File

@@ -26,17 +26,25 @@ func IsExist(f string) bool {
type Conf map[interface{}]interface{}
func GenSql(name string, conf Conf, out string) {
func GenSql(name string, conf Conf, out string, schema string) {
relationshipName := name
packageName := name
if schema != "public" {
relationshipName = fmt.Sprintf("%s.%s", relationshipName, schema)
packageName = schema
out += "/" + schema
}
if !IsExist(out) {
os.MkdirAll(out, 0766)
}
db := NewDb(conf)
tableInfos := db.tableColumns()
tableInfos := db.tableColumns(schema)
tableColumns := tableInfos.Columns
root, _ := os.Getwd()
file, err := os.ReadFile(root + "/config/database/" + name + ".json")
file, err := os.ReadFile(root + "/config/database/" + relationshipName + ".json")
var relationship map[string][]*orm.Relationship
if err == nil {
err = json.Unmarshal(file, &relationship)
@@ -50,13 +58,17 @@ func GenSql(name string, conf Conf, out string) {
for table, columns := range tableColumns {
tableConfig := tableInfos.Infos[table]
mysqlTableName := parser.StringToSnake(table)
tableName := table
if schema != "public" {
tableName = fmt.Sprintf("%s.%s", schema, table)
}
file := out + "/" + mysqlTableName
if _, err := os.Stat(file + "_lock.go"); !os.IsNotExist(err) {
continue
}
str := "package " + name
str := "package " + packageName
str += "\nimport (" + imports[table] + "\n)"
str += "\n" + genOrmStruct(table, columns, conf, relationship[table])
@@ -68,7 +80,7 @@ func GenSql(name string, conf Conf, out string) {
}
for old, newStr := range map[string]string{
"{orm_table_name}": parser.StringToHump(table),
"{table_name}": table,
"{table_name}": tableName,
"{connect_name}": name,
} {
baseFunStr = strings.ReplaceAll(baseFunStr, old, newStr)
@@ -391,8 +403,8 @@ func (d *DB) GetDB() *sql.DB {
// 获取所有表信息
// 过滤分表信息, table_{1-9} 只返回table
func (d *DB) tableColumns() TableInfo {
var sqlStr = "SELECT tablename FROM pg_tables WHERE schemaname = 'public'"
func (d *DB) tableColumns(schema string) TableInfo {
var sqlStr = fmt.Sprintf("SELECT tablename FROM pg_tables WHERE schemaname = '%s'", schema)
rows, err := d.db.Query(sqlStr)
if err != nil {
@@ -407,15 +419,15 @@ func (d *DB) tableColumns() TableInfo {
_ = rows.Scan(
&tableName,
)
_rows, _ := d.db.Query(`
_rows, _ := d.db.Query(fmt.Sprintf(`
SELECT i.column_name, i.column_default, i.is_nullable, i.udt_name, col_description(a.attrelid,a.attnum) as comment, ixc.relname
FROM information_schema.columns as i
LEFT JOIN pg_class as c on c.relname = i.table_name
LEFT JOIN pg_attribute as a on a.attrelid = c.oid and a.attname = i.column_name
LEFT JOIN pg_index ix ON c.oid = ix.indrelid AND a.attnum = ANY(ix.indkey)
LEFT JOIN pg_class ixc ON ixc.oid = ix.indexrelid
WHERE table_schema = 'public' and i.table_name = $1;
`, tableName)
WHERE table_schema = '%s' and i.table_name = '%s';
`, schema, tableName))
defer _rows.Close()
//获取主键
__rows, _ := d.db.Query(`