diff --git a/builder/builder.go b/builder/builder.go index 03933c0..42c0ac7 100644 --- a/builder/builder.go +++ b/builder/builder.go @@ -130,13 +130,12 @@ func getFieldName(field interface{}) string { } //反射表名,优先从方法获取,没有方法则从名字获取 -func getTableNameByReflect(typeOf reflect.Type) string { +func getTableNameByReflect(typeOf reflect.Type, valueOf reflect.Value) string { method, isSet := typeOf.MethodByName("TableName") - fmt.Println("=isSet=") - fmt.Println(typeOf) - fmt.Println(isSet) if isSet { - res := method.Func.Call(nil) + var paramList []reflect.Value + paramList = append(paramList, valueOf) + res := method.Func.Call(paramList) return res[0].String() } else { arr := strings.Split(typeOf.String(), ".") diff --git a/builder/crud.go b/builder/crud.go index 268b84e..4e11a35 100644 --- a/builder/crud.go +++ b/builder/crud.go @@ -78,7 +78,7 @@ func (b *Builder) getTableNameCommon(typeOf reflect.Type, valueOf reflect.Value) return getTableNameByTable(b.table) } - return getTableNameByReflect(typeOf) + return getTableNameByReflect(typeOf, valueOf) } // Insert 增加记录 @@ -175,24 +175,11 @@ func convertToPostgresSql(sqlStr string) string { // InsertBatch 批量增加记录 func (b *Builder) InsertBatch(values interface{}) (int64, error) { - - TypeOf := reflect.TypeOf(values) - ValueOf := reflect.ValueOf(values) - fmt.Println(TypeOf) - fmt.Println(ValueOf) - - fmt.Println(TypeOf.Elem()) - fmt.Println(ValueOf.Elem()) - fmt.Println(ValueOf.NumField()) - fmt.Println(ValueOf.Elem().NumField()) - return 0, nil - var keys []string var paramList []any var place []string valueOf := reflect.ValueOf(values).Elem() - fmt.Println(valueOf.NumField()) if valueOf.Len() == 0 { return 0, errors.New("the data list for insert batch not found") @@ -202,15 +189,15 @@ func (b *Builder) InsertBatch(values interface{}) (int64, error) { for j := 0; j < valueOf.Len(); j++ { var placeItem []string - for i := 0; i < valueOf.Index(j).NumField(); i++ { - isNotNull := valueOf.Index(j).Field(i).Field(0).Field(1).Bool() + for i := 0; i < valueOf.Index(j).Elem().NumField(); i++ { + isNotNull := valueOf.Index(j).Elem().Field(i).Field(0).Field(1).Bool() if isNotNull { if j == 0 { - key := helper.UnderLine(typeOf.Field(i).Name) + key := helper.UnderLine(typeOf.Elem().Field(i).Name) keys = append(keys, key) } - val := valueOf.Index(j).Field(i).Field(0).Field(0).Interface() + val := valueOf.Index(j).Elem().Field(i).Field(0).Field(0).Interface() paramList = append(paramList, val) placeItem = append(placeItem, "?") } @@ -219,7 +206,6 @@ func (b *Builder) InsertBatch(values interface{}) (int64, error) { place = append(place, "("+strings.Join(placeItem, ",")+")") } - fmt.Println("--InsertBatch--") sqlStr := "INSERT INTO " + b.getTableNameCommon(typeOf, valueOf.Index(0)) + " (" + strings.Join(keys, ",") + ") VALUES " + strings.Join(place, ",") if b.driverName == model.Postgres { diff --git a/builder/handle.go b/builder/handle.go index 87281a9..74278f7 100644 --- a/builder/handle.go +++ b/builder/handle.go @@ -84,7 +84,7 @@ func (b *Builder) handleSet(typeOf reflect.Type, valueOf reflect.Value, paramLis //如果没有设置表名 if b.tableName == "" { - b.tableName = getTableNameByReflect(typeOf) + b.tableName = getTableNameByReflect(typeOf, valueOf) } var keys []string diff --git a/builder/having.go b/builder/having.go index 5b8d2b0..2d39e27 100644 --- a/builder/having.go +++ b/builder/having.go @@ -12,7 +12,7 @@ func (b *Builder) Having(dest interface{}) *Builder { //如果没有设置表名 if b.tableName == "" { - b.tableName = getTableNameByReflect(typeOf) + b.tableName = getTableNameByReflect(typeOf, valueOf) } for i := 0; i < typeOf.Elem().NumField(); i++ { diff --git a/builder/where.go b/builder/where.go index d72bc85..d84920f 100644 --- a/builder/where.go +++ b/builder/where.go @@ -12,7 +12,7 @@ func (b *Builder) Where(dest interface{}) *Builder { //如果没有设置表名 if b.tableName == "" { - b.tableName = getTableNameByReflect(typeOf) + b.tableName = getTableNameByReflect(typeOf, valueOf) } for i := 0; i < typeOf.Elem().NumField(); i++ { diff --git a/migrate_mysql/migrate.go b/migrate_mysql/migrate.go index 11f4776..1d42e23 100644 --- a/migrate_mysql/migrate.go +++ b/migrate_mysql/migrate.go @@ -53,8 +53,8 @@ func (mm *MigrateExecutor) ShowCreateTable(tableName string) string { } //MigrateCommon 迁移的主要过程 -func (mm *MigrateExecutor) MigrateCommon(tableName string, typeOf reflect.Type) error { - tableFromCode := mm.getTableFromCode(tableName, typeOf) +func (mm *MigrateExecutor) MigrateCommon(tableName string, typeOf reflect.Type, valueOf reflect.Value) error { + tableFromCode := mm.getTableFromCode(tableName, typeOf, valueOf) columnsFromCode := mm.getColumnsFromCode(typeOf) indexesFromCode := mm.getIndexesFromCode(typeOf, tableFromCode) @@ -77,7 +77,7 @@ func (mm *MigrateExecutor) MigrateCommon(tableName string, typeOf reflect.Type) return nil } -func (mm *MigrateExecutor) getTableFromCode(tableName string, typeOf reflect.Type) Table { +func (mm *MigrateExecutor) getTableFromCode(tableName string, typeOf reflect.Type, valueOf reflect.Value) Table { table := Table{ TableName: null.StringFrom(tableName), Engine: null.StringFrom("MyISAM"), @@ -86,7 +86,9 @@ func (mm *MigrateExecutor) getTableFromCode(tableName string, typeOf reflect.Typ method, isSet := typeOf.MethodByName("TableOpinion") if isSet { - valueList := method.Func.Call(nil) + var paramList []reflect.Value + paramList = append(paramList, valueOf) + valueList := method.Func.Call(paramList) i := valueList[0].Interface() m := i.(map[string]string) diff --git a/migrator/migrator.go b/migrator/migrator.go index 10bfd95..e2791a4 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -58,18 +58,20 @@ func (mi *Migrator) AutoMigrate(destList ...interface{}) { for i := 0; i < len(destList); i++ { dest := destList[i] typeOf := reflect.TypeOf(dest) - tableName := getTableNameByReflect(typeOf) - mi.migrateCommon(tableName, typeOf) + valueOf := reflect.ValueOf(dest) + tableName := getTableNameByReflect(typeOf, valueOf) + mi.migrateCommon(tableName, typeOf, valueOf) } } // Migrate 自动迁移数据库结构,需要输入数据库名,表名 func (mi *Migrator) Migrate(tableName string, dest interface{}) { typeOf := reflect.TypeOf(dest) - mi.migrateCommon(tableName, typeOf) + valueOf := reflect.ValueOf(dest) + mi.migrateCommon(tableName, typeOf, valueOf) } -func (mi *Migrator) migrateCommon(tableName string, typeOf reflect.Type) { +func (mi *Migrator) migrateCommon(tableName string, typeOf reflect.Type, valueOf reflect.Value) { if mi.driverName == model.Mssql { me := migrate_mssql.MigrateExecutor{ DriverName: mi.driverName, @@ -89,7 +91,7 @@ func (mi *Migrator) migrateCommon(tableName string, typeOf reflect.Type) { LinkCommon: mi.LinkCommon, }, } - me.MigrateCommon(tableName, typeOf) + me.MigrateCommon(tableName, typeOf, valueOf) } if mi.driverName == model.Sqlite3 { @@ -120,10 +122,12 @@ func (mi *Migrator) GetOpinionList() []model.OpinionItem { } //反射表名,优先从方法获取,没有方法则从名字获取 -func getTableNameByReflect(typeOf reflect.Type) string { +func getTableNameByReflect(typeOf reflect.Type, valueOf reflect.Value) string { method, isSet := typeOf.MethodByName("TableName") if isSet { - res := method.Func.Call(nil) + var paramList []reflect.Value + paramList = append(paramList, valueOf) + res := method.Func.Call(paramList) return res[0].String() } else { arr := strings.Split(typeOf.String(), ".") diff --git a/test/aorm_test.go b/test/aorm_test.go index a779ab9..4b521e8 100644 --- a/test/aorm_test.go +++ b/test/aorm_test.go @@ -60,10 +60,6 @@ type Person struct { Test null.Float `aorm:"type:double;comment:测试" json:"test"` } -func (p *Person) TableName() string { - return "erp_person" -} - func (p *Person) TableOpinion() map[string]string { return map[string]string{ "ENGINE": "InnoDB", @@ -100,11 +96,12 @@ func TestAll(t *testing.T) { aorm.Store(&articleVO) aorm.Store(&personAge, &personWithArticleCount) - dbList := make([]aorm.DbContent, 0) - dbList = append(dbList, testSqlite3Connect()) - dbList = append(dbList, testMysqlConnect()) - dbList = append(dbList, testPostgresConnect()) - dbList = append(dbList, testMssqlConnect()) + var dbList = []aorm.DbContent{ + testMysqlConnect(), + testSqlite3Connect(), + testPostgresConnect(), + testMssqlConnect(), + } for i := 0; i < len(dbList); i++ { dbItem := dbList[i] @@ -115,7 +112,6 @@ func TestAll(t *testing.T) { id := testInsert(dbItem.DriverName, dbItem.DbLink) testInsertBatch(dbItem.DriverName, dbItem.DbLink) - break testGetOne(dbItem.DriverName, dbItem.DbLink, id) testGetMany(dbItem.DriverName, dbItem.DbLink) testUpdate(dbItem.DriverName, dbItem.DbLink, id) @@ -227,10 +223,8 @@ func testMssqlConnect() aorm.DbContent { } func testMigrate(driver string, db *sql.DB) { - //AutoMigrate aorm.Migrator(db).Driver(driver).AutoMigrate(&person, &article, &student) - //Migrate aorm.Migrator(db).Driver(driver).Migrate("person_1", &person) } @@ -249,15 +243,15 @@ func testInsert(driver string, db *sql.DB) int64 { Test: null.FloatFrom(2), } - id, errInsert := aorm.Db(db).Debug(true).Driver(driver).Insert(&obj) + id, errInsert := aorm.Db(db).Debug(false).Driver(driver).Insert(&obj) if errInsert != nil { panic(driver + " testInsert " + "found err: " + errInsert.Error()) } - //aorm.Db(db).Debug(false).Driver(driver).Insert(&Article{ - // Type: null.IntFrom(0), - // PersonId: null.IntFrom(id), - // ArticleBody: null.StringFrom("文章内容"), - //}) + aorm.Db(db).Debug(false).Driver(driver).Insert(&Article{ + Type: null.IntFrom(0), + PersonId: null.IntFrom(id), + ArticleBody: null.StringFrom("文章内容"), + }) var personItem Person err := aorm.Db(db).Table(&person).Debug(false).Driver(driver).Table(&person).WhereEq(&person.Id, id).OrderBy(&person.Id, builder.Desc).GetOne(&personItem) @@ -290,9 +284,9 @@ func testInsert(driver string, db *sql.DB) int64 { } //测试非id主键 - //aorm.Db(db).Debug(false).Driver(driver).Insert(&Student{ - // Name: null.StringFrom("new student"), - //}) + aorm.Db(db).Debug(false).Driver(driver).Insert(&Student{ + Name: null.StringFrom("new student"), + }) return id } @@ -319,7 +313,7 @@ func testInsertBatch(driver string, db *sql.DB) int64 { Test: null.FloatFrom(200.15987654321987654321), }) - count, err := aorm.Db(db).Debug(true).Driver(driver).InsertBatch(&batch) + count, err := aorm.Db(db).Debug(false).Driver(driver).InsertBatch(&batch) if err != nil { panic(driver + " testInsertBatch " + "found err:" + err.Error()) }