This commit is contained in:
tangpanqing
2023-01-10 09:52:08 +08:00
parent bc43028389
commit 10a7bc9230
8 changed files with 45 additions and 60 deletions

View File

@@ -130,13 +130,12 @@ func getFieldName(field interface{}) string {
} }
//反射表名,优先从方法获取,没有方法则从名字获取 //反射表名,优先从方法获取,没有方法则从名字获取
func getTableNameByReflect(typeOf reflect.Type) string { func getTableNameByReflect(typeOf reflect.Type, valueOf reflect.Value) string {
method, isSet := typeOf.MethodByName("TableName") method, isSet := typeOf.MethodByName("TableName")
fmt.Println("=isSet=")
fmt.Println(typeOf)
fmt.Println(isSet)
if isSet { if isSet {
res := method.Func.Call(nil) var paramList []reflect.Value
paramList = append(paramList, valueOf)
res := method.Func.Call(paramList)
return res[0].String() return res[0].String()
} else { } else {
arr := strings.Split(typeOf.String(), ".") arr := strings.Split(typeOf.String(), ".")

View File

@@ -78,7 +78,7 @@ func (b *Builder) getTableNameCommon(typeOf reflect.Type, valueOf reflect.Value)
return getTableNameByTable(b.table) return getTableNameByTable(b.table)
} }
return getTableNameByReflect(typeOf) return getTableNameByReflect(typeOf, valueOf)
} }
// Insert 增加记录 // Insert 增加记录
@@ -175,24 +175,11 @@ func convertToPostgresSql(sqlStr string) string {
// InsertBatch 批量增加记录 // InsertBatch 批量增加记录
func (b *Builder) InsertBatch(values interface{}) (int64, error) { func (b *Builder) InsertBatch(values interface{}) (int64, error) {
TypeOf := reflect.TypeOf(values)
ValueOf := reflect.ValueOf(values)
fmt.Println(TypeOf)
fmt.Println(ValueOf)
fmt.Println(TypeOf.Elem())
fmt.Println(ValueOf.Elem())
fmt.Println(ValueOf.NumField())
fmt.Println(ValueOf.Elem().NumField())
return 0, nil
var keys []string var keys []string
var paramList []any var paramList []any
var place []string var place []string
valueOf := reflect.ValueOf(values).Elem() valueOf := reflect.ValueOf(values).Elem()
fmt.Println(valueOf.NumField())
if valueOf.Len() == 0 { if valueOf.Len() == 0 {
return 0, errors.New("the data list for insert batch not found") return 0, errors.New("the data list for insert batch not found")
@@ -202,15 +189,15 @@ func (b *Builder) InsertBatch(values interface{}) (int64, error) {
for j := 0; j < valueOf.Len(); j++ { for j := 0; j < valueOf.Len(); j++ {
var placeItem []string var placeItem []string
for i := 0; i < valueOf.Index(j).NumField(); i++ { for i := 0; i < valueOf.Index(j).Elem().NumField(); i++ {
isNotNull := valueOf.Index(j).Field(i).Field(0).Field(1).Bool() isNotNull := valueOf.Index(j).Elem().Field(i).Field(0).Field(1).Bool()
if isNotNull { if isNotNull {
if j == 0 { if j == 0 {
key := helper.UnderLine(typeOf.Field(i).Name) key := helper.UnderLine(typeOf.Elem().Field(i).Name)
keys = append(keys, key) keys = append(keys, key)
} }
val := valueOf.Index(j).Field(i).Field(0).Field(0).Interface() val := valueOf.Index(j).Elem().Field(i).Field(0).Field(0).Interface()
paramList = append(paramList, val) paramList = append(paramList, val)
placeItem = append(placeItem, "?") placeItem = append(placeItem, "?")
} }
@@ -219,7 +206,6 @@ func (b *Builder) InsertBatch(values interface{}) (int64, error) {
place = append(place, "("+strings.Join(placeItem, ",")+")") place = append(place, "("+strings.Join(placeItem, ",")+")")
} }
fmt.Println("--InsertBatch--")
sqlStr := "INSERT INTO " + b.getTableNameCommon(typeOf, valueOf.Index(0)) + " (" + strings.Join(keys, ",") + ") VALUES " + strings.Join(place, ",") sqlStr := "INSERT INTO " + b.getTableNameCommon(typeOf, valueOf.Index(0)) + " (" + strings.Join(keys, ",") + ") VALUES " + strings.Join(place, ",")
if b.driverName == model.Postgres { if b.driverName == model.Postgres {

View File

@@ -84,7 +84,7 @@ func (b *Builder) handleSet(typeOf reflect.Type, valueOf reflect.Value, paramLis
//如果没有设置表名 //如果没有设置表名
if b.tableName == "" { if b.tableName == "" {
b.tableName = getTableNameByReflect(typeOf) b.tableName = getTableNameByReflect(typeOf, valueOf)
} }
var keys []string var keys []string

View File

@@ -12,7 +12,7 @@ func (b *Builder) Having(dest interface{}) *Builder {
//如果没有设置表名 //如果没有设置表名
if b.tableName == "" { if b.tableName == "" {
b.tableName = getTableNameByReflect(typeOf) b.tableName = getTableNameByReflect(typeOf, valueOf)
} }
for i := 0; i < typeOf.Elem().NumField(); i++ { for i := 0; i < typeOf.Elem().NumField(); i++ {

View File

@@ -12,7 +12,7 @@ func (b *Builder) Where(dest interface{}) *Builder {
//如果没有设置表名 //如果没有设置表名
if b.tableName == "" { if b.tableName == "" {
b.tableName = getTableNameByReflect(typeOf) b.tableName = getTableNameByReflect(typeOf, valueOf)
} }
for i := 0; i < typeOf.Elem().NumField(); i++ { for i := 0; i < typeOf.Elem().NumField(); i++ {

View File

@@ -53,8 +53,8 @@ func (mm *MigrateExecutor) ShowCreateTable(tableName string) string {
} }
//MigrateCommon 迁移的主要过程 //MigrateCommon 迁移的主要过程
func (mm *MigrateExecutor) MigrateCommon(tableName string, typeOf reflect.Type) error { func (mm *MigrateExecutor) MigrateCommon(tableName string, typeOf reflect.Type, valueOf reflect.Value) error {
tableFromCode := mm.getTableFromCode(tableName, typeOf) tableFromCode := mm.getTableFromCode(tableName, typeOf, valueOf)
columnsFromCode := mm.getColumnsFromCode(typeOf) columnsFromCode := mm.getColumnsFromCode(typeOf)
indexesFromCode := mm.getIndexesFromCode(typeOf, tableFromCode) indexesFromCode := mm.getIndexesFromCode(typeOf, tableFromCode)
@@ -77,7 +77,7 @@ func (mm *MigrateExecutor) MigrateCommon(tableName string, typeOf reflect.Type)
return nil return nil
} }
func (mm *MigrateExecutor) getTableFromCode(tableName string, typeOf reflect.Type) Table { func (mm *MigrateExecutor) getTableFromCode(tableName string, typeOf reflect.Type, valueOf reflect.Value) Table {
table := Table{ table := Table{
TableName: null.StringFrom(tableName), TableName: null.StringFrom(tableName),
Engine: null.StringFrom("MyISAM"), Engine: null.StringFrom("MyISAM"),
@@ -86,7 +86,9 @@ func (mm *MigrateExecutor) getTableFromCode(tableName string, typeOf reflect.Typ
method, isSet := typeOf.MethodByName("TableOpinion") method, isSet := typeOf.MethodByName("TableOpinion")
if isSet { if isSet {
valueList := method.Func.Call(nil) var paramList []reflect.Value
paramList = append(paramList, valueOf)
valueList := method.Func.Call(paramList)
i := valueList[0].Interface() i := valueList[0].Interface()
m := i.(map[string]string) m := i.(map[string]string)

View File

@@ -58,18 +58,20 @@ func (mi *Migrator) AutoMigrate(destList ...interface{}) {
for i := 0; i < len(destList); i++ { for i := 0; i < len(destList); i++ {
dest := destList[i] dest := destList[i]
typeOf := reflect.TypeOf(dest) typeOf := reflect.TypeOf(dest)
tableName := getTableNameByReflect(typeOf) valueOf := reflect.ValueOf(dest)
mi.migrateCommon(tableName, typeOf) tableName := getTableNameByReflect(typeOf, valueOf)
mi.migrateCommon(tableName, typeOf, valueOf)
} }
} }
// Migrate 自动迁移数据库结构,需要输入数据库名,表名 // Migrate 自动迁移数据库结构,需要输入数据库名,表名
func (mi *Migrator) Migrate(tableName string, dest interface{}) { func (mi *Migrator) Migrate(tableName string, dest interface{}) {
typeOf := reflect.TypeOf(dest) typeOf := reflect.TypeOf(dest)
mi.migrateCommon(tableName, typeOf) valueOf := reflect.ValueOf(dest)
mi.migrateCommon(tableName, typeOf, valueOf)
} }
func (mi *Migrator) migrateCommon(tableName string, typeOf reflect.Type) { func (mi *Migrator) migrateCommon(tableName string, typeOf reflect.Type, valueOf reflect.Value) {
if mi.driverName == model.Mssql { if mi.driverName == model.Mssql {
me := migrate_mssql.MigrateExecutor{ me := migrate_mssql.MigrateExecutor{
DriverName: mi.driverName, DriverName: mi.driverName,
@@ -89,7 +91,7 @@ func (mi *Migrator) migrateCommon(tableName string, typeOf reflect.Type) {
LinkCommon: mi.LinkCommon, LinkCommon: mi.LinkCommon,
}, },
} }
me.MigrateCommon(tableName, typeOf) me.MigrateCommon(tableName, typeOf, valueOf)
} }
if mi.driverName == model.Sqlite3 { if mi.driverName == model.Sqlite3 {
@@ -120,10 +122,12 @@ func (mi *Migrator) GetOpinionList() []model.OpinionItem {
} }
//反射表名,优先从方法获取,没有方法则从名字获取 //反射表名,优先从方法获取,没有方法则从名字获取
func getTableNameByReflect(typeOf reflect.Type) string { func getTableNameByReflect(typeOf reflect.Type, valueOf reflect.Value) string {
method, isSet := typeOf.MethodByName("TableName") method, isSet := typeOf.MethodByName("TableName")
if isSet { if isSet {
res := method.Func.Call(nil) var paramList []reflect.Value
paramList = append(paramList, valueOf)
res := method.Func.Call(paramList)
return res[0].String() return res[0].String()
} else { } else {
arr := strings.Split(typeOf.String(), ".") arr := strings.Split(typeOf.String(), ".")

View File

@@ -60,10 +60,6 @@ type Person struct {
Test null.Float `aorm:"type:double;comment:测试" json:"test"` Test null.Float `aorm:"type:double;comment:测试" json:"test"`
} }
func (p *Person) TableName() string {
return "erp_person"
}
func (p *Person) TableOpinion() map[string]string { func (p *Person) TableOpinion() map[string]string {
return map[string]string{ return map[string]string{
"ENGINE": "InnoDB", "ENGINE": "InnoDB",
@@ -100,11 +96,12 @@ func TestAll(t *testing.T) {
aorm.Store(&articleVO) aorm.Store(&articleVO)
aorm.Store(&personAge, &personWithArticleCount) aorm.Store(&personAge, &personWithArticleCount)
dbList := make([]aorm.DbContent, 0) var dbList = []aorm.DbContent{
dbList = append(dbList, testSqlite3Connect()) testMysqlConnect(),
dbList = append(dbList, testMysqlConnect()) testSqlite3Connect(),
dbList = append(dbList, testPostgresConnect()) testPostgresConnect(),
dbList = append(dbList, testMssqlConnect()) testMssqlConnect(),
}
for i := 0; i < len(dbList); i++ { for i := 0; i < len(dbList); i++ {
dbItem := dbList[i] dbItem := dbList[i]
@@ -115,7 +112,6 @@ func TestAll(t *testing.T) {
id := testInsert(dbItem.DriverName, dbItem.DbLink) id := testInsert(dbItem.DriverName, dbItem.DbLink)
testInsertBatch(dbItem.DriverName, dbItem.DbLink) testInsertBatch(dbItem.DriverName, dbItem.DbLink)
break
testGetOne(dbItem.DriverName, dbItem.DbLink, id) testGetOne(dbItem.DriverName, dbItem.DbLink, id)
testGetMany(dbItem.DriverName, dbItem.DbLink) testGetMany(dbItem.DriverName, dbItem.DbLink)
testUpdate(dbItem.DriverName, dbItem.DbLink, id) testUpdate(dbItem.DriverName, dbItem.DbLink, id)
@@ -227,10 +223,8 @@ func testMssqlConnect() aorm.DbContent {
} }
func testMigrate(driver string, db *sql.DB) { func testMigrate(driver string, db *sql.DB) {
//AutoMigrate
aorm.Migrator(db).Driver(driver).AutoMigrate(&person, &article, &student) aorm.Migrator(db).Driver(driver).AutoMigrate(&person, &article, &student)
//Migrate
aorm.Migrator(db).Driver(driver).Migrate("person_1", &person) aorm.Migrator(db).Driver(driver).Migrate("person_1", &person)
} }
@@ -249,15 +243,15 @@ func testInsert(driver string, db *sql.DB) int64 {
Test: null.FloatFrom(2), Test: null.FloatFrom(2),
} }
id, errInsert := aorm.Db(db).Debug(true).Driver(driver).Insert(&obj) id, errInsert := aorm.Db(db).Debug(false).Driver(driver).Insert(&obj)
if errInsert != nil { if errInsert != nil {
panic(driver + " testInsert " + "found err: " + errInsert.Error()) panic(driver + " testInsert " + "found err: " + errInsert.Error())
} }
//aorm.Db(db).Debug(false).Driver(driver).Insert(&Article{ aorm.Db(db).Debug(false).Driver(driver).Insert(&Article{
// Type: null.IntFrom(0), Type: null.IntFrom(0),
// PersonId: null.IntFrom(id), PersonId: null.IntFrom(id),
// ArticleBody: null.StringFrom("文章内容"), ArticleBody: null.StringFrom("文章内容"),
//}) })
var personItem Person var personItem Person
err := aorm.Db(db).Table(&person).Debug(false).Driver(driver).Table(&person).WhereEq(&person.Id, id).OrderBy(&person.Id, builder.Desc).GetOne(&personItem) err := aorm.Db(db).Table(&person).Debug(false).Driver(driver).Table(&person).WhereEq(&person.Id, id).OrderBy(&person.Id, builder.Desc).GetOne(&personItem)
@@ -290,9 +284,9 @@ func testInsert(driver string, db *sql.DB) int64 {
} }
//测试非id主键 //测试非id主键
//aorm.Db(db).Debug(false).Driver(driver).Insert(&Student{ aorm.Db(db).Debug(false).Driver(driver).Insert(&Student{
// Name: null.StringFrom("new student"), Name: null.StringFrom("new student"),
//}) })
return id return id
} }
@@ -319,7 +313,7 @@ func testInsertBatch(driver string, db *sql.DB) int64 {
Test: null.FloatFrom(200.15987654321987654321), Test: null.FloatFrom(200.15987654321987654321),
}) })
count, err := aorm.Db(db).Debug(true).Driver(driver).InsertBatch(&batch) count, err := aorm.Db(db).Debug(false).Driver(driver).InsertBatch(&batch)
if err != nil { if err != nil {
panic(driver + " testInsertBatch " + "found err:" + err.Error()) panic(driver + " testInsertBatch " + "found err:" + err.Error())
} }