mirror of
https://github.com/tangpanqing/aorm.git
synced 2025-10-08 17:30:55 +08:00
528 lines
13 KiB
Go
528 lines
13 KiB
Go
package aorm
|
|
|
|
import (
|
|
"fmt"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
)
|
|
|
|
type Table struct {
|
|
TableName string
|
|
Engine string
|
|
Comment string
|
|
}
|
|
|
|
type Column struct {
|
|
ColumnName string
|
|
ColumnDefault string
|
|
IsNullable string
|
|
DataType string //数据类型 varchar,bigint,int
|
|
MaxLength int //数据最大长度 20
|
|
ColumnType string //列类型 varchar(20)
|
|
ColumnComment string
|
|
Extra string //扩展信息 auto_increment
|
|
DefaultVal string //默认值
|
|
}
|
|
|
|
type Index struct {
|
|
NonUnique int
|
|
ColumnName string
|
|
KeyName string
|
|
}
|
|
|
|
type OpinionItem struct {
|
|
Key string
|
|
Val string
|
|
}
|
|
|
|
func (db *Executor) Opinion(key string, val string) *Executor {
|
|
if key == "COMMENT" {
|
|
val = "'" + val + "'"
|
|
}
|
|
|
|
db.OpinionList = append(db.OpinionList, OpinionItem{Key: key, Val: val})
|
|
|
|
return db
|
|
}
|
|
|
|
func (db *Executor) ShowCreateTable(tableName string) string {
|
|
list, _ := db.Query("show create table " + tableName)
|
|
return list[0]["Create Table"].(string)
|
|
}
|
|
|
|
// Migrate 迁移数据库结构,需要输入数据库名,表名自动获取
|
|
func (db *Executor) AutoMigrate(dest interface{}) {
|
|
typeOf := reflect.TypeOf(dest)
|
|
arr := strings.Split(typeOf.String(), ".")
|
|
tableName := UnderLine(arr[len(arr)-1])
|
|
|
|
db.migrateCommon(tableName, typeOf)
|
|
}
|
|
|
|
// AutoMigrate 自动迁移数据库结构,需要输入数据库名,表名
|
|
func (db *Executor) Migrate(tableName string, dest interface{}) {
|
|
typeOf := reflect.TypeOf(dest)
|
|
db.migrateCommon(tableName, typeOf)
|
|
}
|
|
|
|
func (db *Executor) migrateCommon(tableName string, typeOf reflect.Type) {
|
|
tableFromCode := db.getTableFromCode(tableName)
|
|
columnsFromCode := db.getColumnsFromCode(typeOf)
|
|
indexsFromCode := db.getIndexsFromCode(typeOf, tableFromCode)
|
|
|
|
//获取数据库名称
|
|
dbNameRows, _ := db.Query("SELECT DATABASE()")
|
|
dbName := dbNameRows[0]["DATABASE()"].(string)
|
|
|
|
//查询表信息,如果找不到就新建
|
|
sql := "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA =" + "'" + dbName + "' AND TABLE_NAME =" + "'" + tableName + "'"
|
|
dataList, _ := db.Query(sql)
|
|
if len(dataList) != 0 {
|
|
tableFromDb := getTableFromDb(dataList)
|
|
columnsFromDb := db.getColumnsFromDb(dbName, tableName)
|
|
indexsFromDb := db.getIndexsFromDb(tableName)
|
|
|
|
db.modifyTable(tableFromCode, columnsFromCode, indexsFromCode, tableFromDb, columnsFromDb, indexsFromDb)
|
|
} else {
|
|
db.createTable(tableFromCode, columnsFromCode, indexsFromCode)
|
|
}
|
|
}
|
|
|
|
func (db *Executor) getTableFromCode(tableName string) Table {
|
|
var tableFromCode Table
|
|
tableFromCode.TableName = tableName
|
|
tableFromCode.Engine = db.getValFromOpinion("ENGINE", "MyISAM")
|
|
tableFromCode.Comment = db.getValFromOpinion("COMMENT", "")
|
|
|
|
return tableFromCode
|
|
}
|
|
|
|
func (db *Executor) getColumnsFromCode(typeOf reflect.Type) []Column {
|
|
var columnsFromCode []Column
|
|
for i := 0; i < typeOf.Elem().NumField(); i++ {
|
|
fieldName := UnderLine(typeOf.Elem().Field(i).Name)
|
|
fieldType := typeOf.Elem().Field(i).Type.Name()
|
|
fieldMap := getTagMap(typeOf.Elem().Field(i).Tag.Get("aorm"))
|
|
columnsFromCode = append(columnsFromCode, getColumnFromCode(fieldName, fieldType, fieldMap))
|
|
}
|
|
|
|
return columnsFromCode
|
|
}
|
|
|
|
func (db *Executor) getIndexsFromCode(typeOf reflect.Type, tableFromCode Table) []Index {
|
|
var indexsFromCode []Index
|
|
for i := 0; i < typeOf.Elem().NumField(); i++ {
|
|
fieldName := UnderLine(typeOf.Elem().Field(i).Name)
|
|
fieldMap := getTagMap(typeOf.Elem().Field(i).Tag.Get("aorm"))
|
|
|
|
_, primaryIs := fieldMap["primary"]
|
|
if primaryIs {
|
|
indexsFromCode = append(indexsFromCode, Index{
|
|
NonUnique: 0,
|
|
ColumnName: fieldName,
|
|
KeyName: "PRIMARY",
|
|
})
|
|
}
|
|
|
|
_, uniqueIndexIs := fieldMap["unique"]
|
|
if uniqueIndexIs {
|
|
indexsFromCode = append(indexsFromCode, Index{
|
|
NonUnique: 0,
|
|
ColumnName: fieldName,
|
|
KeyName: "idx_" + tableFromCode.TableName + "_" + fieldName,
|
|
})
|
|
}
|
|
|
|
_, indexIs := fieldMap["index"]
|
|
if indexIs {
|
|
indexsFromCode = append(indexsFromCode, Index{
|
|
NonUnique: 1,
|
|
ColumnName: fieldName,
|
|
KeyName: "idx_" + tableFromCode.TableName + "_" + fieldName,
|
|
})
|
|
}
|
|
}
|
|
|
|
return indexsFromCode
|
|
}
|
|
|
|
func getTableFromDb(dataList []map[string]interface{}) Table {
|
|
var tableFromDb Table
|
|
tableFromDb.TableName = fmt.Sprintf("%v", dataList[0]["TABLE_NAME"])
|
|
tableFromDb.Engine = fmt.Sprintf("%v", dataList[0]["ENGINE"])
|
|
tableFromDb.Comment = "'" + fmt.Sprintf("%v", dataList[0]["TABLE_COMMENT"]) + "'"
|
|
|
|
return tableFromDb
|
|
}
|
|
|
|
func (db *Executor) getColumnsFromDb(dbName string, tableName string) []Column {
|
|
var columnsFromDb []Column
|
|
|
|
sqlColumn := "SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA =" + "'" + dbName + "' AND TABLE_NAME =" + "'" + tableName + "'"
|
|
dataColumn, _ := db.Query(sqlColumn)
|
|
|
|
for j := 0; j < len(dataColumn); j++ {
|
|
maxLength, _ := strconv.Atoi(fmt.Sprintf("%v", dataColumn[j]["CHARACTER_MAXIMUM_LENGTH"]))
|
|
defaultVal := ""
|
|
if dataColumn[j]["COLUMN_DEFAULT"] != nil {
|
|
defaultVal = dataColumn[j]["COLUMN_DEFAULT"].(string)
|
|
}
|
|
|
|
columnsFromDb = append(columnsFromDb, Column{
|
|
ColumnName: dataColumn[j]["COLUMN_NAME"].(string),
|
|
DataType: dataColumn[j]["DATA_TYPE"].(string),
|
|
IsNullable: dataColumn[j]["IS_NULLABLE"].(string),
|
|
MaxLength: maxLength,
|
|
ColumnType: dataColumn[j]["COLUMN_TYPE"].(string),
|
|
ColumnComment: dataColumn[j]["COLUMN_COMMENT"].(string),
|
|
Extra: dataColumn[j]["EXTRA"].(string),
|
|
DefaultVal: defaultVal,
|
|
})
|
|
}
|
|
|
|
return columnsFromDb
|
|
}
|
|
|
|
func (db *Executor) getIndexsFromDb(tableName string) []Index {
|
|
sqlIndex := "SHOW INDEXES FROM " + tableName
|
|
dataIndex, _ := db.Query(sqlIndex)
|
|
|
|
var indexsFromDb []Index
|
|
for j := 0; j < len(dataIndex); j++ {
|
|
nonUnique, _ := strconv.Atoi(fmt.Sprintf("%v", dataIndex[j]["Non_unique"]))
|
|
indexsFromDb = append(indexsFromDb, Index{
|
|
ColumnName: fmt.Sprintf("%v", dataIndex[j]["Column_name"]),
|
|
KeyName: fmt.Sprintf("%v", dataIndex[j]["Key_name"]),
|
|
NonUnique: nonUnique,
|
|
})
|
|
}
|
|
|
|
return indexsFromDb
|
|
}
|
|
|
|
// 修改表
|
|
func (db *Executor) modifyTable(tableFromCode Table, columnsFromCode []Column, indexsFromCode []Index, tableFromDb Table, columnsFromDb []Column, indexsFromDb []Index) {
|
|
//fmt.Println("正在修改表" + tableFromCode.TableName)
|
|
//fmt.Println(columnsFromCode)
|
|
//fmt.Println(columnsFromDb)
|
|
|
|
//fmt.Println(indexsFromCode)
|
|
//fmt.Println(indexsFromDb)
|
|
|
|
if tableFromCode.Engine != tableFromDb.Engine {
|
|
sql := "ALTER TABLE " + tableFromCode.TableName + " Engine " + tableFromCode.Engine
|
|
_, err := db.Exec(sql)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
} else {
|
|
fmt.Println("修改表:" + sql)
|
|
}
|
|
}
|
|
|
|
if tableFromCode.Comment != tableFromDb.Comment {
|
|
sql := "ALTER TABLE " + tableFromCode.TableName + " Comment " + tableFromCode.Comment
|
|
_, err := db.Exec(sql)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
} else {
|
|
fmt.Println("修改表:" + sql)
|
|
}
|
|
}
|
|
|
|
for i := 0; i < len(columnsFromCode); i++ {
|
|
isFind := 0
|
|
columnCode := columnsFromCode[i]
|
|
|
|
for j := 0; j < len(columnsFromDb); j++ {
|
|
columnDb := columnsFromDb[j]
|
|
if columnCode.ColumnName == columnDb.ColumnName {
|
|
isFind = 1
|
|
|
|
if columnCode.ColumnType != columnDb.ColumnType || columnCode.ColumnComment != columnDb.ColumnComment || columnCode.Extra != columnDb.Extra || columnCode.DefaultVal != columnDb.DefaultVal {
|
|
sql := "ALTER TABLE " + tableFromCode.TableName + " MODIFY " + getColumnStr(columnCode)
|
|
_, err := db.Exec(sql)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
} else {
|
|
fmt.Println("修改属性:" + sql)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if isFind == 0 {
|
|
sql := "ALTER TABLE " + tableFromCode.TableName + " ADD " + getColumnStr(columnCode)
|
|
_, err := db.Exec(sql)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
} else {
|
|
fmt.Println("增加属性:" + sql)
|
|
}
|
|
}
|
|
}
|
|
|
|
for i := 0; i < len(indexsFromCode); i++ {
|
|
isFind := 0
|
|
indexCode := indexsFromCode[i]
|
|
|
|
for j := 0; j < len(indexsFromDb); j++ {
|
|
indexDb := indexsFromDb[j]
|
|
if indexCode.ColumnName == indexDb.ColumnName {
|
|
isFind = 1
|
|
if indexCode.KeyName != indexDb.KeyName || indexCode.NonUnique != indexDb.NonUnique {
|
|
sql := "ALTER TABLE " + tableFromCode.TableName + " MODIFY " + getIndexStr(indexCode)
|
|
_, err := db.Exec(sql)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
} else {
|
|
fmt.Println("修改索引:" + sql)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if isFind == 0 {
|
|
sql := "ALTER TABLE " + tableFromCode.TableName + " ADD " + getIndexStr(indexCode)
|
|
_, err := db.Exec(sql)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
} else {
|
|
fmt.Println("增加索引:" + sql)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// 创建表
|
|
func (db *Executor) createTable(tableFromCode Table, columnsFromCode []Column, indexsFromCode []Index) {
|
|
var fieldArr []string
|
|
|
|
for i := 0; i < len(columnsFromCode); i++ {
|
|
column := columnsFromCode[i]
|
|
fieldArr = append(fieldArr, getColumnStr(column))
|
|
}
|
|
|
|
for i := 0; i < len(indexsFromCode); i++ {
|
|
index := indexsFromCode[i]
|
|
fieldArr = append(fieldArr, getIndexStr(index))
|
|
}
|
|
|
|
sqlStr := "CREATE TABLE `" + tableFromCode.TableName + "` (\n" + strings.Join(fieldArr, ",\n") + "\n) " + getTableInfoFromCode(tableFromCode) + ";"
|
|
|
|
res, err := db.Exec(sqlStr)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
} else {
|
|
fmt.Println(res.RowsAffected())
|
|
}
|
|
}
|
|
|
|
//
|
|
func (db *Executor) getValFromOpinion(key string, def string) string {
|
|
for i := 0; i < len(db.OpinionList); i++ {
|
|
opinionItem := db.OpinionList[i]
|
|
if opinionItem.Key == key {
|
|
def = opinionItem.Val
|
|
}
|
|
}
|
|
return def
|
|
}
|
|
|
|
func getTableInfoFromCode(tableFromCode Table) string {
|
|
return " ENGINE " + tableFromCode.Engine + " COMMENT " + tableFromCode.Comment
|
|
}
|
|
|
|
// 获得某列的结构
|
|
func getColumnFromCode(fieldName string, fieldType string, fieldMap map[string]string) Column {
|
|
var column Column
|
|
|
|
//字段名
|
|
column.ColumnName = fieldName
|
|
//字段数据类型
|
|
column.DataType = getDataType(fieldType, fieldMap)
|
|
//字段数据长度
|
|
maxLength := getMaxLength(column.DataType, fieldMap)
|
|
columnType := column.DataType
|
|
if maxLength > 0 {
|
|
columnType = columnType + "(" + strconv.Itoa(maxLength) + ")"
|
|
}
|
|
column.MaxLength = maxLength
|
|
//字段是否可以为空
|
|
column.IsNullable = getNullAble(fieldMap)
|
|
//字段注释
|
|
column.ColumnComment = getComment(fieldMap)
|
|
//字段类型
|
|
column.ColumnType = columnType
|
|
//扩展信息
|
|
column.Extra = getExtra(fieldMap)
|
|
//默认信息
|
|
column.DefaultVal = getDefaultVal(fieldMap)
|
|
|
|
return column
|
|
}
|
|
|
|
// 转换tag成map
|
|
func getTagMap(fieldTag string) map[string]string {
|
|
var fieldMap = make(map[string]string)
|
|
if "" != fieldTag {
|
|
tagArr := strings.Split(fieldTag, ";")
|
|
for j := 0; j < len(tagArr); j++ {
|
|
tagArrArr := strings.Split(tagArr[j], ":")
|
|
fieldMap[tagArrArr[0]] = ""
|
|
if len(tagArrArr) > 1 {
|
|
fieldMap[tagArrArr[0]] = tagArrArr[1]
|
|
}
|
|
}
|
|
}
|
|
return fieldMap
|
|
}
|
|
|
|
func getColumnStr(column Column) string {
|
|
var strArr []string
|
|
strArr = append(strArr, column.ColumnName)
|
|
strArr = append(strArr, column.DataType+"("+strconv.Itoa(column.MaxLength)+")")
|
|
|
|
if column.DefaultVal != "" {
|
|
strArr = append(strArr, "DEFAULT '"+column.DefaultVal+"'")
|
|
}
|
|
|
|
if column.IsNullable == "NO" {
|
|
strArr = append(strArr, "NOT NULL")
|
|
}
|
|
|
|
if column.ColumnComment != "" {
|
|
strArr = append(strArr, "COMMENT '"+column.ColumnComment+"'")
|
|
}
|
|
|
|
if column.Extra != "" {
|
|
strArr = append(strArr, column.Extra)
|
|
}
|
|
|
|
return strings.Join(strArr, " ")
|
|
}
|
|
|
|
func getIndexStr(index Index) string {
|
|
var strArr []string
|
|
|
|
if "PRIMARY" == index.KeyName {
|
|
strArr = append(strArr, index.KeyName)
|
|
strArr = append(strArr, "KEY")
|
|
strArr = append(strArr, "(`"+index.ColumnName+"`)")
|
|
} else {
|
|
if 0 == index.NonUnique {
|
|
strArr = append(strArr, "Unique")
|
|
strArr = append(strArr, index.KeyName)
|
|
strArr = append(strArr, "(`"+index.ColumnName+"`)")
|
|
} else {
|
|
strArr = append(strArr, "Index")
|
|
strArr = append(strArr, index.KeyName)
|
|
strArr = append(strArr, "(`"+index.ColumnName+"`)")
|
|
}
|
|
}
|
|
|
|
return strings.Join(strArr, " ")
|
|
}
|
|
|
|
//将对象属性类型转换数据库字段数据类型
|
|
func getDataType(fieldType string, fieldMap map[string]string) string {
|
|
var DataType string
|
|
|
|
dataTypeVal, dataTypeOk := fieldMap["type"]
|
|
if dataTypeOk {
|
|
DataType = dataTypeVal
|
|
} else {
|
|
if "Int" == fieldType {
|
|
DataType = "int"
|
|
}
|
|
if "String" == fieldType {
|
|
DataType = "varchar"
|
|
}
|
|
if "Bool" == fieldType {
|
|
DataType = "tinyint"
|
|
}
|
|
if "Time" == fieldType {
|
|
DataType = "datetime"
|
|
}
|
|
if "Float" == fieldType {
|
|
DataType = "float"
|
|
}
|
|
}
|
|
|
|
return DataType
|
|
}
|
|
|
|
func getMaxLength(DataType string, fieldMap map[string]string) int {
|
|
var MaxLength int
|
|
|
|
maxLengthVal, maxLengthOk := fieldMap["size"]
|
|
if maxLengthOk {
|
|
num, _ := strconv.Atoi(maxLengthVal)
|
|
MaxLength = num
|
|
} else {
|
|
if "bigint" == DataType {
|
|
MaxLength = 20
|
|
}
|
|
if "int" == DataType {
|
|
MaxLength = 11
|
|
}
|
|
if "tinyint" == DataType {
|
|
MaxLength = 4
|
|
}
|
|
if "varchar" == DataType {
|
|
MaxLength = 255
|
|
}
|
|
if "text" == DataType {
|
|
MaxLength = 0
|
|
}
|
|
if "datetime" == DataType {
|
|
MaxLength = 0
|
|
}
|
|
if "float" == DataType {
|
|
MaxLength = 0
|
|
}
|
|
}
|
|
|
|
return MaxLength
|
|
}
|
|
|
|
func getNullAble(fieldMap map[string]string) string {
|
|
var IsNullable string
|
|
|
|
_, ok := fieldMap["not null"]
|
|
if ok {
|
|
IsNullable = "NO"
|
|
} else {
|
|
IsNullable = "YES"
|
|
}
|
|
|
|
return IsNullable
|
|
}
|
|
|
|
func getComment(fieldMap map[string]string) string {
|
|
commentVal, commentIs := fieldMap["comment"]
|
|
if commentIs {
|
|
return commentVal
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
func getExtra(fieldMap map[string]string) string {
|
|
_, commentIs := fieldMap["auto_increment"]
|
|
if commentIs {
|
|
return "auto_increment"
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
func getDefaultVal(fieldMap map[string]string) string {
|
|
defaultVal, defaultIs := fieldMap["default"]
|
|
if defaultIs {
|
|
return defaultVal
|
|
}
|
|
|
|
return ""
|
|
}
|