mirror of
https://github.com/tangpanqing/aorm.git
synced 2025-10-16 12:51:23 +08:00
new test
This commit is contained in:
3
aorm.go
3
aorm.go
@@ -3,7 +3,6 @@ package aorm
|
||||
import (
|
||||
"database/sql" //只需导入你需要的驱动即可
|
||||
"github.com/tangpanqing/aorm/builder"
|
||||
"github.com/tangpanqing/aorm/cache"
|
||||
"github.com/tangpanqing/aorm/migrator"
|
||||
"github.com/tangpanqing/aorm/model"
|
||||
)
|
||||
@@ -15,7 +14,7 @@ type DbContent struct {
|
||||
}
|
||||
|
||||
func Store(destList ...interface{}) {
|
||||
cache.Store(destList...)
|
||||
builder.Store(destList...)
|
||||
}
|
||||
|
||||
//Open 开始一个数据库连接
|
||||
|
@@ -2,7 +2,6 @@ package builder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/tangpanqing/aorm/cache"
|
||||
"github.com/tangpanqing/aorm/helper"
|
||||
"reflect"
|
||||
"strings"
|
||||
@@ -90,9 +89,9 @@ func getPrefixByField(field interface{}, prefix ...string) string {
|
||||
valueOf := reflect.ValueOf(field)
|
||||
if reflect.Ptr == valueOf.Kind() {
|
||||
fieldPointer := valueOf.Pointer()
|
||||
tablePointer := cache.GetFieldMap(fieldPointer).TablePointer
|
||||
tablePointer := getFieldMap(fieldPointer).TablePointer
|
||||
|
||||
tableName := cache.GetTableMap(tablePointer)
|
||||
tableName := getTableMap(tablePointer)
|
||||
strArr := strings.Split(tableName, ".")
|
||||
str = helper.UnderLine(strArr[len(strArr)-1])
|
||||
} else {
|
||||
@@ -111,7 +110,7 @@ func getTableNameByTable(table interface{}) string {
|
||||
|
||||
valueOf := reflect.ValueOf(table)
|
||||
if reflect.Ptr == valueOf.Kind() {
|
||||
tableName := cache.GetTableMap(valueOf.Pointer())
|
||||
tableName := getTableMap(valueOf.Pointer())
|
||||
strArr := strings.Split(tableName, ".")
|
||||
return helper.UnderLine(strArr[len(strArr)-1])
|
||||
} else {
|
||||
@@ -123,7 +122,7 @@ func getTableNameByTable(table interface{}) string {
|
||||
func getFieldName(field interface{}) string {
|
||||
valueOf := reflect.ValueOf(field)
|
||||
if reflect.Ptr == valueOf.Kind() {
|
||||
return helper.UnderLine(cache.GetFieldMap(reflect.ValueOf(field).Pointer()).Name)
|
||||
return helper.UnderLine(getFieldMap(reflect.ValueOf(field).Pointer()).Name)
|
||||
} else {
|
||||
return fmt.Sprintf("%v", field)
|
||||
}
|
||||
|
59
builder/cache.go
Normal file
59
builder/cache.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package builder
|
||||
|
||||
import (
|
||||
"github.com/tangpanqing/aorm/helper"
|
||||
"github.com/tangpanqing/aorm/model"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
var TableMap = make(map[uintptr]string)
|
||||
var FieldMap = make(map[uintptr]model.FieldInfo)
|
||||
|
||||
//Store 保存到缓存
|
||||
func Store(destList ...interface{}) {
|
||||
for i := 0; i < len(destList); i++ {
|
||||
dest := destList[i]
|
||||
valueOf := reflect.ValueOf(dest)
|
||||
typeof := reflect.TypeOf(dest)
|
||||
|
||||
tablePointer := valueOf.Pointer()
|
||||
setTableMap(tablePointer, getTableNameByReflect(typeof, valueOf))
|
||||
|
||||
for j := 0; j < valueOf.Elem().NumField(); j++ {
|
||||
addr := valueOf.Elem().Field(j).Addr().Pointer()
|
||||
key, _ := getFieldNameByReflect(typeof.Elem().Field(j))
|
||||
|
||||
setFieldMap(addr, model.FieldInfo{
|
||||
TablePointer: tablePointer,
|
||||
Name: key,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func setTableMap(tablePointer uintptr, name string) {
|
||||
TableMap[tablePointer] = name
|
||||
}
|
||||
|
||||
func getTableMap(tablePointer uintptr) string {
|
||||
return TableMap[tablePointer]
|
||||
}
|
||||
|
||||
func setFieldMap(fieldPointer uintptr, fieldInfo model.FieldInfo) {
|
||||
FieldMap[fieldPointer] = fieldInfo
|
||||
}
|
||||
|
||||
func getFieldMap(fieldPointer uintptr) model.FieldInfo {
|
||||
return FieldMap[fieldPointer]
|
||||
}
|
||||
|
||||
func getFieldNameByReflect(field reflect.StructField) (string, map[string]string) {
|
||||
key := helper.UnderLine(field.Name)
|
||||
tag := field.Tag.Get("aorm")
|
||||
tagMap := getTagMap(tag)
|
||||
if column, ok := tagMap["column"]; ok {
|
||||
key = column
|
||||
}
|
||||
|
||||
return key, tagMap
|
||||
}
|
@@ -4,7 +4,6 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/tangpanqing/aorm/helper"
|
||||
"github.com/tangpanqing/aorm/model"
|
||||
"reflect"
|
||||
"strconv"
|
||||
@@ -81,6 +80,21 @@ func (b *Builder) getTableNameCommon(typeOf reflect.Type, valueOf reflect.Value)
|
||||
return getTableNameByReflect(typeOf, valueOf)
|
||||
}
|
||||
|
||||
func getTagMap(fieldTag string) map[string]string {
|
||||
var fieldMap = make(map[string]string)
|
||||
if "" != fieldTag {
|
||||
tagArr := strings.Split(fieldTag, ";")
|
||||
for j := 0; j < len(tagArr); j++ {
|
||||
tagArrArr := strings.Split(tagArr[j], ":")
|
||||
fieldMap[tagArrArr[0]] = ""
|
||||
if len(tagArrArr) > 1 {
|
||||
fieldMap[tagArrArr[0]] = tagArrArr[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
return fieldMap
|
||||
}
|
||||
|
||||
// Insert 增加记录
|
||||
func (b *Builder) Insert(dest interface{}) (int64, error) {
|
||||
typeOf := reflect.TypeOf(dest)
|
||||
@@ -93,12 +107,11 @@ func (b *Builder) Insert(dest interface{}) (int64, error) {
|
||||
var paramList []any
|
||||
var place []string
|
||||
for i := 0; i < typeOf.Elem().NumField(); i++ {
|
||||
key := helper.UnderLine(typeOf.Elem().Field(i).Name)
|
||||
key, tagMap := getFieldNameByReflect(typeOf.Elem().Field(i))
|
||||
|
||||
//如果是Postgres数据库,寻找主键
|
||||
if b.driverName == model.Postgres {
|
||||
tag := typeOf.Elem().Field(i).Tag.Get("aorm")
|
||||
if -1 != strings.Index(tag, "primary") {
|
||||
if _, ok := tagMap["primary"]; ok {
|
||||
primaryKey = key
|
||||
}
|
||||
}
|
||||
@@ -193,7 +206,7 @@ func (b *Builder) InsertBatch(values interface{}) (int64, error) {
|
||||
isNotNull := valueOf.Index(j).Elem().Field(i).Field(0).Field(1).Bool()
|
||||
if isNotNull {
|
||||
if j == 0 {
|
||||
key := helper.UnderLine(typeOf.Elem().Field(i).Name)
|
||||
key, _ := getFieldNameByReflect(typeOf.Elem().Field(i))
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
|
@@ -1,7 +1,6 @@
|
||||
package builder
|
||||
|
||||
import (
|
||||
"github.com/tangpanqing/aorm/helper"
|
||||
"github.com/tangpanqing/aorm/model"
|
||||
"reflect"
|
||||
"strings"
|
||||
@@ -91,7 +90,8 @@ func (b *Builder) handleSet(typeOf reflect.Type, valueOf reflect.Value, paramLis
|
||||
for i := 0; i < typeOf.Elem().NumField(); i++ {
|
||||
isNotNull := valueOf.Elem().Field(i).Field(0).Field(1).Bool()
|
||||
if isNotNull {
|
||||
key := helper.UnderLine(typeOf.Elem().Field(i).Name)
|
||||
key, _ := getFieldNameByReflect(typeOf.Elem().Field(i))
|
||||
|
||||
val := valueOf.Elem().Field(i).Field(0).Field(0).Interface()
|
||||
|
||||
keys = append(keys, key+"=?")
|
||||
|
63
cache/cache.go
vendored
63
cache/cache.go
vendored
@@ -1,63 +0,0 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"github.com/tangpanqing/aorm/helper"
|
||||
"github.com/tangpanqing/aorm/model"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var TableMap = make(map[uintptr]string)
|
||||
var FieldMap = make(map[uintptr]model.FieldInfo)
|
||||
|
||||
//Store 保存到缓存
|
||||
func Store(destList ...interface{}) {
|
||||
for i := 0; i < len(destList); i++ {
|
||||
dest := destList[i]
|
||||
valueOf := reflect.ValueOf(dest)
|
||||
typeof := reflect.TypeOf(dest)
|
||||
|
||||
tablePointer := valueOf.Pointer()
|
||||
SetTableMap(tablePointer, getTableNameByReflect(typeof, valueOf))
|
||||
|
||||
for j := 0; j < valueOf.Elem().NumField(); j++ {
|
||||
addr := valueOf.Elem().Field(j).Addr().Pointer()
|
||||
name := typeof.Elem().Field(j).Name
|
||||
|
||||
SetFieldMap(addr, model.FieldInfo{
|
||||
TablePointer: tablePointer,
|
||||
Name: name,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func SetTableMap(tablePointer uintptr, name string) {
|
||||
TableMap[tablePointer] = name
|
||||
}
|
||||
|
||||
func GetTableMap(tablePointer uintptr) string {
|
||||
return TableMap[tablePointer]
|
||||
}
|
||||
|
||||
func SetFieldMap(fieldPointer uintptr, fieldInfo model.FieldInfo) {
|
||||
FieldMap[fieldPointer] = fieldInfo
|
||||
}
|
||||
|
||||
func GetFieldMap(fieldPointer uintptr) model.FieldInfo {
|
||||
return FieldMap[fieldPointer]
|
||||
}
|
||||
|
||||
//反射表名,优先从方法获取,没有方法则从名字获取
|
||||
func getTableNameByReflect(typeOf reflect.Type, valueOf reflect.Value) string {
|
||||
method, isSet := typeOf.MethodByName("TableName")
|
||||
if isSet {
|
||||
var paramList []reflect.Value
|
||||
paramList = append(paramList, valueOf)
|
||||
res := method.Func.Call(paramList)
|
||||
return res[0].String()
|
||||
} else {
|
||||
arr := strings.Split(typeOf.String(), ".")
|
||||
return helper.UnderLine(arr[len(arr)-1])
|
||||
}
|
||||
}
|
@@ -106,6 +106,12 @@ func (mm *MigrateExecutor) getColumnsFromCode(typeOf reflect.Type) []Column {
|
||||
fieldName := helper.UnderLine(typeOf.Elem().Field(i).Name)
|
||||
fieldType := typeOf.Elem().Field(i).Type.Name()
|
||||
fieldMap := getTagMap(typeOf.Elem().Field(i).Tag.Get("aorm"))
|
||||
|
||||
//如果tag里重新设置了字段名
|
||||
if column, ok := fieldMap["column"]; ok {
|
||||
fieldName = column
|
||||
}
|
||||
|
||||
columnsFromCode = append(columnsFromCode, Column{
|
||||
ColumnName: null.StringFrom(fieldName),
|
||||
DataType: null.StringFrom(getDataType(fieldType, fieldMap)),
|
||||
|
@@ -117,10 +117,6 @@ func (mi *Migrator) migrateCommon(tableName string, typeOf reflect.Type, valueOf
|
||||
}
|
||||
}
|
||||
|
||||
func (mi *Migrator) GetOpinionList() []model.OpinionItem {
|
||||
return mi.opinionList
|
||||
}
|
||||
|
||||
//反射表名,优先从方法获取,没有方法则从名字获取
|
||||
func getTableNameByReflect(typeOf reflect.Type, valueOf reflect.Value) string {
|
||||
method, isSet := typeOf.MethodByName("TableName")
|
||||
|
@@ -107,6 +107,7 @@ func TestAll(t *testing.T) {
|
||||
dbItem := dbList[i]
|
||||
|
||||
testMigrate(dbItem.DriverName, dbItem.DbLink)
|
||||
|
||||
testShowCreateTable(dbItem.DriverName, dbItem.DbLink)
|
||||
|
||||
id := testInsert(dbItem.DriverName, dbItem.DbLink)
|
||||
|
Reference in New Issue
Block a user