This commit is contained in:
tangpanqing
2023-01-10 09:52:08 +08:00
parent bc43028389
commit 10a7bc9230
8 changed files with 45 additions and 60 deletions

View File

@@ -130,13 +130,12 @@ func getFieldName(field interface{}) string {
}
//反射表名,优先从方法获取,没有方法则从名字获取
func getTableNameByReflect(typeOf reflect.Type) string {
func getTableNameByReflect(typeOf reflect.Type, valueOf reflect.Value) string {
method, isSet := typeOf.MethodByName("TableName")
fmt.Println("=isSet=")
fmt.Println(typeOf)
fmt.Println(isSet)
if isSet {
res := method.Func.Call(nil)
var paramList []reflect.Value
paramList = append(paramList, valueOf)
res := method.Func.Call(paramList)
return res[0].String()
} else {
arr := strings.Split(typeOf.String(), ".")

View File

@@ -78,7 +78,7 @@ func (b *Builder) getTableNameCommon(typeOf reflect.Type, valueOf reflect.Value)
return getTableNameByTable(b.table)
}
return getTableNameByReflect(typeOf)
return getTableNameByReflect(typeOf, valueOf)
}
// Insert 增加记录
@@ -175,24 +175,11 @@ func convertToPostgresSql(sqlStr string) string {
// InsertBatch 批量增加记录
func (b *Builder) InsertBatch(values interface{}) (int64, error) {
TypeOf := reflect.TypeOf(values)
ValueOf := reflect.ValueOf(values)
fmt.Println(TypeOf)
fmt.Println(ValueOf)
fmt.Println(TypeOf.Elem())
fmt.Println(ValueOf.Elem())
fmt.Println(ValueOf.NumField())
fmt.Println(ValueOf.Elem().NumField())
return 0, nil
var keys []string
var paramList []any
var place []string
valueOf := reflect.ValueOf(values).Elem()
fmt.Println(valueOf.NumField())
if valueOf.Len() == 0 {
return 0, errors.New("the data list for insert batch not found")
@@ -202,15 +189,15 @@ func (b *Builder) InsertBatch(values interface{}) (int64, error) {
for j := 0; j < valueOf.Len(); j++ {
var placeItem []string
for i := 0; i < valueOf.Index(j).NumField(); i++ {
isNotNull := valueOf.Index(j).Field(i).Field(0).Field(1).Bool()
for i := 0; i < valueOf.Index(j).Elem().NumField(); i++ {
isNotNull := valueOf.Index(j).Elem().Field(i).Field(0).Field(1).Bool()
if isNotNull {
if j == 0 {
key := helper.UnderLine(typeOf.Field(i).Name)
key := helper.UnderLine(typeOf.Elem().Field(i).Name)
keys = append(keys, key)
}
val := valueOf.Index(j).Field(i).Field(0).Field(0).Interface()
val := valueOf.Index(j).Elem().Field(i).Field(0).Field(0).Interface()
paramList = append(paramList, val)
placeItem = append(placeItem, "?")
}
@@ -219,7 +206,6 @@ func (b *Builder) InsertBatch(values interface{}) (int64, error) {
place = append(place, "("+strings.Join(placeItem, ",")+")")
}
fmt.Println("--InsertBatch--")
sqlStr := "INSERT INTO " + b.getTableNameCommon(typeOf, valueOf.Index(0)) + " (" + strings.Join(keys, ",") + ") VALUES " + strings.Join(place, ",")
if b.driverName == model.Postgres {

View File

@@ -84,7 +84,7 @@ func (b *Builder) handleSet(typeOf reflect.Type, valueOf reflect.Value, paramLis
//如果没有设置表名
if b.tableName == "" {
b.tableName = getTableNameByReflect(typeOf)
b.tableName = getTableNameByReflect(typeOf, valueOf)
}
var keys []string

View File

@@ -12,7 +12,7 @@ func (b *Builder) Having(dest interface{}) *Builder {
//如果没有设置表名
if b.tableName == "" {
b.tableName = getTableNameByReflect(typeOf)
b.tableName = getTableNameByReflect(typeOf, valueOf)
}
for i := 0; i < typeOf.Elem().NumField(); i++ {

View File

@@ -12,7 +12,7 @@ func (b *Builder) Where(dest interface{}) *Builder {
//如果没有设置表名
if b.tableName == "" {
b.tableName = getTableNameByReflect(typeOf)
b.tableName = getTableNameByReflect(typeOf, valueOf)
}
for i := 0; i < typeOf.Elem().NumField(); i++ {

View File

@@ -53,8 +53,8 @@ func (mm *MigrateExecutor) ShowCreateTable(tableName string) string {
}
//MigrateCommon 迁移的主要过程
func (mm *MigrateExecutor) MigrateCommon(tableName string, typeOf reflect.Type) error {
tableFromCode := mm.getTableFromCode(tableName, typeOf)
func (mm *MigrateExecutor) MigrateCommon(tableName string, typeOf reflect.Type, valueOf reflect.Value) error {
tableFromCode := mm.getTableFromCode(tableName, typeOf, valueOf)
columnsFromCode := mm.getColumnsFromCode(typeOf)
indexesFromCode := mm.getIndexesFromCode(typeOf, tableFromCode)
@@ -77,7 +77,7 @@ func (mm *MigrateExecutor) MigrateCommon(tableName string, typeOf reflect.Type)
return nil
}
func (mm *MigrateExecutor) getTableFromCode(tableName string, typeOf reflect.Type) Table {
func (mm *MigrateExecutor) getTableFromCode(tableName string, typeOf reflect.Type, valueOf reflect.Value) Table {
table := Table{
TableName: null.StringFrom(tableName),
Engine: null.StringFrom("MyISAM"),
@@ -86,7 +86,9 @@ func (mm *MigrateExecutor) getTableFromCode(tableName string, typeOf reflect.Typ
method, isSet := typeOf.MethodByName("TableOpinion")
if isSet {
valueList := method.Func.Call(nil)
var paramList []reflect.Value
paramList = append(paramList, valueOf)
valueList := method.Func.Call(paramList)
i := valueList[0].Interface()
m := i.(map[string]string)

View File

@@ -58,18 +58,20 @@ func (mi *Migrator) AutoMigrate(destList ...interface{}) {
for i := 0; i < len(destList); i++ {
dest := destList[i]
typeOf := reflect.TypeOf(dest)
tableName := getTableNameByReflect(typeOf)
mi.migrateCommon(tableName, typeOf)
valueOf := reflect.ValueOf(dest)
tableName := getTableNameByReflect(typeOf, valueOf)
mi.migrateCommon(tableName, typeOf, valueOf)
}
}
// Migrate 自动迁移数据库结构,需要输入数据库名,表名
func (mi *Migrator) Migrate(tableName string, dest interface{}) {
typeOf := reflect.TypeOf(dest)
mi.migrateCommon(tableName, typeOf)
valueOf := reflect.ValueOf(dest)
mi.migrateCommon(tableName, typeOf, valueOf)
}
func (mi *Migrator) migrateCommon(tableName string, typeOf reflect.Type) {
func (mi *Migrator) migrateCommon(tableName string, typeOf reflect.Type, valueOf reflect.Value) {
if mi.driverName == model.Mssql {
me := migrate_mssql.MigrateExecutor{
DriverName: mi.driverName,
@@ -89,7 +91,7 @@ func (mi *Migrator) migrateCommon(tableName string, typeOf reflect.Type) {
LinkCommon: mi.LinkCommon,
},
}
me.MigrateCommon(tableName, typeOf)
me.MigrateCommon(tableName, typeOf, valueOf)
}
if mi.driverName == model.Sqlite3 {
@@ -120,10 +122,12 @@ func (mi *Migrator) GetOpinionList() []model.OpinionItem {
}
//反射表名,优先从方法获取,没有方法则从名字获取
func getTableNameByReflect(typeOf reflect.Type) string {
func getTableNameByReflect(typeOf reflect.Type, valueOf reflect.Value) string {
method, isSet := typeOf.MethodByName("TableName")
if isSet {
res := method.Func.Call(nil)
var paramList []reflect.Value
paramList = append(paramList, valueOf)
res := method.Func.Call(paramList)
return res[0].String()
} else {
arr := strings.Split(typeOf.String(), ".")

View File

@@ -60,10 +60,6 @@ type Person struct {
Test null.Float `aorm:"type:double;comment:测试" json:"test"`
}
func (p *Person) TableName() string {
return "erp_person"
}
func (p *Person) TableOpinion() map[string]string {
return map[string]string{
"ENGINE": "InnoDB",
@@ -100,11 +96,12 @@ func TestAll(t *testing.T) {
aorm.Store(&articleVO)
aorm.Store(&personAge, &personWithArticleCount)
dbList := make([]aorm.DbContent, 0)
dbList = append(dbList, testSqlite3Connect())
dbList = append(dbList, testMysqlConnect())
dbList = append(dbList, testPostgresConnect())
dbList = append(dbList, testMssqlConnect())
var dbList = []aorm.DbContent{
testMysqlConnect(),
testSqlite3Connect(),
testPostgresConnect(),
testMssqlConnect(),
}
for i := 0; i < len(dbList); i++ {
dbItem := dbList[i]
@@ -115,7 +112,6 @@ func TestAll(t *testing.T) {
id := testInsert(dbItem.DriverName, dbItem.DbLink)
testInsertBatch(dbItem.DriverName, dbItem.DbLink)
break
testGetOne(dbItem.DriverName, dbItem.DbLink, id)
testGetMany(dbItem.DriverName, dbItem.DbLink)
testUpdate(dbItem.DriverName, dbItem.DbLink, id)
@@ -227,10 +223,8 @@ func testMssqlConnect() aorm.DbContent {
}
func testMigrate(driver string, db *sql.DB) {
//AutoMigrate
aorm.Migrator(db).Driver(driver).AutoMigrate(&person, &article, &student)
//Migrate
aorm.Migrator(db).Driver(driver).Migrate("person_1", &person)
}
@@ -249,15 +243,15 @@ func testInsert(driver string, db *sql.DB) int64 {
Test: null.FloatFrom(2),
}
id, errInsert := aorm.Db(db).Debug(true).Driver(driver).Insert(&obj)
id, errInsert := aorm.Db(db).Debug(false).Driver(driver).Insert(&obj)
if errInsert != nil {
panic(driver + " testInsert " + "found err: " + errInsert.Error())
}
//aorm.Db(db).Debug(false).Driver(driver).Insert(&Article{
// Type: null.IntFrom(0),
// PersonId: null.IntFrom(id),
// ArticleBody: null.StringFrom("文章内容"),
//})
aorm.Db(db).Debug(false).Driver(driver).Insert(&Article{
Type: null.IntFrom(0),
PersonId: null.IntFrom(id),
ArticleBody: null.StringFrom("文章内容"),
})
var personItem Person
err := aorm.Db(db).Table(&person).Debug(false).Driver(driver).Table(&person).WhereEq(&person.Id, id).OrderBy(&person.Id, builder.Desc).GetOne(&personItem)
@@ -290,9 +284,9 @@ func testInsert(driver string, db *sql.DB) int64 {
}
//测试非id主键
//aorm.Db(db).Debug(false).Driver(driver).Insert(&Student{
// Name: null.StringFrom("new student"),
//})
aorm.Db(db).Debug(false).Driver(driver).Insert(&Student{
Name: null.StringFrom("new student"),
})
return id
}
@@ -319,7 +313,7 @@ func testInsertBatch(driver string, db *sql.DB) int64 {
Test: null.FloatFrom(200.15987654321987654321),
})
count, err := aorm.Db(db).Debug(true).Driver(driver).InsertBatch(&batch)
count, err := aorm.Db(db).Debug(false).Driver(driver).InsertBatch(&batch)
if err != nil {
panic(driver + " testInsertBatch " + "found err:" + err.Error())
}