From 1d2d090efa7060ee16672aab6f6012c5fffaccde Mon Sep 17 00:00:00 2001 From: tangpanqing Date: Tue, 10 Jan 2023 11:59:13 +0800 Subject: [PATCH] new test --- aorm.go | 3 +- builder/builder.go | 9 +++--- builder/cache.go | 59 +++++++++++++++++++++++++++++++++++++ builder/crud.go | 23 +++++++++++---- builder/handle.go | 4 +-- cache/cache.go | 63 ---------------------------------------- migrate_mysql/migrate.go | 6 ++++ migrator/migrator.go | 4 --- test/aorm_test.go | 1 + 9 files changed, 91 insertions(+), 81 deletions(-) create mode 100644 builder/cache.go delete mode 100644 cache/cache.go diff --git a/aorm.go b/aorm.go index 66c1ac6..9d5ac83 100644 --- a/aorm.go +++ b/aorm.go @@ -3,7 +3,6 @@ package aorm import ( "database/sql" //只需导入你需要的驱动即可 "github.com/tangpanqing/aorm/builder" - "github.com/tangpanqing/aorm/cache" "github.com/tangpanqing/aorm/migrator" "github.com/tangpanqing/aorm/model" ) @@ -15,7 +14,7 @@ type DbContent struct { } func Store(destList ...interface{}) { - cache.Store(destList...) + builder.Store(destList...) } //Open 开始一个数据库连接 diff --git a/builder/builder.go b/builder/builder.go index 42c0ac7..f8b0eac 100644 --- a/builder/builder.go +++ b/builder/builder.go @@ -2,7 +2,6 @@ package builder import ( "fmt" - "github.com/tangpanqing/aorm/cache" "github.com/tangpanqing/aorm/helper" "reflect" "strings" @@ -90,9 +89,9 @@ func getPrefixByField(field interface{}, prefix ...string) string { valueOf := reflect.ValueOf(field) if reflect.Ptr == valueOf.Kind() { fieldPointer := valueOf.Pointer() - tablePointer := cache.GetFieldMap(fieldPointer).TablePointer + tablePointer := getFieldMap(fieldPointer).TablePointer - tableName := cache.GetTableMap(tablePointer) + tableName := getTableMap(tablePointer) strArr := strings.Split(tableName, ".") str = helper.UnderLine(strArr[len(strArr)-1]) } else { @@ -111,7 +110,7 @@ func getTableNameByTable(table interface{}) string { valueOf := reflect.ValueOf(table) if reflect.Ptr == valueOf.Kind() { - tableName := cache.GetTableMap(valueOf.Pointer()) + tableName := getTableMap(valueOf.Pointer()) strArr := strings.Split(tableName, ".") return helper.UnderLine(strArr[len(strArr)-1]) } else { @@ -123,7 +122,7 @@ func getTableNameByTable(table interface{}) string { func getFieldName(field interface{}) string { valueOf := reflect.ValueOf(field) if reflect.Ptr == valueOf.Kind() { - return helper.UnderLine(cache.GetFieldMap(reflect.ValueOf(field).Pointer()).Name) + return helper.UnderLine(getFieldMap(reflect.ValueOf(field).Pointer()).Name) } else { return fmt.Sprintf("%v", field) } diff --git a/builder/cache.go b/builder/cache.go new file mode 100644 index 0000000..98ca57c --- /dev/null +++ b/builder/cache.go @@ -0,0 +1,59 @@ +package builder + +import ( + "github.com/tangpanqing/aorm/helper" + "github.com/tangpanqing/aorm/model" + "reflect" +) + +var TableMap = make(map[uintptr]string) +var FieldMap = make(map[uintptr]model.FieldInfo) + +//Store 保存到缓存 +func Store(destList ...interface{}) { + for i := 0; i < len(destList); i++ { + dest := destList[i] + valueOf := reflect.ValueOf(dest) + typeof := reflect.TypeOf(dest) + + tablePointer := valueOf.Pointer() + setTableMap(tablePointer, getTableNameByReflect(typeof, valueOf)) + + for j := 0; j < valueOf.Elem().NumField(); j++ { + addr := valueOf.Elem().Field(j).Addr().Pointer() + key, _ := getFieldNameByReflect(typeof.Elem().Field(j)) + + setFieldMap(addr, model.FieldInfo{ + TablePointer: tablePointer, + Name: key, + }) + } + } +} + +func setTableMap(tablePointer uintptr, name string) { + TableMap[tablePointer] = name +} + +func getTableMap(tablePointer uintptr) string { + return TableMap[tablePointer] +} + +func setFieldMap(fieldPointer uintptr, fieldInfo model.FieldInfo) { + FieldMap[fieldPointer] = fieldInfo +} + +func getFieldMap(fieldPointer uintptr) model.FieldInfo { + return FieldMap[fieldPointer] +} + +func getFieldNameByReflect(field reflect.StructField) (string, map[string]string) { + key := helper.UnderLine(field.Name) + tag := field.Tag.Get("aorm") + tagMap := getTagMap(tag) + if column, ok := tagMap["column"]; ok { + key = column + } + + return key, tagMap +} diff --git a/builder/crud.go b/builder/crud.go index 4e11a35..d354151 100644 --- a/builder/crud.go +++ b/builder/crud.go @@ -4,7 +4,6 @@ import ( "database/sql" "errors" "fmt" - "github.com/tangpanqing/aorm/helper" "github.com/tangpanqing/aorm/model" "reflect" "strconv" @@ -81,6 +80,21 @@ func (b *Builder) getTableNameCommon(typeOf reflect.Type, valueOf reflect.Value) return getTableNameByReflect(typeOf, valueOf) } +func getTagMap(fieldTag string) map[string]string { + var fieldMap = make(map[string]string) + if "" != fieldTag { + tagArr := strings.Split(fieldTag, ";") + for j := 0; j < len(tagArr); j++ { + tagArrArr := strings.Split(tagArr[j], ":") + fieldMap[tagArrArr[0]] = "" + if len(tagArrArr) > 1 { + fieldMap[tagArrArr[0]] = tagArrArr[1] + } + } + } + return fieldMap +} + // Insert 增加记录 func (b *Builder) Insert(dest interface{}) (int64, error) { typeOf := reflect.TypeOf(dest) @@ -93,12 +107,11 @@ func (b *Builder) Insert(dest interface{}) (int64, error) { var paramList []any var place []string for i := 0; i < typeOf.Elem().NumField(); i++ { - key := helper.UnderLine(typeOf.Elem().Field(i).Name) + key, tagMap := getFieldNameByReflect(typeOf.Elem().Field(i)) //如果是Postgres数据库,寻找主键 if b.driverName == model.Postgres { - tag := typeOf.Elem().Field(i).Tag.Get("aorm") - if -1 != strings.Index(tag, "primary") { + if _, ok := tagMap["primary"]; ok { primaryKey = key } } @@ -193,7 +206,7 @@ func (b *Builder) InsertBatch(values interface{}) (int64, error) { isNotNull := valueOf.Index(j).Elem().Field(i).Field(0).Field(1).Bool() if isNotNull { if j == 0 { - key := helper.UnderLine(typeOf.Elem().Field(i).Name) + key, _ := getFieldNameByReflect(typeOf.Elem().Field(i)) keys = append(keys, key) } diff --git a/builder/handle.go b/builder/handle.go index 74278f7..9a5316d 100644 --- a/builder/handle.go +++ b/builder/handle.go @@ -1,7 +1,6 @@ package builder import ( - "github.com/tangpanqing/aorm/helper" "github.com/tangpanqing/aorm/model" "reflect" "strings" @@ -91,7 +90,8 @@ func (b *Builder) handleSet(typeOf reflect.Type, valueOf reflect.Value, paramLis for i := 0; i < typeOf.Elem().NumField(); i++ { isNotNull := valueOf.Elem().Field(i).Field(0).Field(1).Bool() if isNotNull { - key := helper.UnderLine(typeOf.Elem().Field(i).Name) + key, _ := getFieldNameByReflect(typeOf.Elem().Field(i)) + val := valueOf.Elem().Field(i).Field(0).Field(0).Interface() keys = append(keys, key+"=?") diff --git a/cache/cache.go b/cache/cache.go deleted file mode 100644 index a26aedd..0000000 --- a/cache/cache.go +++ /dev/null @@ -1,63 +0,0 @@ -package cache - -import ( - "github.com/tangpanqing/aorm/helper" - "github.com/tangpanqing/aorm/model" - "reflect" - "strings" -) - -var TableMap = make(map[uintptr]string) -var FieldMap = make(map[uintptr]model.FieldInfo) - -//Store 保存到缓存 -func Store(destList ...interface{}) { - for i := 0; i < len(destList); i++ { - dest := destList[i] - valueOf := reflect.ValueOf(dest) - typeof := reflect.TypeOf(dest) - - tablePointer := valueOf.Pointer() - SetTableMap(tablePointer, getTableNameByReflect(typeof, valueOf)) - - for j := 0; j < valueOf.Elem().NumField(); j++ { - addr := valueOf.Elem().Field(j).Addr().Pointer() - name := typeof.Elem().Field(j).Name - - SetFieldMap(addr, model.FieldInfo{ - TablePointer: tablePointer, - Name: name, - }) - } - } -} - -func SetTableMap(tablePointer uintptr, name string) { - TableMap[tablePointer] = name -} - -func GetTableMap(tablePointer uintptr) string { - return TableMap[tablePointer] -} - -func SetFieldMap(fieldPointer uintptr, fieldInfo model.FieldInfo) { - FieldMap[fieldPointer] = fieldInfo -} - -func GetFieldMap(fieldPointer uintptr) model.FieldInfo { - return FieldMap[fieldPointer] -} - -//反射表名,优先从方法获取,没有方法则从名字获取 -func getTableNameByReflect(typeOf reflect.Type, valueOf reflect.Value) string { - method, isSet := typeOf.MethodByName("TableName") - if isSet { - var paramList []reflect.Value - paramList = append(paramList, valueOf) - res := method.Func.Call(paramList) - return res[0].String() - } else { - arr := strings.Split(typeOf.String(), ".") - return helper.UnderLine(arr[len(arr)-1]) - } -} diff --git a/migrate_mysql/migrate.go b/migrate_mysql/migrate.go index 1d42e23..7e794b7 100644 --- a/migrate_mysql/migrate.go +++ b/migrate_mysql/migrate.go @@ -106,6 +106,12 @@ func (mm *MigrateExecutor) getColumnsFromCode(typeOf reflect.Type) []Column { fieldName := helper.UnderLine(typeOf.Elem().Field(i).Name) fieldType := typeOf.Elem().Field(i).Type.Name() fieldMap := getTagMap(typeOf.Elem().Field(i).Tag.Get("aorm")) + + //如果tag里重新设置了字段名 + if column, ok := fieldMap["column"]; ok { + fieldName = column + } + columnsFromCode = append(columnsFromCode, Column{ ColumnName: null.StringFrom(fieldName), DataType: null.StringFrom(getDataType(fieldType, fieldMap)), diff --git a/migrator/migrator.go b/migrator/migrator.go index e2791a4..c0cd28f 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -117,10 +117,6 @@ func (mi *Migrator) migrateCommon(tableName string, typeOf reflect.Type, valueOf } } -func (mi *Migrator) GetOpinionList() []model.OpinionItem { - return mi.opinionList -} - //反射表名,优先从方法获取,没有方法则从名字获取 func getTableNameByReflect(typeOf reflect.Type, valueOf reflect.Value) string { method, isSet := typeOf.MethodByName("TableName") diff --git a/test/aorm_test.go b/test/aorm_test.go index 4b521e8..756297b 100644 --- a/test/aorm_test.go +++ b/test/aorm_test.go @@ -107,6 +107,7 @@ func TestAll(t *testing.T) { dbItem := dbList[i] testMigrate(dbItem.DriverName, dbItem.DbLink) + testShowCreateTable(dbItem.DriverName, dbItem.DbLink) id := testInsert(dbItem.DriverName, dbItem.DbLink)