This commit is contained in:
tangpanqing
2023-01-10 11:59:13 +08:00
parent 10a7bc9230
commit 1d2d090efa
9 changed files with 91 additions and 81 deletions

View File

@@ -3,7 +3,6 @@ package aorm
import (
"database/sql" //只需导入你需要的驱动即可
"github.com/tangpanqing/aorm/builder"
"github.com/tangpanqing/aorm/cache"
"github.com/tangpanqing/aorm/migrator"
"github.com/tangpanqing/aorm/model"
)
@@ -15,7 +14,7 @@ type DbContent struct {
}
func Store(destList ...interface{}) {
cache.Store(destList...)
builder.Store(destList...)
}
//Open 开始一个数据库连接

View File

@@ -2,7 +2,6 @@ package builder
import (
"fmt"
"github.com/tangpanqing/aorm/cache"
"github.com/tangpanqing/aorm/helper"
"reflect"
"strings"
@@ -90,9 +89,9 @@ func getPrefixByField(field interface{}, prefix ...string) string {
valueOf := reflect.ValueOf(field)
if reflect.Ptr == valueOf.Kind() {
fieldPointer := valueOf.Pointer()
tablePointer := cache.GetFieldMap(fieldPointer).TablePointer
tablePointer := getFieldMap(fieldPointer).TablePointer
tableName := cache.GetTableMap(tablePointer)
tableName := getTableMap(tablePointer)
strArr := strings.Split(tableName, ".")
str = helper.UnderLine(strArr[len(strArr)-1])
} else {
@@ -111,7 +110,7 @@ func getTableNameByTable(table interface{}) string {
valueOf := reflect.ValueOf(table)
if reflect.Ptr == valueOf.Kind() {
tableName := cache.GetTableMap(valueOf.Pointer())
tableName := getTableMap(valueOf.Pointer())
strArr := strings.Split(tableName, ".")
return helper.UnderLine(strArr[len(strArr)-1])
} else {
@@ -123,7 +122,7 @@ func getTableNameByTable(table interface{}) string {
func getFieldName(field interface{}) string {
valueOf := reflect.ValueOf(field)
if reflect.Ptr == valueOf.Kind() {
return helper.UnderLine(cache.GetFieldMap(reflect.ValueOf(field).Pointer()).Name)
return helper.UnderLine(getFieldMap(reflect.ValueOf(field).Pointer()).Name)
} else {
return fmt.Sprintf("%v", field)
}

59
builder/cache.go Normal file
View File

@@ -0,0 +1,59 @@
package builder
import (
"github.com/tangpanqing/aorm/helper"
"github.com/tangpanqing/aorm/model"
"reflect"
)
var TableMap = make(map[uintptr]string)
var FieldMap = make(map[uintptr]model.FieldInfo)
//Store 保存到缓存
func Store(destList ...interface{}) {
for i := 0; i < len(destList); i++ {
dest := destList[i]
valueOf := reflect.ValueOf(dest)
typeof := reflect.TypeOf(dest)
tablePointer := valueOf.Pointer()
setTableMap(tablePointer, getTableNameByReflect(typeof, valueOf))
for j := 0; j < valueOf.Elem().NumField(); j++ {
addr := valueOf.Elem().Field(j).Addr().Pointer()
key, _ := getFieldNameByReflect(typeof.Elem().Field(j))
setFieldMap(addr, model.FieldInfo{
TablePointer: tablePointer,
Name: key,
})
}
}
}
func setTableMap(tablePointer uintptr, name string) {
TableMap[tablePointer] = name
}
func getTableMap(tablePointer uintptr) string {
return TableMap[tablePointer]
}
func setFieldMap(fieldPointer uintptr, fieldInfo model.FieldInfo) {
FieldMap[fieldPointer] = fieldInfo
}
func getFieldMap(fieldPointer uintptr) model.FieldInfo {
return FieldMap[fieldPointer]
}
func getFieldNameByReflect(field reflect.StructField) (string, map[string]string) {
key := helper.UnderLine(field.Name)
tag := field.Tag.Get("aorm")
tagMap := getTagMap(tag)
if column, ok := tagMap["column"]; ok {
key = column
}
return key, tagMap
}

View File

@@ -4,7 +4,6 @@ import (
"database/sql"
"errors"
"fmt"
"github.com/tangpanqing/aorm/helper"
"github.com/tangpanqing/aorm/model"
"reflect"
"strconv"
@@ -81,6 +80,21 @@ func (b *Builder) getTableNameCommon(typeOf reflect.Type, valueOf reflect.Value)
return getTableNameByReflect(typeOf, valueOf)
}
func getTagMap(fieldTag string) map[string]string {
var fieldMap = make(map[string]string)
if "" != fieldTag {
tagArr := strings.Split(fieldTag, ";")
for j := 0; j < len(tagArr); j++ {
tagArrArr := strings.Split(tagArr[j], ":")
fieldMap[tagArrArr[0]] = ""
if len(tagArrArr) > 1 {
fieldMap[tagArrArr[0]] = tagArrArr[1]
}
}
}
return fieldMap
}
// Insert 增加记录
func (b *Builder) Insert(dest interface{}) (int64, error) {
typeOf := reflect.TypeOf(dest)
@@ -93,12 +107,11 @@ func (b *Builder) Insert(dest interface{}) (int64, error) {
var paramList []any
var place []string
for i := 0; i < typeOf.Elem().NumField(); i++ {
key := helper.UnderLine(typeOf.Elem().Field(i).Name)
key, tagMap := getFieldNameByReflect(typeOf.Elem().Field(i))
//如果是Postgres数据库寻找主键
if b.driverName == model.Postgres {
tag := typeOf.Elem().Field(i).Tag.Get("aorm")
if -1 != strings.Index(tag, "primary") {
if _, ok := tagMap["primary"]; ok {
primaryKey = key
}
}
@@ -193,7 +206,7 @@ func (b *Builder) InsertBatch(values interface{}) (int64, error) {
isNotNull := valueOf.Index(j).Elem().Field(i).Field(0).Field(1).Bool()
if isNotNull {
if j == 0 {
key := helper.UnderLine(typeOf.Elem().Field(i).Name)
key, _ := getFieldNameByReflect(typeOf.Elem().Field(i))
keys = append(keys, key)
}

View File

@@ -1,7 +1,6 @@
package builder
import (
"github.com/tangpanqing/aorm/helper"
"github.com/tangpanqing/aorm/model"
"reflect"
"strings"
@@ -91,7 +90,8 @@ func (b *Builder) handleSet(typeOf reflect.Type, valueOf reflect.Value, paramLis
for i := 0; i < typeOf.Elem().NumField(); i++ {
isNotNull := valueOf.Elem().Field(i).Field(0).Field(1).Bool()
if isNotNull {
key := helper.UnderLine(typeOf.Elem().Field(i).Name)
key, _ := getFieldNameByReflect(typeOf.Elem().Field(i))
val := valueOf.Elem().Field(i).Field(0).Field(0).Interface()
keys = append(keys, key+"=?")

63
cache/cache.go vendored
View File

@@ -1,63 +0,0 @@
package cache
import (
"github.com/tangpanqing/aorm/helper"
"github.com/tangpanqing/aorm/model"
"reflect"
"strings"
)
var TableMap = make(map[uintptr]string)
var FieldMap = make(map[uintptr]model.FieldInfo)
//Store 保存到缓存
func Store(destList ...interface{}) {
for i := 0; i < len(destList); i++ {
dest := destList[i]
valueOf := reflect.ValueOf(dest)
typeof := reflect.TypeOf(dest)
tablePointer := valueOf.Pointer()
SetTableMap(tablePointer, getTableNameByReflect(typeof, valueOf))
for j := 0; j < valueOf.Elem().NumField(); j++ {
addr := valueOf.Elem().Field(j).Addr().Pointer()
name := typeof.Elem().Field(j).Name
SetFieldMap(addr, model.FieldInfo{
TablePointer: tablePointer,
Name: name,
})
}
}
}
func SetTableMap(tablePointer uintptr, name string) {
TableMap[tablePointer] = name
}
func GetTableMap(tablePointer uintptr) string {
return TableMap[tablePointer]
}
func SetFieldMap(fieldPointer uintptr, fieldInfo model.FieldInfo) {
FieldMap[fieldPointer] = fieldInfo
}
func GetFieldMap(fieldPointer uintptr) model.FieldInfo {
return FieldMap[fieldPointer]
}
//反射表名,优先从方法获取,没有方法则从名字获取
func getTableNameByReflect(typeOf reflect.Type, valueOf reflect.Value) string {
method, isSet := typeOf.MethodByName("TableName")
if isSet {
var paramList []reflect.Value
paramList = append(paramList, valueOf)
res := method.Func.Call(paramList)
return res[0].String()
} else {
arr := strings.Split(typeOf.String(), ".")
return helper.UnderLine(arr[len(arr)-1])
}
}

View File

@@ -106,6 +106,12 @@ func (mm *MigrateExecutor) getColumnsFromCode(typeOf reflect.Type) []Column {
fieldName := helper.UnderLine(typeOf.Elem().Field(i).Name)
fieldType := typeOf.Elem().Field(i).Type.Name()
fieldMap := getTagMap(typeOf.Elem().Field(i).Tag.Get("aorm"))
//如果tag里重新设置了字段名
if column, ok := fieldMap["column"]; ok {
fieldName = column
}
columnsFromCode = append(columnsFromCode, Column{
ColumnName: null.StringFrom(fieldName),
DataType: null.StringFrom(getDataType(fieldType, fieldMap)),

View File

@@ -117,10 +117,6 @@ func (mi *Migrator) migrateCommon(tableName string, typeOf reflect.Type, valueOf
}
}
func (mi *Migrator) GetOpinionList() []model.OpinionItem {
return mi.opinionList
}
//反射表名,优先从方法获取,没有方法则从名字获取
func getTableNameByReflect(typeOf reflect.Type, valueOf reflect.Value) string {
method, isSet := typeOf.MethodByName("TableName")

View File

@@ -107,6 +107,7 @@ func TestAll(t *testing.T) {
dbItem := dbList[i]
testMigrate(dbItem.DriverName, dbItem.DbLink)
testShowCreateTable(dbItem.DriverName, dbItem.DbLink)
id := testInsert(dbItem.DriverName, dbItem.DbLink)