diff --git a/console/commands/orm/mysql.go b/console/commands/orm/mysql.go index ac8abaf..7054a40 100644 --- a/console/commands/orm/mysql.go +++ b/console/commands/orm/mysql.go @@ -10,6 +10,7 @@ import ( _ "github.com/go-sql-driver/mysql" "log" "os" + "strconv" "strings" "time" ) @@ -361,6 +362,8 @@ func (d *DB) GetDB() *sql.DB { return d.db } +// 获取所有表信息 +// 过滤分表信息, table_{1-9} 只返回table func (d *DB) tableColumns() map[string][]tableColumn { var sqlStr = `SELECT TABLE_CATALOG, @@ -446,7 +449,27 @@ ORDER BY tableColumns[col.TABLE_NAME] = append(tableColumns[col.TABLE_NAME], col) } - return tableColumns + return Filter(tableColumns) +} + +// Filter 过滤分表格式 +// table_{0-9} 只返回table +func Filter(tableColumns map[string][]tableColumn) map[string][]tableColumn { + got := make(map[string][]tableColumn) + for tableName, columns := range tableColumns { + arr := strings.Split(tableName, "_") + arrLen := len(arr) + if arrLen > 1 { + str := arr[arrLen-1] + _, err := strconv.Atoi(str) + if err == nil { + tableName = strings.ReplaceAll(tableName, "_"+str, "") + } + } + + got[tableName] = columns + } + return got } func (d *DB) tableIndex() map[string]map[string][]tableColumnIndex {