Files
aorm/migrate.go
tangpanqing 64765048e4 init
2022-11-29 16:36:06 +08:00

528 lines
13 KiB
Go

package aorm
import (
"fmt"
"reflect"
"strconv"
"strings"
)
type Table struct {
TableName string
Engine string
Comment string
}
type Column struct {
ColumnName string
ColumnDefault string
IsNullable string
DataType string //数据类型 varchar,bigint,int
MaxLength int //数据最大长度 20
ColumnType string //列类型 varchar(20)
ColumnComment string
Extra string //扩展信息 auto_increment
DefaultVal string //默认值
}
type Index struct {
NonUnique int
ColumnName string
KeyName string
}
type OpinionItem struct {
Key string
Val string
}
func (db *Executor) Opinion(key string, val string) *Executor {
if key == "COMMENT" {
val = "'" + val + "'"
}
db.OpinionList = append(db.OpinionList, OpinionItem{Key: key, Val: val})
return db
}
func (db *Executor) ShowCreateTable(tableName string) string {
list, _ := db.Query("show create table " + tableName)
return list[0]["Create Table"].(string)
}
// Migrate 迁移数据库结构,需要输入数据库名,表名自动获取
func (db *Executor) AutoMigrate(dest interface{}) {
typeOf := reflect.TypeOf(dest)
arr := strings.Split(typeOf.String(), ".")
tableName := UnderLine(arr[len(arr)-1])
db.migrateCommon(tableName, typeOf)
}
// AutoMigrate 自动迁移数据库结构,需要输入数据库名,表名
func (db *Executor) Migrate(tableName string, dest interface{}) {
typeOf := reflect.TypeOf(dest)
db.migrateCommon(tableName, typeOf)
}
func (db *Executor) migrateCommon(tableName string, typeOf reflect.Type) {
tableFromCode := db.getTableFromCode(tableName)
columnsFromCode := db.getColumnsFromCode(typeOf)
indexsFromCode := db.getIndexsFromCode(typeOf, tableFromCode)
//获取数据库名称
dbNameRows, _ := db.Query("SELECT DATABASE()")
dbName := dbNameRows[0]["DATABASE()"].(string)
//查询表信息,如果找不到就新建
sql := "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA =" + "'" + dbName + "' AND TABLE_NAME =" + "'" + tableName + "'"
dataList, _ := db.Query(sql)
if len(dataList) != 0 {
tableFromDb := getTableFromDb(dataList)
columnsFromDb := db.getColumnsFromDb(dbName, tableName)
indexsFromDb := db.getIndexsFromDb(tableName)
db.modifyTable(tableFromCode, columnsFromCode, indexsFromCode, tableFromDb, columnsFromDb, indexsFromDb)
} else {
db.createTable(tableFromCode, columnsFromCode, indexsFromCode)
}
}
func (db *Executor) getTableFromCode(tableName string) Table {
var tableFromCode Table
tableFromCode.TableName = tableName
tableFromCode.Engine = db.getValFromOpinion("ENGINE", "MyISAM")
tableFromCode.Comment = db.getValFromOpinion("COMMENT", "")
return tableFromCode
}
func (db *Executor) getColumnsFromCode(typeOf reflect.Type) []Column {
var columnsFromCode []Column
for i := 0; i < typeOf.Elem().NumField(); i++ {
fieldName := UnderLine(typeOf.Elem().Field(i).Name)
fieldType := typeOf.Elem().Field(i).Type.Name()
fieldMap := getTagMap(typeOf.Elem().Field(i).Tag.Get("aorm"))
columnsFromCode = append(columnsFromCode, getColumnFromCode(fieldName, fieldType, fieldMap))
}
return columnsFromCode
}
func (db *Executor) getIndexsFromCode(typeOf reflect.Type, tableFromCode Table) []Index {
var indexsFromCode []Index
for i := 0; i < typeOf.Elem().NumField(); i++ {
fieldName := UnderLine(typeOf.Elem().Field(i).Name)
fieldMap := getTagMap(typeOf.Elem().Field(i).Tag.Get("aorm"))
_, primaryIs := fieldMap["primary"]
if primaryIs {
indexsFromCode = append(indexsFromCode, Index{
NonUnique: 0,
ColumnName: fieldName,
KeyName: "PRIMARY",
})
}
_, uniqueIndexIs := fieldMap["unique"]
if uniqueIndexIs {
indexsFromCode = append(indexsFromCode, Index{
NonUnique: 0,
ColumnName: fieldName,
KeyName: "idx_" + tableFromCode.TableName + "_" + fieldName,
})
}
_, indexIs := fieldMap["index"]
if indexIs {
indexsFromCode = append(indexsFromCode, Index{
NonUnique: 1,
ColumnName: fieldName,
KeyName: "idx_" + tableFromCode.TableName + "_" + fieldName,
})
}
}
return indexsFromCode
}
func getTableFromDb(dataList []map[string]interface{}) Table {
var tableFromDb Table
tableFromDb.TableName = fmt.Sprintf("%v", dataList[0]["TABLE_NAME"])
tableFromDb.Engine = fmt.Sprintf("%v", dataList[0]["ENGINE"])
tableFromDb.Comment = "'" + fmt.Sprintf("%v", dataList[0]["TABLE_COMMENT"]) + "'"
return tableFromDb
}
func (db *Executor) getColumnsFromDb(dbName string, tableName string) []Column {
var columnsFromDb []Column
sqlColumn := "SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA =" + "'" + dbName + "' AND TABLE_NAME =" + "'" + tableName + "'"
dataColumn, _ := db.Query(sqlColumn)
for j := 0; j < len(dataColumn); j++ {
maxLength, _ := strconv.Atoi(fmt.Sprintf("%v", dataColumn[j]["CHARACTER_MAXIMUM_LENGTH"]))
defaultVal := ""
if dataColumn[j]["COLUMN_DEFAULT"] != nil {
defaultVal = dataColumn[j]["COLUMN_DEFAULT"].(string)
}
columnsFromDb = append(columnsFromDb, Column{
ColumnName: dataColumn[j]["COLUMN_NAME"].(string),
DataType: dataColumn[j]["DATA_TYPE"].(string),
IsNullable: dataColumn[j]["IS_NULLABLE"].(string),
MaxLength: maxLength,
ColumnType: dataColumn[j]["COLUMN_TYPE"].(string),
ColumnComment: dataColumn[j]["COLUMN_COMMENT"].(string),
Extra: dataColumn[j]["EXTRA"].(string),
DefaultVal: defaultVal,
})
}
return columnsFromDb
}
func (db *Executor) getIndexsFromDb(tableName string) []Index {
sqlIndex := "SHOW INDEXES FROM " + tableName
dataIndex, _ := db.Query(sqlIndex)
var indexsFromDb []Index
for j := 0; j < len(dataIndex); j++ {
nonUnique, _ := strconv.Atoi(fmt.Sprintf("%v", dataIndex[j]["Non_unique"]))
indexsFromDb = append(indexsFromDb, Index{
ColumnName: fmt.Sprintf("%v", dataIndex[j]["Column_name"]),
KeyName: fmt.Sprintf("%v", dataIndex[j]["Key_name"]),
NonUnique: nonUnique,
})
}
return indexsFromDb
}
// 修改表
func (db *Executor) modifyTable(tableFromCode Table, columnsFromCode []Column, indexsFromCode []Index, tableFromDb Table, columnsFromDb []Column, indexsFromDb []Index) {
//fmt.Println("正在修改表" + tableFromCode.TableName)
//fmt.Println(columnsFromCode)
//fmt.Println(columnsFromDb)
//fmt.Println(indexsFromCode)
//fmt.Println(indexsFromDb)
if tableFromCode.Engine != tableFromDb.Engine {
sql := "ALTER TABLE " + tableFromCode.TableName + " Engine " + tableFromCode.Engine
_, err := db.Exec(sql)
if err != nil {
fmt.Println(err)
} else {
fmt.Println("修改表:" + sql)
}
}
if tableFromCode.Comment != tableFromDb.Comment {
sql := "ALTER TABLE " + tableFromCode.TableName + " Comment " + tableFromCode.Comment
_, err := db.Exec(sql)
if err != nil {
fmt.Println(err)
} else {
fmt.Println("修改表:" + sql)
}
}
for i := 0; i < len(columnsFromCode); i++ {
isFind := 0
columnCode := columnsFromCode[i]
for j := 0; j < len(columnsFromDb); j++ {
columnDb := columnsFromDb[j]
if columnCode.ColumnName == columnDb.ColumnName {
isFind = 1
if columnCode.ColumnType != columnDb.ColumnType || columnCode.ColumnComment != columnDb.ColumnComment || columnCode.Extra != columnDb.Extra || columnCode.DefaultVal != columnDb.DefaultVal {
sql := "ALTER TABLE " + tableFromCode.TableName + " MODIFY " + getColumnStr(columnCode)
_, err := db.Exec(sql)
if err != nil {
fmt.Println(err)
} else {
fmt.Println("修改属性:" + sql)
}
}
}
}
if isFind == 0 {
sql := "ALTER TABLE " + tableFromCode.TableName + " ADD " + getColumnStr(columnCode)
_, err := db.Exec(sql)
if err != nil {
fmt.Println(err)
} else {
fmt.Println("增加属性:" + sql)
}
}
}
for i := 0; i < len(indexsFromCode); i++ {
isFind := 0
indexCode := indexsFromCode[i]
for j := 0; j < len(indexsFromDb); j++ {
indexDb := indexsFromDb[j]
if indexCode.ColumnName == indexDb.ColumnName {
isFind = 1
if indexCode.KeyName != indexDb.KeyName || indexCode.NonUnique != indexDb.NonUnique {
sql := "ALTER TABLE " + tableFromCode.TableName + " MODIFY " + getIndexStr(indexCode)
_, err := db.Exec(sql)
if err != nil {
fmt.Println(err)
} else {
fmt.Println("修改索引:" + sql)
}
}
}
}
if isFind == 0 {
sql := "ALTER TABLE " + tableFromCode.TableName + " ADD " + getIndexStr(indexCode)
_, err := db.Exec(sql)
if err != nil {
fmt.Println(err)
} else {
fmt.Println("增加索引:" + sql)
}
}
}
}
// 创建表
func (db *Executor) createTable(tableFromCode Table, columnsFromCode []Column, indexsFromCode []Index) {
var fieldArr []string
for i := 0; i < len(columnsFromCode); i++ {
column := columnsFromCode[i]
fieldArr = append(fieldArr, getColumnStr(column))
}
for i := 0; i < len(indexsFromCode); i++ {
index := indexsFromCode[i]
fieldArr = append(fieldArr, getIndexStr(index))
}
sqlStr := "CREATE TABLE `" + tableFromCode.TableName + "` (\n" + strings.Join(fieldArr, ",\n") + "\n) " + getTableInfoFromCode(tableFromCode) + ";"
res, err := db.Exec(sqlStr)
if err != nil {
fmt.Println(err)
} else {
fmt.Println(res.RowsAffected())
}
}
//
func (db *Executor) getValFromOpinion(key string, def string) string {
for i := 0; i < len(db.OpinionList); i++ {
opinionItem := db.OpinionList[i]
if opinionItem.Key == key {
def = opinionItem.Val
}
}
return def
}
func getTableInfoFromCode(tableFromCode Table) string {
return " ENGINE " + tableFromCode.Engine + " COMMENT " + tableFromCode.Comment
}
// 获得某列的结构
func getColumnFromCode(fieldName string, fieldType string, fieldMap map[string]string) Column {
var column Column
//字段名
column.ColumnName = fieldName
//字段数据类型
column.DataType = getDataType(fieldType, fieldMap)
//字段数据长度
maxLength := getMaxLength(column.DataType, fieldMap)
columnType := column.DataType
if maxLength > 0 {
columnType = columnType + "(" + strconv.Itoa(maxLength) + ")"
}
column.MaxLength = maxLength
//字段是否可以为空
column.IsNullable = getNullAble(fieldMap)
//字段注释
column.ColumnComment = getComment(fieldMap)
//字段类型
column.ColumnType = columnType
//扩展信息
column.Extra = getExtra(fieldMap)
//默认信息
column.DefaultVal = getDefaultVal(fieldMap)
return column
}
// 转换tag成map
func getTagMap(fieldTag string) map[string]string {
var fieldMap = make(map[string]string)
if "" != fieldTag {
tagArr := strings.Split(fieldTag, ";")
for j := 0; j < len(tagArr); j++ {
tagArrArr := strings.Split(tagArr[j], ":")
fieldMap[tagArrArr[0]] = ""
if len(tagArrArr) > 1 {
fieldMap[tagArrArr[0]] = tagArrArr[1]
}
}
}
return fieldMap
}
func getColumnStr(column Column) string {
var strArr []string
strArr = append(strArr, column.ColumnName)
strArr = append(strArr, column.DataType+"("+strconv.Itoa(column.MaxLength)+")")
if column.DefaultVal != "" {
strArr = append(strArr, "DEFAULT '"+column.DefaultVal+"'")
}
if column.IsNullable == "NO" {
strArr = append(strArr, "NOT NULL")
}
if column.ColumnComment != "" {
strArr = append(strArr, "COMMENT '"+column.ColumnComment+"'")
}
if column.Extra != "" {
strArr = append(strArr, column.Extra)
}
return strings.Join(strArr, " ")
}
func getIndexStr(index Index) string {
var strArr []string
if "PRIMARY" == index.KeyName {
strArr = append(strArr, index.KeyName)
strArr = append(strArr, "KEY")
strArr = append(strArr, "(`"+index.ColumnName+"`)")
} else {
if 0 == index.NonUnique {
strArr = append(strArr, "Unique")
strArr = append(strArr, index.KeyName)
strArr = append(strArr, "(`"+index.ColumnName+"`)")
} else {
strArr = append(strArr, "Index")
strArr = append(strArr, index.KeyName)
strArr = append(strArr, "(`"+index.ColumnName+"`)")
}
}
return strings.Join(strArr, " ")
}
//将对象属性类型转换数据库字段数据类型
func getDataType(fieldType string, fieldMap map[string]string) string {
var DataType string
dataTypeVal, dataTypeOk := fieldMap["type"]
if dataTypeOk {
DataType = dataTypeVal
} else {
if "Int" == fieldType {
DataType = "int"
}
if "String" == fieldType {
DataType = "varchar"
}
if "Bool" == fieldType {
DataType = "tinyint"
}
if "Time" == fieldType {
DataType = "datetime"
}
if "Float" == fieldType {
DataType = "float"
}
}
return DataType
}
func getMaxLength(DataType string, fieldMap map[string]string) int {
var MaxLength int
maxLengthVal, maxLengthOk := fieldMap["size"]
if maxLengthOk {
num, _ := strconv.Atoi(maxLengthVal)
MaxLength = num
} else {
if "bigint" == DataType {
MaxLength = 20
}
if "int" == DataType {
MaxLength = 11
}
if "tinyint" == DataType {
MaxLength = 4
}
if "varchar" == DataType {
MaxLength = 255
}
if "text" == DataType {
MaxLength = 0
}
if "datetime" == DataType {
MaxLength = 0
}
if "float" == DataType {
MaxLength = 0
}
}
return MaxLength
}
func getNullAble(fieldMap map[string]string) string {
var IsNullable string
_, ok := fieldMap["not null"]
if ok {
IsNullable = "NO"
} else {
IsNullable = "YES"
}
return IsNullable
}
func getComment(fieldMap map[string]string) string {
commentVal, commentIs := fieldMap["comment"]
if commentIs {
return commentVal
}
return ""
}
func getExtra(fieldMap map[string]string) string {
_, commentIs := fieldMap["auto_increment"]
if commentIs {
return "auto_increment"
}
return ""
}
func getDefaultVal(fieldMap map[string]string) string {
defaultVal, defaultIs := fieldMap["default"]
if defaultIs {
return defaultVal
}
return ""
}