mirror of
https://github.com/tangpanqing/aorm.git
synced 2025-10-11 19:00:11 +08:00
add sub query
This commit is contained in:
11
aorm.go
11
aorm.go
@@ -14,7 +14,10 @@ type LinkCommon interface {
|
|||||||
|
|
||||||
// Executor 查询记录所需要的条件
|
// Executor 查询记录所需要的条件
|
||||||
type Executor struct {
|
type Executor struct {
|
||||||
|
//数据库操作连接
|
||||||
linkCommon LinkCommon
|
linkCommon LinkCommon
|
||||||
|
|
||||||
|
//查询参数
|
||||||
tableName string
|
tableName string
|
||||||
selectList []string
|
selectList []string
|
||||||
selectExpList []*ExpItem
|
selectExpList []*ExpItem
|
||||||
@@ -27,6 +30,12 @@ type Executor struct {
|
|||||||
pageSize int
|
pageSize int
|
||||||
isDebug bool
|
isDebug bool
|
||||||
isLockForUpdate bool
|
isLockForUpdate bool
|
||||||
|
|
||||||
|
//sql与参数
|
||||||
|
sql string
|
||||||
|
paramList []interface{}
|
||||||
|
|
||||||
|
//表属性
|
||||||
opinionList []OpinionItem
|
opinionList []OpinionItem
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -63,5 +72,7 @@ func (db *Executor) clear() {
|
|||||||
db.pageSize = 0
|
db.pageSize = 0
|
||||||
db.isDebug = false
|
db.isDebug = false
|
||||||
db.isLockForUpdate = false
|
db.isLockForUpdate = false
|
||||||
|
db.sql = ""
|
||||||
|
db.paramList = make([]interface{}, 0)
|
||||||
db.opinionList = make([]OpinionItem, 0)
|
db.opinionList = make([]OpinionItem, 0)
|
||||||
}
|
}
|
||||||
|
139
crud.go
139
crud.go
@@ -134,21 +134,31 @@ func (db *Executor) InsertBatch(values interface{}) (int64, error) {
|
|||||||
return count, nil
|
return count, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMany 查询记录(新)
|
// GetRows 获取行操作
|
||||||
func (db *Executor) GetMany(values interface{}) error {
|
func (db *Executor) GetRows() (*sql.Rows, error) {
|
||||||
sqlStr, paramList := db.GetSqlAndParams()
|
sqlStr, paramList := db.GetSqlAndParams()
|
||||||
|
|
||||||
smt, errSmt := db.linkCommon.Prepare(sqlStr)
|
smt, errSmt := db.linkCommon.Prepare(sqlStr)
|
||||||
if errSmt != nil {
|
if errSmt != nil {
|
||||||
return errSmt
|
return nil, errSmt
|
||||||
}
|
}
|
||||||
defer smt.Close()
|
//defer smt.Close()
|
||||||
|
|
||||||
rows, errRows := smt.Query(paramList...)
|
rows, errRows := smt.Query(paramList...)
|
||||||
|
if errRows != nil {
|
||||||
|
return nil, errRows
|
||||||
|
}
|
||||||
|
|
||||||
|
return rows, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMany 查询记录(新)
|
||||||
|
func (db *Executor) GetMany(values interface{}) error {
|
||||||
|
rows, errRows := db.GetRows()
|
||||||
|
defer rows.Close()
|
||||||
if errRows != nil {
|
if errRows != nil {
|
||||||
return errRows
|
return errRows
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
destSlice := reflect.Indirect(reflect.ValueOf(values))
|
destSlice := reflect.Indirect(reflect.ValueOf(values))
|
||||||
destType := destSlice.Type().Elem()
|
destType := destSlice.Type().Elem()
|
||||||
@@ -181,19 +191,11 @@ func (db *Executor) GetMany(values interface{}) error {
|
|||||||
func (db *Executor) GetOne(obj interface{}) error {
|
func (db *Executor) GetOne(obj interface{}) error {
|
||||||
db.Limit(0, 1)
|
db.Limit(0, 1)
|
||||||
|
|
||||||
sqlStr, paramList := db.GetSqlAndParams()
|
rows, errRows := db.GetRows()
|
||||||
|
defer rows.Close()
|
||||||
smt, errSmt := db.linkCommon.Prepare(sqlStr)
|
|
||||||
if errSmt != nil {
|
|
||||||
return errSmt
|
|
||||||
}
|
|
||||||
defer smt.Close()
|
|
||||||
|
|
||||||
rows, errRows := smt.Query(paramList...)
|
|
||||||
if errRows != nil {
|
if errRows != nil {
|
||||||
return errRows
|
return errRows
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
destType := reflect.TypeOf(obj).Elem()
|
destType := reflect.TypeOf(obj).Elem()
|
||||||
destValue := reflect.ValueOf(obj).Elem()
|
destValue := reflect.ValueOf(obj).Elem()
|
||||||
@@ -218,8 +220,19 @@ func (db *Executor) GetOne(obj interface{}) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *Executor) GetSqlAndParams() (string, []any) {
|
// RawSql 执行原始的sql语句
|
||||||
var paramList []any
|
func (db *Executor) RawSql(sql string, paramList ...interface{}) *Executor {
|
||||||
|
db.sql = sql
|
||||||
|
db.paramList = paramList
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *Executor) GetSqlAndParams() (string, []interface{}) {
|
||||||
|
if db.sql != "" {
|
||||||
|
return db.sql, db.paramList
|
||||||
|
}
|
||||||
|
|
||||||
|
var paramList []interface{}
|
||||||
|
|
||||||
fieldStr, paramList := handleField(db.selectList, db.selectExpList, paramList)
|
fieldStr, paramList := handleField(db.selectList, db.selectExpList, paramList)
|
||||||
whereStr, paramList := handleWhere(db.whereList, paramList)
|
whereStr, paramList := handleWhere(db.whereList, paramList)
|
||||||
@@ -323,22 +336,13 @@ func (db *Executor) Min(fieldName string) (float64, error) {
|
|||||||
|
|
||||||
// Value 字段值
|
// Value 字段值
|
||||||
func (db *Executor) Value(fieldName string, dest interface{}) error {
|
func (db *Executor) Value(fieldName string, dest interface{}) error {
|
||||||
|
|
||||||
db.Select(fieldName).Limit(0, 1)
|
db.Select(fieldName).Limit(0, 1)
|
||||||
|
|
||||||
sqlStr, paramList := db.GetSqlAndParams()
|
rows, errRows := db.GetRows()
|
||||||
|
defer rows.Close()
|
||||||
smt, errSmt := db.linkCommon.Prepare(sqlStr)
|
|
||||||
if errSmt != nil {
|
|
||||||
return errSmt
|
|
||||||
}
|
|
||||||
defer smt.Close()
|
|
||||||
|
|
||||||
rows, errRows := smt.Query(paramList...)
|
|
||||||
if errRows != nil {
|
if errRows != nil {
|
||||||
return errRows
|
return errRows
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
destValue := reflect.ValueOf(dest).Elem()
|
destValue := reflect.ValueOf(dest).Elem()
|
||||||
|
|
||||||
@@ -372,19 +376,11 @@ func (db *Executor) Value(fieldName string, dest interface{}) error {
|
|||||||
func (db *Executor) Pluck(fieldName string, values interface{}) error {
|
func (db *Executor) Pluck(fieldName string, values interface{}) error {
|
||||||
db.Select(fieldName)
|
db.Select(fieldName)
|
||||||
|
|
||||||
sqlStr, paramList := db.GetSqlAndParams()
|
rows, errRows := db.GetRows()
|
||||||
|
defer rows.Close()
|
||||||
smt, errSmt := db.linkCommon.Prepare(sqlStr)
|
|
||||||
if errSmt != nil {
|
|
||||||
return errSmt
|
|
||||||
}
|
|
||||||
defer smt.Close()
|
|
||||||
|
|
||||||
rows, errRows := smt.Query(paramList...)
|
|
||||||
if errRows != nil {
|
if errRows != nil {
|
||||||
return errRows
|
return errRows
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
destSlice := reflect.Indirect(reflect.ValueOf(values))
|
destSlice := reflect.Indirect(reflect.ValueOf(values))
|
||||||
destType := destSlice.Type().Elem()
|
destType := destSlice.Type().Elem()
|
||||||
@@ -438,71 +434,6 @@ func (db *Executor) Decrement(fieldName string, step int) (int64, error) {
|
|||||||
return db.ExecAffected(sqlStr, paramList...)
|
return db.ExecAffected(sqlStr, paramList...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query 通用查询
|
|
||||||
func (db *Executor) Query(sqlStr string, args ...interface{}) ([]map[string]interface{}, error) {
|
|
||||||
if db.isDebug {
|
|
||||||
fmt.Println(sqlStr)
|
|
||||||
fmt.Println(args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
var listData []map[string]interface{}
|
|
||||||
|
|
||||||
smt, err1 := db.linkCommon.Prepare(sqlStr)
|
|
||||||
if err1 != nil {
|
|
||||||
return listData, err1
|
|
||||||
}
|
|
||||||
defer smt.Close()
|
|
||||||
|
|
||||||
rows, err2 := smt.Query(args...)
|
|
||||||
if err2 != nil {
|
|
||||||
return listData, err2
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
fieldsTypes, errType := rows.ColumnTypes()
|
|
||||||
if errType != nil {
|
|
||||||
return make([]map[string]interface{}, 0), errType
|
|
||||||
}
|
|
||||||
fields, errColumns := rows.Columns()
|
|
||||||
if errColumns != nil {
|
|
||||||
return make([]map[string]interface{}, 0), errColumns
|
|
||||||
}
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
data := make(map[string]interface{})
|
|
||||||
|
|
||||||
scans := make([]interface{}, len(fields))
|
|
||||||
for i := range scans {
|
|
||||||
scans[i] = &scans[i]
|
|
||||||
}
|
|
||||||
err := rows.Scan(scans...)
|
|
||||||
if err != nil {
|
|
||||||
return make([]map[string]interface{}, 0), err
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, v := range scans {
|
|
||||||
if v == nil {
|
|
||||||
data[fields[i]] = v
|
|
||||||
} else {
|
|
||||||
if fieldsTypes[i].DatabaseTypeName() == "VARCHAR" || fieldsTypes[i].DatabaseTypeName() == "TEXT" || fieldsTypes[i].DatabaseTypeName() == "CHAR" || fieldsTypes[i].DatabaseTypeName() == "LONGTEXT" {
|
|
||||||
data[fields[i]] = fmt.Sprintf("%s", v)
|
|
||||||
} else if fieldsTypes[i].DatabaseTypeName() == "INT" || fieldsTypes[i].DatabaseTypeName() == "BIGINT" || fieldsTypes[i].DatabaseTypeName() == "UNSIGNED INT" || fieldsTypes[i].DatabaseTypeName() == "UNSIGNED BIGINT" {
|
|
||||||
data[fields[i]] = fmt.Sprintf("%v", v)
|
|
||||||
} else if fieldsTypes[i].DatabaseTypeName() == "DECIMAL" {
|
|
||||||
data[fields[i]] = string(v.([]uint8))
|
|
||||||
} else {
|
|
||||||
data[fields[i]] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
listData = append(listData, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
db.clear()
|
|
||||||
return listData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exec 通用执行-新增,更新,删除
|
// Exec 通用执行-新增,更新,删除
|
||||||
func (db *Executor) Exec(sqlStr string, args ...interface{}) (sql.Result, error) {
|
func (db *Executor) Exec(sqlStr string, args ...interface{}) (sql.Result, error) {
|
||||||
if db.isDebug {
|
if db.isDebug {
|
||||||
@@ -1177,7 +1108,7 @@ func getFieldNameMap(destValue reflect.Value, destType reflect.Type) map[string]
|
|||||||
func getScans(columnNameList []string, fieldNameMap map[string]int, destValue reflect.Value) []interface{} {
|
func getScans(columnNameList []string, fieldNameMap map[string]int, destValue reflect.Value) []interface{} {
|
||||||
var scans []interface{}
|
var scans []interface{}
|
||||||
for _, columnName := range columnNameList {
|
for _, columnName := range columnNameList {
|
||||||
fieldName := CamelString(columnName)
|
fieldName := CamelString(strings.ToLower(columnName))
|
||||||
index, ok := fieldNameMap[fieldName]
|
index, ok := fieldNameMap[fieldName]
|
||||||
if ok {
|
if ok {
|
||||||
scans = append(scans, destValue.Field(index).Addr().Interface())
|
scans = append(scans, destValue.Field(index).Addr().Interface())
|
||||||
|
176
migrate.go
176
migrate.go
@@ -8,21 +8,19 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Table struct {
|
type Table struct {
|
||||||
TableName string
|
TableName String
|
||||||
Engine string
|
Engine String
|
||||||
Comment string
|
TableComment String
|
||||||
}
|
}
|
||||||
|
|
||||||
type Column struct {
|
type Column struct {
|
||||||
ColumnName string
|
ColumnName String
|
||||||
ColumnDefault string
|
ColumnDefault String
|
||||||
IsNullable string
|
IsNullable String
|
||||||
DataType string //数据类型 varchar,bigint,int
|
DataType String //数据类型 varchar,bigint,int
|
||||||
MaxLength int //数据最大长度 20
|
MaxLength Int //数据最大长度 20
|
||||||
ColumnType string //列类型 varchar(20)
|
ColumnComment String
|
||||||
ColumnComment string
|
Extra String //扩展信息 auto_increment
|
||||||
Extra string //扩展信息 auto_increment
|
|
||||||
DefaultVal string //默认值
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Index struct {
|
type Index struct {
|
||||||
@@ -47,11 +45,12 @@ func (db *Executor) Opinion(key string, val string) *Executor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (db *Executor) ShowCreateTable(tableName string) string {
|
func (db *Executor) ShowCreateTable(tableName string) string {
|
||||||
list, _ := db.Query("show create table " + tableName)
|
var str string
|
||||||
return list[0]["Create Table"].(string)
|
db.RawSql("show create table "+tableName).Value("Create Table", &str)
|
||||||
|
return str
|
||||||
}
|
}
|
||||||
|
|
||||||
// Migrate 迁移数据库结构,需要输入数据库名,表名自动获取
|
// AutoMigrate 迁移数据库结构,需要输入数据库名,表名自动获取
|
||||||
func (db *Executor) AutoMigrate(dest interface{}) {
|
func (db *Executor) AutoMigrate(dest interface{}) {
|
||||||
typeOf := reflect.TypeOf(dest)
|
typeOf := reflect.TypeOf(dest)
|
||||||
arr := strings.Split(typeOf.String(), ".")
|
arr := strings.Split(typeOf.String(), ".")
|
||||||
@@ -60,7 +59,7 @@ func (db *Executor) AutoMigrate(dest interface{}) {
|
|||||||
db.migrateCommon(tableName, typeOf)
|
db.migrateCommon(tableName, typeOf)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AutoMigrate 自动迁移数据库结构,需要输入数据库名,表名
|
// Migrate 自动迁移数据库结构,需要输入数据库名,表名
|
||||||
func (db *Executor) Migrate(tableName string, dest interface{}) {
|
func (db *Executor) Migrate(tableName string, dest interface{}) {
|
||||||
typeOf := reflect.TypeOf(dest)
|
typeOf := reflect.TypeOf(dest)
|
||||||
db.migrateCommon(tableName, typeOf)
|
db.migrateCommon(tableName, typeOf)
|
||||||
@@ -72,14 +71,19 @@ func (db *Executor) migrateCommon(tableName string, typeOf reflect.Type) {
|
|||||||
indexsFromCode := db.getIndexsFromCode(typeOf, tableFromCode)
|
indexsFromCode := db.getIndexsFromCode(typeOf, tableFromCode)
|
||||||
|
|
||||||
//获取数据库名称
|
//获取数据库名称
|
||||||
dbNameRows, _ := db.Query("SELECT DATABASE()")
|
var dbName string
|
||||||
dbName := dbNameRows[0]["DATABASE()"].(string)
|
db.RawSql("SELECT DATABASE()").Value("DATABASE()", &dbName)
|
||||||
|
|
||||||
//查询表信息,如果找不到就新建
|
//查询表信息,如果找不到就新建
|
||||||
sql := "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA =" + "'" + dbName + "' AND TABLE_NAME =" + "'" + tableName + "'"
|
sql := "SELECT TABLE_NAME,ENGINE,TABLE_COMMENT FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA =" + "'" + dbName + "' AND TABLE_NAME =" + "'" + tableName + "'"
|
||||||
dataList, _ := db.Query(sql)
|
var dataList []Table
|
||||||
|
db.RawSql(sql).GetMany(&dataList)
|
||||||
|
for i := 0; i < len(dataList); i++ {
|
||||||
|
dataList[i].TableComment = StringFrom("'" + dataList[i].TableComment.String + "'")
|
||||||
|
}
|
||||||
|
|
||||||
if len(dataList) != 0 {
|
if len(dataList) != 0 {
|
||||||
tableFromDb := getTableFromDb(dataList)
|
tableFromDb := dataList[0]
|
||||||
columnsFromDb := db.getColumnsFromDb(dbName, tableName)
|
columnsFromDb := db.getColumnsFromDb(dbName, tableName)
|
||||||
indexsFromDb := db.getIndexsFromDb(tableName)
|
indexsFromDb := db.getIndexsFromDb(tableName)
|
||||||
|
|
||||||
@@ -91,9 +95,9 @@ func (db *Executor) migrateCommon(tableName string, typeOf reflect.Type) {
|
|||||||
|
|
||||||
func (db *Executor) getTableFromCode(tableName string) Table {
|
func (db *Executor) getTableFromCode(tableName string) Table {
|
||||||
var tableFromCode Table
|
var tableFromCode Table
|
||||||
tableFromCode.TableName = tableName
|
tableFromCode.TableName = StringFrom(tableName)
|
||||||
tableFromCode.Engine = db.getValFromOpinion("ENGINE", "MyISAM")
|
tableFromCode.Engine = StringFrom(db.getValFromOpinion("ENGINE", "MyISAM"))
|
||||||
tableFromCode.Comment = db.getValFromOpinion("COMMENT", "")
|
tableFromCode.TableComment = StringFrom(db.getValFromOpinion("COMMENT", ""))
|
||||||
|
|
||||||
return tableFromCode
|
return tableFromCode
|
||||||
}
|
}
|
||||||
@@ -130,7 +134,7 @@ func (db *Executor) getIndexsFromCode(typeOf reflect.Type, tableFromCode Table)
|
|||||||
indexsFromCode = append(indexsFromCode, Index{
|
indexsFromCode = append(indexsFromCode, Index{
|
||||||
NonUnique: 0,
|
NonUnique: 0,
|
||||||
ColumnName: fieldName,
|
ColumnName: fieldName,
|
||||||
KeyName: "idx_" + tableFromCode.TableName + "_" + fieldName,
|
KeyName: "idx_" + tableFromCode.TableName.String + "_" + fieldName,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,7 +143,7 @@ func (db *Executor) getIndexsFromCode(typeOf reflect.Type, tableFromCode Table)
|
|||||||
indexsFromCode = append(indexsFromCode, Index{
|
indexsFromCode = append(indexsFromCode, Index{
|
||||||
NonUnique: 1,
|
NonUnique: 1,
|
||||||
ColumnName: fieldName,
|
ColumnName: fieldName,
|
||||||
KeyName: "idx_" + tableFromCode.TableName + "_" + fieldName,
|
KeyName: "idx_" + tableFromCode.TableName.String + "_" + fieldName,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -147,43 +151,16 @@ func (db *Executor) getIndexsFromCode(typeOf reflect.Type, tableFromCode Table)
|
|||||||
return indexsFromCode
|
return indexsFromCode
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTableFromDb(dataList []map[string]interface{}) Table {
|
|
||||||
var tableFromDb Table
|
|
||||||
tableFromDb.TableName = fmt.Sprintf("%v", dataList[0]["TABLE_NAME"])
|
|
||||||
tableFromDb.Engine = fmt.Sprintf("%v", dataList[0]["ENGINE"])
|
|
||||||
tableFromDb.Comment = "'" + fmt.Sprintf("%v", dataList[0]["TABLE_COMMENT"]) + "'"
|
|
||||||
|
|
||||||
return tableFromDb
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *Executor) getColumnsFromDb(dbName string, tableName string) []Column {
|
func (db *Executor) getColumnsFromDb(dbName string, tableName string) []Column {
|
||||||
var columnsFromDb []Column
|
var columnsFromDb []Column
|
||||||
|
|
||||||
sqlColumn := "SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA =" + "'" + dbName + "' AND TABLE_NAME =" + "'" + tableName + "'"
|
sqlColumn := "SELECT COLUMN_NAME,DATA_TYPE,CHARACTER_MAXIMUM_LENGTH as Max_Length,COLUMN_DEFAULT,COLUMN_COMMENT,EXTRA,IS_NULLABLE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA =" + "'" + dbName + "' AND TABLE_NAME =" + "'" + tableName + "'"
|
||||||
dataColumn, _ := db.Query(sqlColumn)
|
db.RawSql(sqlColumn).GetMany(&columnsFromDb)
|
||||||
|
|
||||||
for j := 0; j < len(dataColumn); j++ {
|
for j := 0; j < len(columnsFromDb); j++ {
|
||||||
dataType := dataColumn[j]["DATA_TYPE"].(string)
|
if columnsFromDb[j].DataType.String == "text" && columnsFromDb[j].MaxLength.Int64 == 65535 {
|
||||||
maxLength, _ := strconv.Atoi(fmt.Sprintf("%v", dataColumn[j]["CHARACTER_MAXIMUM_LENGTH"]))
|
columnsFromDb[j].MaxLength = IntFrom(0)
|
||||||
if dataType == "text" && maxLength == 65535 {
|
|
||||||
maxLength = 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultVal := ""
|
|
||||||
if dataColumn[j]["COLUMN_DEFAULT"] != nil {
|
|
||||||
defaultVal = dataColumn[j]["COLUMN_DEFAULT"].(string)
|
|
||||||
}
|
|
||||||
|
|
||||||
columnsFromDb = append(columnsFromDb, Column{
|
|
||||||
ColumnName: dataColumn[j]["COLUMN_NAME"].(string),
|
|
||||||
DataType: dataType,
|
|
||||||
IsNullable: dataColumn[j]["IS_NULLABLE"].(string),
|
|
||||||
MaxLength: maxLength,
|
|
||||||
ColumnType: dataColumn[j]["COLUMN_TYPE"].(string),
|
|
||||||
ColumnComment: dataColumn[j]["COLUMN_COMMENT"].(string),
|
|
||||||
Extra: dataColumn[j]["EXTRA"].(string),
|
|
||||||
DefaultVal: defaultVal,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return columnsFromDb
|
return columnsFromDb
|
||||||
@@ -191,17 +168,9 @@ func (db *Executor) getColumnsFromDb(dbName string, tableName string) []Column {
|
|||||||
|
|
||||||
func (db *Executor) getIndexsFromDb(tableName string) []Index {
|
func (db *Executor) getIndexsFromDb(tableName string) []Index {
|
||||||
sqlIndex := "SHOW INDEXES FROM " + tableName
|
sqlIndex := "SHOW INDEXES FROM " + tableName
|
||||||
dataIndex, _ := db.Query(sqlIndex)
|
|
||||||
|
|
||||||
var indexsFromDb []Index
|
var indexsFromDb []Index
|
||||||
for j := 0; j < len(dataIndex); j++ {
|
db.RawSql(sqlIndex).GetMany(&indexsFromDb)
|
||||||
nonUnique, _ := strconv.Atoi(fmt.Sprintf("%v", dataIndex[j]["Non_unique"]))
|
|
||||||
indexsFromDb = append(indexsFromDb, Index{
|
|
||||||
ColumnName: fmt.Sprintf("%v", dataIndex[j]["Column_name"]),
|
|
||||||
KeyName: fmt.Sprintf("%v", dataIndex[j]["Key_name"]),
|
|
||||||
NonUnique: nonUnique,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return indexsFromDb
|
return indexsFromDb
|
||||||
}
|
}
|
||||||
@@ -209,7 +178,7 @@ func (db *Executor) getIndexsFromDb(tableName string) []Index {
|
|||||||
// 修改表
|
// 修改表
|
||||||
func (db *Executor) modifyTable(tableFromCode Table, columnsFromCode []Column, indexsFromCode []Index, tableFromDb Table, columnsFromDb []Column, indexsFromDb []Index) {
|
func (db *Executor) modifyTable(tableFromCode Table, columnsFromCode []Column, indexsFromCode []Index, tableFromDb Table, columnsFromDb []Column, indexsFromDb []Index) {
|
||||||
if tableFromCode.Engine != tableFromDb.Engine {
|
if tableFromCode.Engine != tableFromDb.Engine {
|
||||||
sql := "ALTER TABLE " + tableFromCode.TableName + " Engine " + tableFromCode.Engine
|
sql := "ALTER TABLE " + tableFromCode.TableName.String + " Engine " + tableFromCode.Engine.String
|
||||||
_, err := db.Exec(sql)
|
_, err := db.Exec(sql)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
@@ -218,8 +187,8 @@ func (db *Executor) modifyTable(tableFromCode Table, columnsFromCode []Column, i
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if tableFromCode.Comment != tableFromDb.Comment {
|
if tableFromCode.TableComment != tableFromDb.TableComment {
|
||||||
sql := "ALTER TABLE " + tableFromCode.TableName + " Comment " + tableFromCode.Comment
|
sql := "ALTER TABLE " + tableFromCode.TableName.String + " Comment " + tableFromCode.TableComment.String
|
||||||
_, err := db.Exec(sql)
|
_, err := db.Exec(sql)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
@@ -236,8 +205,12 @@ func (db *Executor) modifyTable(tableFromCode Table, columnsFromCode []Column, i
|
|||||||
columnDb := columnsFromDb[j]
|
columnDb := columnsFromDb[j]
|
||||||
if columnCode.ColumnName == columnDb.ColumnName {
|
if columnCode.ColumnName == columnDb.ColumnName {
|
||||||
isFind = 1
|
isFind = 1
|
||||||
if columnCode.DataType != columnDb.DataType || columnCode.MaxLength != columnDb.MaxLength || columnCode.ColumnComment != columnDb.ColumnComment || columnCode.Extra != columnDb.Extra || columnCode.DefaultVal != columnDb.DefaultVal {
|
if columnCode.DataType.String != columnDb.DataType.String ||
|
||||||
sql := "ALTER TABLE " + tableFromCode.TableName + " MODIFY " + getColumnStr(columnCode)
|
columnCode.MaxLength.Int64 != columnDb.MaxLength.Int64 ||
|
||||||
|
columnCode.ColumnComment.String != columnDb.ColumnComment.String ||
|
||||||
|
columnCode.Extra.String != columnDb.Extra.String ||
|
||||||
|
columnCode.ColumnDefault.String != columnDb.ColumnDefault.String {
|
||||||
|
sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getColumnStr(columnCode)
|
||||||
_, err := db.Exec(sql)
|
_, err := db.Exec(sql)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
@@ -249,7 +222,7 @@ func (db *Executor) modifyTable(tableFromCode Table, columnsFromCode []Column, i
|
|||||||
}
|
}
|
||||||
|
|
||||||
if isFind == 0 {
|
if isFind == 0 {
|
||||||
sql := "ALTER TABLE " + tableFromCode.TableName + " ADD " + getColumnStr(columnCode)
|
sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getColumnStr(columnCode)
|
||||||
_, err := db.Exec(sql)
|
_, err := db.Exec(sql)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
@@ -268,7 +241,7 @@ func (db *Executor) modifyTable(tableFromCode Table, columnsFromCode []Column, i
|
|||||||
if indexCode.ColumnName == indexDb.ColumnName {
|
if indexCode.ColumnName == indexDb.ColumnName {
|
||||||
isFind = 1
|
isFind = 1
|
||||||
if indexCode.KeyName != indexDb.KeyName || indexCode.NonUnique != indexDb.NonUnique {
|
if indexCode.KeyName != indexDb.KeyName || indexCode.NonUnique != indexDb.NonUnique {
|
||||||
sql := "ALTER TABLE " + tableFromCode.TableName + " MODIFY " + getIndexStr(indexCode)
|
sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getIndexStr(indexCode)
|
||||||
_, err := db.Exec(sql)
|
_, err := db.Exec(sql)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
@@ -280,7 +253,7 @@ func (db *Executor) modifyTable(tableFromCode Table, columnsFromCode []Column, i
|
|||||||
}
|
}
|
||||||
|
|
||||||
if isFind == 0 {
|
if isFind == 0 {
|
||||||
sql := "ALTER TABLE " + tableFromCode.TableName + " ADD " + getIndexStr(indexCode)
|
sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getIndexStr(indexCode)
|
||||||
_, err := db.Exec(sql)
|
_, err := db.Exec(sql)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
@@ -305,12 +278,12 @@ func (db *Executor) createTable(tableFromCode Table, columnsFromCode []Column, i
|
|||||||
fieldArr = append(fieldArr, getIndexStr(index))
|
fieldArr = append(fieldArr, getIndexStr(index))
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlStr := "CREATE TABLE `" + tableFromCode.TableName + "` (\n" + strings.Join(fieldArr, ",\n") + "\n) " + getTableInfoFromCode(tableFromCode) + ";"
|
sqlStr := "CREATE TABLE `" + tableFromCode.TableName.String + "` (\n" + strings.Join(fieldArr, ",\n") + "\n) " + getTableInfoFromCode(tableFromCode) + ";"
|
||||||
_, err := db.Exec(sqlStr)
|
_, err := db.Exec(sqlStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
} else {
|
} else {
|
||||||
fmt.Println("创建表:" + tableFromCode.TableName)
|
fmt.Println("创建表:" + tableFromCode.TableName.String)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -326,33 +299,26 @@ func (db *Executor) getValFromOpinion(key string, def string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getTableInfoFromCode(tableFromCode Table) string {
|
func getTableInfoFromCode(tableFromCode Table) string {
|
||||||
return " ENGINE " + tableFromCode.Engine + " COMMENT " + tableFromCode.Comment
|
return " ENGINE " + tableFromCode.Engine.String + " COMMENT " + tableFromCode.TableComment.String
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获得某列的结构
|
// 获得某列的结构
|
||||||
func getColumnFromCode(fieldName string, fieldType string, fieldMap map[string]string) Column {
|
func getColumnFromCode(fieldName string, fieldType string, fieldMap map[string]string) Column {
|
||||||
var column Column
|
var column Column
|
||||||
//字段名
|
//字段名
|
||||||
column.ColumnName = fieldName
|
column.ColumnName = StringFrom(fieldName)
|
||||||
//字段数据类型
|
//字段数据类型
|
||||||
column.DataType = getDataType(fieldType, fieldMap)
|
column.DataType = StringFrom(getDataType(fieldType, fieldMap))
|
||||||
//字段数据长度
|
//字段数据长度
|
||||||
maxLength := getMaxLength(column.DataType, fieldMap)
|
column.MaxLength = IntFrom(int64(getMaxLength(column.DataType.String, fieldMap)))
|
||||||
columnType := column.DataType
|
|
||||||
if maxLength > 0 {
|
|
||||||
columnType = columnType + "(" + strconv.Itoa(maxLength) + ")"
|
|
||||||
}
|
|
||||||
column.MaxLength = maxLength
|
|
||||||
//字段是否可以为空
|
//字段是否可以为空
|
||||||
column.IsNullable = getNullAble(fieldMap)
|
column.IsNullable = StringFrom(getNullAble(fieldMap))
|
||||||
//字段注释
|
//字段注释
|
||||||
column.ColumnComment = getComment(fieldMap)
|
column.ColumnComment = StringFrom(getComment(fieldMap))
|
||||||
//字段类型
|
|
||||||
column.ColumnType = columnType
|
|
||||||
//扩展信息
|
//扩展信息
|
||||||
column.Extra = getExtra(fieldMap)
|
column.Extra = StringFrom(getExtra(fieldMap))
|
||||||
//默认信息
|
//默认信息
|
||||||
column.DefaultVal = getDefaultVal(fieldMap)
|
column.ColumnDefault = StringFrom(getDefaultVal(fieldMap))
|
||||||
|
|
||||||
return column
|
return column
|
||||||
}
|
}
|
||||||
@@ -375,31 +341,31 @@ func getTagMap(fieldTag string) map[string]string {
|
|||||||
|
|
||||||
func getColumnStr(column Column) string {
|
func getColumnStr(column Column) string {
|
||||||
var strArr []string
|
var strArr []string
|
||||||
strArr = append(strArr, column.ColumnName)
|
strArr = append(strArr, column.ColumnName.String)
|
||||||
if column.MaxLength == 0 {
|
if column.MaxLength.Int64 == 0 {
|
||||||
if column.DataType == "varchar" {
|
if column.DataType.String == "varchar" {
|
||||||
strArr = append(strArr, column.DataType+"(255)")
|
strArr = append(strArr, column.DataType.String+"(255)")
|
||||||
} else {
|
} else {
|
||||||
strArr = append(strArr, column.DataType)
|
strArr = append(strArr, column.DataType.String)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
strArr = append(strArr, column.DataType+"("+strconv.Itoa(column.MaxLength)+")")
|
strArr = append(strArr, column.DataType.String+"("+strconv.Itoa(int(column.MaxLength.Int64))+")")
|
||||||
}
|
}
|
||||||
|
|
||||||
if column.DefaultVal != "" {
|
if column.ColumnDefault.String != "" {
|
||||||
strArr = append(strArr, "DEFAULT '"+column.DefaultVal+"'")
|
strArr = append(strArr, "DEFAULT '"+column.ColumnDefault.String+"'")
|
||||||
}
|
}
|
||||||
|
|
||||||
if column.IsNullable == "NO" {
|
if column.IsNullable.String == "NO" {
|
||||||
strArr = append(strArr, "NOT NULL")
|
strArr = append(strArr, "NOT NULL")
|
||||||
}
|
}
|
||||||
|
|
||||||
if column.ColumnComment != "" {
|
if column.ColumnComment.String != "" {
|
||||||
strArr = append(strArr, "COMMENT '"+column.ColumnComment+"'")
|
strArr = append(strArr, "COMMENT '"+column.ColumnComment.String+"'")
|
||||||
}
|
}
|
||||||
|
|
||||||
if column.Extra != "" {
|
if column.Extra.String != "" {
|
||||||
strArr = append(strArr, column.Extra)
|
strArr = append(strArr, column.Extra.String)
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings.Join(strArr, " ")
|
return strings.Join(strArr, " ")
|
||||||
|
@@ -94,7 +94,6 @@ func TestAll(t *testing.T) {
|
|||||||
testMin(name, db)
|
testMin(name, db)
|
||||||
testMax(name, db)
|
testMax(name, db)
|
||||||
|
|
||||||
testQuery(name, db)
|
|
||||||
testExec(name, db)
|
testExec(name, db)
|
||||||
|
|
||||||
testTransaction(name, db)
|
testTransaction(name, db)
|
||||||
@@ -127,7 +126,6 @@ func testConnect() *sql.DB {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func testMigrate(name string, db *sql.DB) {
|
func testMigrate(name string, db *sql.DB) {
|
||||||
|
|
||||||
//AutoMigrate
|
//AutoMigrate
|
||||||
aorm.Use(db).Opinion("ENGINE", "InnoDB").Opinion("COMMENT", "人员表").AutoMigrate(&Person{})
|
aorm.Use(db).Opinion("ENGINE", "InnoDB").Opinion("COMMENT", "人员表").AutoMigrate(&Person{})
|
||||||
aorm.Use(db).Opinion("ENGINE", "InnoDB").Opinion("COMMENT", "文章").AutoMigrate(&Article{})
|
aorm.Use(db).Opinion("ENGINE", "InnoDB").Opinion("COMMENT", "文章").AutoMigrate(&Article{})
|
||||||
@@ -206,7 +204,7 @@ func testGetMany(name string, db *sql.DB) {
|
|||||||
var list []Person
|
var list []Person
|
||||||
errSelect := aorm.Use(db).Debug(false).Where(&Person{Type: aorm.IntFrom(0)}).GetMany(&list)
|
errSelect := aorm.Use(db).Debug(false).Where(&Person{Type: aorm.IntFrom(0)}).GetMany(&list)
|
||||||
if errSelect != nil {
|
if errSelect != nil {
|
||||||
panic(name + "testGetMany" + "found err")
|
panic(name + " testGetMany " + "found err:" + errSelect.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -441,7 +439,7 @@ func testPluck(name string, db *sql.DB) {
|
|||||||
var ageList []int64
|
var ageList []int64
|
||||||
errAgeList := aorm.Use(db).Debug(false).Where(&Person{Type: aorm.IntFrom(0)}).Limit(0, 3).Pluck("age", &ageList)
|
errAgeList := aorm.Use(db).Debug(false).Where(&Person{Type: aorm.IntFrom(0)}).Limit(0, 3).Pluck("age", &ageList)
|
||||||
if errAgeList != nil {
|
if errAgeList != nil {
|
||||||
panic(name + "testPluck" + "found err")
|
panic(name + "testPluck" + "found err:" + errAgeList.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
var moneyList []float32
|
var moneyList []float32
|
||||||
@@ -492,13 +490,6 @@ func testMax(name string, db *sql.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testQuery(name string, db *sql.DB) {
|
|
||||||
_, err := aorm.Use(db).Debug(false).Query("SELECT * FROM person WHERE id=? AND type=?", 1, 3)
|
|
||||||
if err != nil {
|
|
||||||
panic(name + "testQuery" + "found err")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func testExec(name string, db *sql.DB) {
|
func testExec(name string, db *sql.DB) {
|
||||||
_, err := aorm.Use(db).Debug(false).Exec("UPDATE person SET name = ? WHERE id=?", "Bob", 3)
|
_, err := aorm.Use(db).Debug(false).Exec("UPDATE person SET name = ? WHERE id=?", "Bob", 3)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
Reference in New Issue
Block a user