support mssql

This commit is contained in:
tangpanqing
2022-12-25 16:39:59 +08:00
parent 3d7ad79980
commit a88fc51a4a
5 changed files with 135 additions and 61 deletions

BIN
README.md

Binary file not shown.

Binary file not shown.

View File

@@ -65,9 +65,9 @@ func (mm *MigrateExecutor) MigrateCommon(tableName string, typeOf reflect.Type)
if len(tablesFromDb) != 0 { if len(tablesFromDb) != 0 {
tableFromDb := tablesFromDb[0] tableFromDb := tablesFromDb[0]
columnsFromDb := mm.getColumnsFromDb(dbName, tableName) columnsFromDb := mm.getColumnsFromDb(dbName, tableName)
indexsFromDb := mm.getIndexesFromDb(tableName) indexesFromDb := mm.getIndexesFromDb(tableName)
mm.modifyTable(tableFromCode, columnsFromCode, indexesFromCode, tableFromDb, columnsFromDb, indexsFromDb) mm.modifyTable(tableFromCode, columnsFromCode, indexesFromCode, tableFromDb, columnsFromDb, indexesFromDb)
} else { } else {
mm.createTable(tableFromCode, columnsFromCode, indexesFromCode) mm.createTable(tableFromCode, columnsFromCode, indexesFromCode)
} }

View File

@@ -7,10 +7,19 @@ import (
"github.com/tangpanqing/aorm/model" "github.com/tangpanqing/aorm/model"
"github.com/tangpanqing/aorm/null" "github.com/tangpanqing/aorm/null"
"reflect" "reflect"
"regexp"
"strconv" "strconv"
"strings" "strings"
) )
type PgIndexes struct {
Schemaname null.String
Tablename null.String
Indexname null.String
Tablespace null.String
Indexdef null.String
}
type Table struct { type Table struct {
TableName null.String TableName null.String
TableComment null.String TableComment null.String
@@ -61,7 +70,6 @@ func (mm *MigrateExecutor) MigrateCommon(tableName string, typeOf reflect.Type)
if dbErr != nil { if dbErr != nil {
return dbErr return dbErr
} }
fmt.Println("dbName:" + dbName)
tablesFromDb := mm.getTableFromDb(dbName, tableName) tablesFromDb := mm.getTableFromDb(dbName, tableName)
if len(tablesFromDb) != 0 { if len(tablesFromDb) != 0 {
@@ -172,8 +180,12 @@ func (mm *MigrateExecutor) getColumnsFromDb(dbName string, tableName string) []C
mm.Ex.RawSql(sqlColumn).GetMany(&columnsFromDb) mm.Ex.RawSql(sqlColumn).GetMany(&columnsFromDb)
for j := 0; j < len(columnsFromDb); j++ { for j := 0; j < len(columnsFromDb); j++ {
if columnsFromDb[j].DataType.String == "text" && columnsFromDb[j].MaxLength.Int64 == 65535 { if columnsFromDb[j].DataType.String == "character varying" {
columnsFromDb[j].MaxLength = null.IntFrom(0) columnsFromDb[j].DataType = null.StringFrom("varchar")
}
if columnsFromDb[j].DataType.String == "double precision" {
columnsFromDb[j].DataType = null.StringFrom("float")
} }
} }
@@ -181,18 +193,46 @@ func (mm *MigrateExecutor) getColumnsFromDb(dbName string, tableName string) []C
} }
func (mm *MigrateExecutor) getIndexesFromDb(tableName string) []Index { func (mm *MigrateExecutor) getIndexesFromDb(tableName string) []Index {
sqlIndex := "SHOW INDEXES FROM " + tableName sqlIndex := "select * from pg_indexes where tablename=" + "'" + tableName + "'"
var sqliteMasterList []PgIndexes
mm.Ex.RawSql(sqlIndex).GetMany(&sqliteMasterList)
var indexsFromDb []Index var indexesFromDb []Index
mm.Ex.RawSql(sqlIndex).GetMany(&indexsFromDb) for i := 0; i < len(sqliteMasterList); i++ {
indexName := sqliteMasterList[i].Indexname.String
sql := sqliteMasterList[i].Indexdef.String
return indexsFromDb t := 1
if strings.Index(sql, "UNIQUE") != -1 {
t = 0
}
compileRegex := regexp.MustCompile("INDEX\\s(.*?)\\sON.*?\\((.*?)\\)")
matchArr := compileRegex.FindAllStringSubmatch(sql, -1)
//主键索引
if indexName == tableName+"_pkey" {
indexesFromDb = append(indexesFromDb, Index{
NonUnique: null.IntFrom(int64(t)),
ColumnName: null.StringFrom(matchArr[0][2]),
KeyName: null.StringFrom("PRIMARY"),
})
} else {
indexesFromDb = append(indexesFromDb, Index{
NonUnique: null.IntFrom(int64(t)),
ColumnName: null.StringFrom(matchArr[0][2]),
KeyName: null.StringFrom(matchArr[0][1]),
})
}
}
return indexesFromDb
} }
func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Column, indexesFromCode []Index, tableFromDb Table, columnsFromDb []Column, indexesFromDb []Index) { func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Column, indexesFromCode []Index, tableFromDb Table, columnsFromDb []Column, indexesFromDb []Index) {
if tableFromCode.TableComment != tableFromDb.TableComment { //if tableFromCode.TableComment != tableFromDb.TableComment {
mm.modifyTableComment(tableFromCode) // mm.modifyTableComment(tableFromCode)
} //}
for i := 0; i < len(columnsFromCode); i++ { for i := 0; i < len(columnsFromCode); i++ {
isFind := 0 isFind := 0
@@ -200,14 +240,14 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co
for j := 0; j < len(columnsFromDb); j++ { for j := 0; j < len(columnsFromDb); j++ {
columnDb := columnsFromDb[j] columnDb := columnsFromDb[j]
if columnCode.ColumnName == columnDb.ColumnName { if columnCode.ColumnName.String == columnDb.ColumnName.String {
isFind = 1 isFind = 1
if columnCode.DataType.String != columnDb.DataType.String || if columnCode.DataType.String != columnDb.DataType.String {
columnCode.MaxLength.Int64 != columnDb.MaxLength.Int64 || fmt.Println(columnCode.ColumnName.String, columnCode.DataType.String, columnDb.DataType.String)
columnCode.ColumnComment.String != columnDb.ColumnComment.String ||
columnCode.Extra.String != columnDb.Extra.String || sql := "ALTER TABLE " + tableFromCode.TableName.String + " alter COLUMN " + getColumnStr(columnCode, "type")
columnCode.ColumnDefault.String != columnDb.ColumnDefault.String { //fmt.Println(sql)
sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getColumnStr(columnCode)
_, err := mm.Ex.Exec(sql) _, err := mm.Ex.Exec(sql)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
@@ -219,7 +259,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co
} }
if isFind == 0 { if isFind == 0 {
sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getColumnStr(columnCode) sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getColumnStr(columnCode, "")
_, err := mm.Ex.Exec(sql) _, err := mm.Ex.Exec(sql)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
@@ -250,13 +290,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co
} }
if isFind == 0 { if isFind == 0 {
sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getIndexStr(indexCode) mm.createIndex(tableFromCode.TableName.String, indexCode)
_, err := mm.Ex.Exec(sql)
if err != nil {
fmt.Println(err)
} else {
fmt.Println("增加索引:" + sql)
}
} }
} }
} }
@@ -276,15 +310,17 @@ func (mm *MigrateExecutor) createTable(tableFromCode Table, columnsFromCode []Co
for i := 0; i < len(columnsFromCode); i++ { for i := 0; i < len(columnsFromCode); i++ {
column := columnsFromCode[i] column := columnsFromCode[i]
fieldArr = append(fieldArr, getColumnStr(column)) fieldArr = append(fieldArr, getColumnStr(column, ""))
} }
for i := 0; i < len(indexesFromCode); i++ { for i := 0; i < len(indexesFromCode); i++ {
index := indexesFromCode[i] index := indexesFromCode[i]
fieldArr = append(fieldArr, getIndexStr(index)) if index.KeyName.String == "PRIMARY" {
fieldArr = append(fieldArr, "PRIMARY KEY ("+index.ColumnName.String+")")
}
} }
sqlStr := "CREATE TABLE `" + tableFromCode.TableName.String + "` (\n" + strings.Join(fieldArr, ",\n") + "\n) " + " COMMENT " + tableFromCode.TableComment.String + ";" sqlStr := "CREATE TABLE " + tableFromCode.TableName.String + " (\n" + strings.Join(fieldArr, ",\n") + "\n) " + ";"
fmt.Println(sqlStr) fmt.Println(sqlStr)
_, err := mm.Ex.Exec(sqlStr) _, err := mm.Ex.Exec(sqlStr)
@@ -293,6 +329,29 @@ func (mm *MigrateExecutor) createTable(tableFromCode Table, columnsFromCode []Co
} else { } else {
fmt.Println("创建表:" + tableFromCode.TableName.String) fmt.Println("创建表:" + tableFromCode.TableName.String)
} }
//创建其他索引
for i := 0; i < len(indexesFromCode); i++ {
index := indexesFromCode[i]
if index.KeyName.String != "PRIMARY" {
mm.createIndex(tableFromCode.TableName.String, index)
}
}
}
func (mm *MigrateExecutor) createIndex(tableName string, index Index) {
keyType := ""
if index.NonUnique.Int64 == 0 {
keyType = "UNIQUE"
}
sql := "CREATE " + keyType + " INDEX " + index.KeyName.String + " on " + tableName + " (" + index.ColumnName.String + ")"
_, err := mm.Ex.Exec(sql)
if err != nil {
fmt.Println(err)
} else {
fmt.Println("增加索引:" + sql)
}
} }
func (mm *MigrateExecutor) getOpinionVal(key string, def string) string { func (mm *MigrateExecutor) getOpinionVal(key string, def string) string {
@@ -321,17 +380,23 @@ func getTagMap(fieldTag string) map[string]string {
return fieldMap return fieldMap
} }
func getColumnStr(column Column) string { func getColumnStr(column Column, f string) string {
var strArr []string var strArr []string
strArr = append(strArr, column.ColumnName.String) strArr = append(strArr, column.ColumnName.String)
if column.MaxLength.Int64 == 0 {
if column.DataType.String == "varchar" { //类型
strArr = append(strArr, column.DataType.String+"(255)") if column.Extra.String == "auto_increment" {
} else { strArr = append(strArr, "serial")
strArr = append(strArr, column.DataType.String)
}
} else { } else {
strArr = append(strArr, column.DataType.String+"("+strconv.Itoa(int(column.MaxLength.Int64))+")") if column.MaxLength.Int64 == 0 {
if column.DataType.String == "varchar" {
strArr = append(strArr, column.DataType.String+"(255)")
} else {
strArr = append(strArr, f+" "+column.DataType.String)
}
} else {
strArr = append(strArr, column.DataType.String+"("+strconv.Itoa(int(column.MaxLength.Int64))+")")
}
} }
if column.ColumnDefault.String != "" { if column.ColumnDefault.String != "" {
@@ -339,15 +404,15 @@ func getColumnStr(column Column) string {
} }
if column.IsNullable.String == "NO" { if column.IsNullable.String == "NO" {
strArr = append(strArr, "NOT NULL") //strArr = append(strArr, "NOT NULL")
} }
if column.ColumnComment.String != "" { if column.ColumnComment.String != "" {
strArr = append(strArr, "COMMENT '"+column.ColumnComment.String+"'") //strArr = append(strArr, "COMMENT '"+column.ColumnComment.String+"'")
} }
if column.Extra.String != "" { if column.Extra.String != "" {
strArr = append(strArr, column.Extra.String) //strArr = append(strArr, column.Extra.String)
} }
return strings.Join(strArr, " ") return strings.Join(strArr, " ")
@@ -359,16 +424,16 @@ func getIndexStr(index Index) string {
if "PRIMARY" == index.KeyName.String { if "PRIMARY" == index.KeyName.String {
strArr = append(strArr, index.KeyName.String) strArr = append(strArr, index.KeyName.String)
strArr = append(strArr, "KEY") strArr = append(strArr, "KEY")
strArr = append(strArr, "(`"+index.ColumnName.String+"`)") strArr = append(strArr, "("+index.ColumnName.String+")")
} else { } else {
if 0 == index.NonUnique.Int64 { if 0 == index.NonUnique.Int64 {
strArr = append(strArr, "Unique") strArr = append(strArr, "Unique")
strArr = append(strArr, index.KeyName.String) strArr = append(strArr, index.KeyName.String)
strArr = append(strArr, "(`"+index.ColumnName.String+"`)") strArr = append(strArr, "("+index.ColumnName.String+")")
} else { } else {
strArr = append(strArr, "Index") strArr = append(strArr, "Index")
strArr = append(strArr, index.KeyName.String) strArr = append(strArr, index.KeyName.String)
strArr = append(strArr, "(`"+index.ColumnName.String+"`)") strArr = append(strArr, "("+index.ColumnName.String+")")
} }
} }
@@ -381,18 +446,25 @@ func getDataType(fieldType string, fieldMap map[string]string) string {
dataTypeVal, dataTypeOk := fieldMap["type"] dataTypeVal, dataTypeOk := fieldMap["type"]
if dataTypeOk { if dataTypeOk {
DataType = dataTypeVal DataType = dataTypeVal
if "tinyint" == DataType {
DataType = "integer"
}
if "double" == DataType {
DataType = "float"
}
} else { } else {
if "Int" == fieldType { if "Int" == fieldType {
DataType = "int" DataType = "integer"
} }
if "String" == fieldType { if "String" == fieldType {
DataType = "varchar" DataType = "varchar"
} }
if "Bool" == fieldType { if "Bool" == fieldType {
DataType = "tinyint" //DataType = "tinyint"
DataType = "boolean"
} }
if "Time" == fieldType { if "Time" == fieldType {
DataType = "datetime" DataType = "date"
} }
if "Float" == fieldType { if "Float" == fieldType {
DataType = "float" DataType = "float"

View File

@@ -12,7 +12,6 @@ import (
"github.com/tangpanqing/aorm/helper" "github.com/tangpanqing/aorm/helper"
"github.com/tangpanqing/aorm/null" "github.com/tangpanqing/aorm/null"
"testing" "testing"
"time"
) )
type Article struct { type Article struct {
@@ -36,7 +35,7 @@ type Person struct {
Sex null.Bool `aorm:"index;comment:性别" json:"sex"` Sex null.Bool `aorm:"index;comment:性别" json:"sex"`
Age null.Int `aorm:"index;comment:年龄" json:"age"` Age null.Int `aorm:"index;comment:年龄" json:"age"`
Type null.Int `aorm:"index;comment:类型" json:"type"` Type null.Int `aorm:"index;comment:类型" json:"type"`
CreateTime null.Time `aorm:"comment:创建时间" json:"createTime"` CreateTime null.Int `aorm:"comment:创建时间" json:"createTime"`
Money null.Float `aorm:"comment:金额" json:"money"` Money null.Float `aorm:"comment:金额" json:"money"`
Test null.Float `aorm:"type:double;comment:测试" json:"test"` Test null.Float `aorm:"type:double;comment:测试" json:"test"`
} }
@@ -52,7 +51,7 @@ type PersonWithArticleCount struct {
Sex null.Bool `aorm:"index;comment:性别" json:"sex"` Sex null.Bool `aorm:"index;comment:性别" json:"sex"`
Age null.Int `aorm:"index;comment:年龄" json:"age"` Age null.Int `aorm:"index;comment:年龄" json:"age"`
Type null.Int `aorm:"index;comment:类型" json:"type"` Type null.Int `aorm:"index;comment:类型" json:"type"`
CreateTime null.Time `aorm:"comment:创建时间" json:"createTime"` CreateTime null.Int `aorm:"comment:创建时间" json:"createTime"`
Money null.Float `aorm:"comment:金额" json:"money"` Money null.Float `aorm:"comment:金额" json:"money"`
Test null.Float `aorm:"type:double;comment:测试" json:"test"` Test null.Float `aorm:"type:double;comment:测试" json:"test"`
ArticleCount null.Int `aorm:"comment:文章数量" json:"articleCount"` ArticleCount null.Int `aorm:"comment:文章数量" json:"articleCount"`
@@ -60,10 +59,10 @@ type PersonWithArticleCount struct {
func TestAll(t *testing.T) { func TestAll(t *testing.T) {
dbList := make([]aorm.DbContent, 0) dbList := make([]aorm.DbContent, 0)
dbList = append(dbList, testSqlite3Connect()) //dbList = append(dbList, testSqlite3Connect())
dbList = append(dbList, testMysqlConnect()) //dbList = append(dbList, testMysqlConnect())
//dbList = append(dbList, testPostgresConnect()) dbList = append(dbList, testPostgresConnect())
dbList = append(dbList, testMssqlConnect()) //dbList = append(dbList, testMssqlConnect())
for i := 0; i < len(dbList); i++ { for i := 0; i < len(dbList); i++ {
dbItem := dbList[i] dbItem := dbList[i]
@@ -145,7 +144,10 @@ func testPostgresConnect() aorm.DbContent {
panic(postgresErr) panic(postgresErr)
} }
postgresContent.DbLink.Ping() err := postgresContent.DbLink.Ping()
if err != nil {
panic(err)
}
return postgresContent return postgresContent
} }
@@ -184,12 +186,12 @@ func testInsert(name string, db *sql.DB) int64 {
Sex: null.BoolFrom(false), Sex: null.BoolFrom(false),
Age: null.IntFrom(18), Age: null.IntFrom(18),
Type: null.IntFrom(0), Type: null.IntFrom(0),
CreateTime: null.TimeFrom(time.Now()), CreateTime: null.IntFrom(2),
Money: null.FloatFrom(100.15), Money: null.FloatFrom(1),
Test: null.FloatFrom(200.15987654321987654321), Test: null.FloatFrom(2),
} }
id, errInsert := aorm.Use(db).Debug(false).Driver(name).Insert(&obj) id, errInsert := aorm.Use(db).Debug(true).Driver(name).Insert(&obj)
if errInsert != nil { if errInsert != nil {
panic(name + " testInsert " + "found err: " + errInsert.Error()) panic(name + " testInsert " + "found err: " + errInsert.Error())
} }
@@ -239,7 +241,7 @@ func testInsertBatch(name string, db *sql.DB) int64 {
Sex: null.BoolFrom(false), Sex: null.BoolFrom(false),
Age: null.IntFrom(18), Age: null.IntFrom(18),
Type: null.IntFrom(0), Type: null.IntFrom(0),
CreateTime: null.TimeFrom(time.Now()), CreateTime: null.IntFrom(1111),
Money: null.FloatFrom(100.15), Money: null.FloatFrom(100.15),
Test: null.FloatFrom(200.15987654321987654321), Test: null.FloatFrom(200.15987654321987654321),
}) })
@@ -249,7 +251,7 @@ func testInsertBatch(name string, db *sql.DB) int64 {
Sex: null.BoolFrom(true), Sex: null.BoolFrom(true),
Age: null.IntFrom(18), Age: null.IntFrom(18),
Type: null.IntFrom(0), Type: null.IntFrom(0),
CreateTime: null.TimeFrom(time.Now()), CreateTime: null.IntFrom(1111),
Money: null.FloatFrom(100.15), Money: null.FloatFrom(100.15),
Test: null.FloatFrom(200.15987654321987654321), Test: null.FloatFrom(200.15987654321987654321),
}) })