mirror of
https://github.com/tangpanqing/aorm.git
synced 2025-09-27 04:16:26 +08:00
103 lines
2.6 KiB
Go
103 lines
2.6 KiB
Go
package migrator
|
|
|
|
import (
|
|
"github.com/tangpanqing/aorm/base"
|
|
"github.com/tangpanqing/aorm/builder"
|
|
"github.com/tangpanqing/aorm/driver"
|
|
"github.com/tangpanqing/aorm/migrate_mssql"
|
|
"github.com/tangpanqing/aorm/migrate_mysql"
|
|
"github.com/tangpanqing/aorm/migrate_postgres"
|
|
"github.com/tangpanqing/aorm/migrate_sqlite3"
|
|
"github.com/tangpanqing/aorm/utils"
|
|
"reflect"
|
|
"strings"
|
|
)
|
|
|
|
type Migrator struct {
|
|
//数据库操作连接
|
|
Link base.Link
|
|
}
|
|
|
|
//ShowCreateTable 获取创建表的ddl
|
|
func (mi *Migrator) ShowCreateTable(tableName string) string {
|
|
if mi.Link.DriverName() == driver.Mysql {
|
|
me := migrate_mysql.MigrateExecutor{
|
|
Builder: &builder.Builder{
|
|
Link: mi.Link,
|
|
},
|
|
}
|
|
return me.ShowCreateTable(tableName)
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// AutoMigrate 迁移数据库结构,需要输入数据库名,表名自动获取
|
|
func (mi *Migrator) AutoMigrate(destList ...interface{}) {
|
|
for i := 0; i < len(destList); i++ {
|
|
dest := destList[i]
|
|
typeOf := reflect.TypeOf(dest)
|
|
valueOf := reflect.ValueOf(dest)
|
|
tableName := getTableNameByReflect(typeOf, valueOf)
|
|
mi.migrateCommon(tableName, typeOf, valueOf)
|
|
}
|
|
}
|
|
|
|
// Migrate 自动迁移数据库结构,需要输入数据库名,表名
|
|
func (mi *Migrator) Migrate(tableName string, dest interface{}) {
|
|
typeOf := reflect.TypeOf(dest)
|
|
valueOf := reflect.ValueOf(dest)
|
|
mi.migrateCommon(tableName, typeOf, valueOf)
|
|
}
|
|
|
|
func (mi *Migrator) migrateCommon(tableName string, typeOf reflect.Type, valueOf reflect.Value) {
|
|
if mi.Link.DriverName() == driver.Mssql {
|
|
me := migrate_mssql.MigrateExecutor{
|
|
Builder: &builder.Builder{
|
|
Link: mi.Link,
|
|
},
|
|
}
|
|
me.MigrateCommon(tableName, typeOf)
|
|
}
|
|
|
|
if mi.Link.DriverName() == driver.Mysql {
|
|
me := migrate_mysql.MigrateExecutor{
|
|
Builder: &builder.Builder{
|
|
Link: mi.Link,
|
|
},
|
|
}
|
|
me.MigrateCommon(tableName, typeOf, valueOf)
|
|
}
|
|
|
|
if mi.Link.DriverName() == driver.Sqlite3 {
|
|
me := migrate_sqlite3.MigrateExecutor{
|
|
Builder: &builder.Builder{
|
|
Link: mi.Link,
|
|
},
|
|
}
|
|
me.MigrateCommon(tableName, typeOf)
|
|
}
|
|
|
|
if mi.Link.DriverName() == driver.Postgres {
|
|
me := migrate_postgres.MigrateExecutor{
|
|
Builder: &builder.Builder{
|
|
Link: mi.Link,
|
|
},
|
|
}
|
|
me.MigrateCommon(tableName, typeOf, valueOf)
|
|
}
|
|
}
|
|
|
|
//反射表名,优先从方法获取,没有方法则从名字获取
|
|
func getTableNameByReflect(typeOf reflect.Type, valueOf reflect.Value) string {
|
|
method, isSet := typeOf.MethodByName("TableName")
|
|
if isSet {
|
|
var paramList []reflect.Value
|
|
paramList = append(paramList, valueOf)
|
|
res := method.Func.Call(paramList)
|
|
return res[0].String()
|
|
} else {
|
|
arr := strings.Split(typeOf.String(), ".")
|
|
return utils.UnderLine(arr[len(arr)-1])
|
|
}
|
|
}
|