add sub query

This commit is contained in:
tangpanqing
2022-12-20 15:52:06 +08:00
parent ce28ce1559
commit 2620229f51
4 changed files with 121 additions and 222 deletions

11
aorm.go
View File

@@ -14,7 +14,10 @@ type LinkCommon interface {
// Executor 查询记录所需要的条件 // Executor 查询记录所需要的条件
type Executor struct { type Executor struct {
//数据库操作连接
linkCommon LinkCommon linkCommon LinkCommon
//查询参数
tableName string tableName string
selectList []string selectList []string
selectExpList []*ExpItem selectExpList []*ExpItem
@@ -27,6 +30,12 @@ type Executor struct {
pageSize int pageSize int
isDebug bool isDebug bool
isLockForUpdate bool isLockForUpdate bool
//sql与参数
sql string
paramList []interface{}
//表属性
opinionList []OpinionItem opinionList []OpinionItem
} }
@@ -63,5 +72,7 @@ func (db *Executor) clear() {
db.pageSize = 0 db.pageSize = 0
db.isDebug = false db.isDebug = false
db.isLockForUpdate = false db.isLockForUpdate = false
db.sql = ""
db.paramList = make([]interface{}, 0)
db.opinionList = make([]OpinionItem, 0) db.opinionList = make([]OpinionItem, 0)
} }

139
crud.go
View File

@@ -134,21 +134,31 @@ func (db *Executor) InsertBatch(values interface{}) (int64, error) {
return count, nil return count, nil
} }
// GetMany 查询记录(新) // GetRows 获取行操作
func (db *Executor) GetMany(values interface{}) error { func (db *Executor) GetRows() (*sql.Rows, error) {
sqlStr, paramList := db.GetSqlAndParams() sqlStr, paramList := db.GetSqlAndParams()
smt, errSmt := db.linkCommon.Prepare(sqlStr) smt, errSmt := db.linkCommon.Prepare(sqlStr)
if errSmt != nil { if errSmt != nil {
return errSmt return nil, errSmt
} }
defer smt.Close() //defer smt.Close()
rows, errRows := smt.Query(paramList...) rows, errRows := smt.Query(paramList...)
if errRows != nil {
return nil, errRows
}
return rows, nil
}
// GetMany 查询记录(新)
func (db *Executor) GetMany(values interface{}) error {
rows, errRows := db.GetRows()
defer rows.Close()
if errRows != nil { if errRows != nil {
return errRows return errRows
} }
defer rows.Close()
destSlice := reflect.Indirect(reflect.ValueOf(values)) destSlice := reflect.Indirect(reflect.ValueOf(values))
destType := destSlice.Type().Elem() destType := destSlice.Type().Elem()
@@ -181,19 +191,11 @@ func (db *Executor) GetMany(values interface{}) error {
func (db *Executor) GetOne(obj interface{}) error { func (db *Executor) GetOne(obj interface{}) error {
db.Limit(0, 1) db.Limit(0, 1)
sqlStr, paramList := db.GetSqlAndParams() rows, errRows := db.GetRows()
defer rows.Close()
smt, errSmt := db.linkCommon.Prepare(sqlStr)
if errSmt != nil {
return errSmt
}
defer smt.Close()
rows, errRows := smt.Query(paramList...)
if errRows != nil { if errRows != nil {
return errRows return errRows
} }
defer rows.Close()
destType := reflect.TypeOf(obj).Elem() destType := reflect.TypeOf(obj).Elem()
destValue := reflect.ValueOf(obj).Elem() destValue := reflect.ValueOf(obj).Elem()
@@ -218,8 +220,19 @@ func (db *Executor) GetOne(obj interface{}) error {
return nil return nil
} }
func (db *Executor) GetSqlAndParams() (string, []any) { // RawSql 执行原始的sql语句
var paramList []any func (db *Executor) RawSql(sql string, paramList ...interface{}) *Executor {
db.sql = sql
db.paramList = paramList
return db
}
func (db *Executor) GetSqlAndParams() (string, []interface{}) {
if db.sql != "" {
return db.sql, db.paramList
}
var paramList []interface{}
fieldStr, paramList := handleField(db.selectList, db.selectExpList, paramList) fieldStr, paramList := handleField(db.selectList, db.selectExpList, paramList)
whereStr, paramList := handleWhere(db.whereList, paramList) whereStr, paramList := handleWhere(db.whereList, paramList)
@@ -323,22 +336,13 @@ func (db *Executor) Min(fieldName string) (float64, error) {
// Value 字段值 // Value 字段值
func (db *Executor) Value(fieldName string, dest interface{}) error { func (db *Executor) Value(fieldName string, dest interface{}) error {
db.Select(fieldName).Limit(0, 1) db.Select(fieldName).Limit(0, 1)
sqlStr, paramList := db.GetSqlAndParams() rows, errRows := db.GetRows()
defer rows.Close()
smt, errSmt := db.linkCommon.Prepare(sqlStr)
if errSmt != nil {
return errSmt
}
defer smt.Close()
rows, errRows := smt.Query(paramList...)
if errRows != nil { if errRows != nil {
return errRows return errRows
} }
defer rows.Close()
destValue := reflect.ValueOf(dest).Elem() destValue := reflect.ValueOf(dest).Elem()
@@ -372,19 +376,11 @@ func (db *Executor) Value(fieldName string, dest interface{}) error {
func (db *Executor) Pluck(fieldName string, values interface{}) error { func (db *Executor) Pluck(fieldName string, values interface{}) error {
db.Select(fieldName) db.Select(fieldName)
sqlStr, paramList := db.GetSqlAndParams() rows, errRows := db.GetRows()
defer rows.Close()
smt, errSmt := db.linkCommon.Prepare(sqlStr)
if errSmt != nil {
return errSmt
}
defer smt.Close()
rows, errRows := smt.Query(paramList...)
if errRows != nil { if errRows != nil {
return errRows return errRows
} }
defer rows.Close()
destSlice := reflect.Indirect(reflect.ValueOf(values)) destSlice := reflect.Indirect(reflect.ValueOf(values))
destType := destSlice.Type().Elem() destType := destSlice.Type().Elem()
@@ -438,71 +434,6 @@ func (db *Executor) Decrement(fieldName string, step int) (int64, error) {
return db.ExecAffected(sqlStr, paramList...) return db.ExecAffected(sqlStr, paramList...)
} }
// Query 通用查询
func (db *Executor) Query(sqlStr string, args ...interface{}) ([]map[string]interface{}, error) {
if db.isDebug {
fmt.Println(sqlStr)
fmt.Println(args...)
}
var listData []map[string]interface{}
smt, err1 := db.linkCommon.Prepare(sqlStr)
if err1 != nil {
return listData, err1
}
defer smt.Close()
rows, err2 := smt.Query(args...)
if err2 != nil {
return listData, err2
}
defer rows.Close()
fieldsTypes, errType := rows.ColumnTypes()
if errType != nil {
return make([]map[string]interface{}, 0), errType
}
fields, errColumns := rows.Columns()
if errColumns != nil {
return make([]map[string]interface{}, 0), errColumns
}
for rows.Next() {
data := make(map[string]interface{})
scans := make([]interface{}, len(fields))
for i := range scans {
scans[i] = &scans[i]
}
err := rows.Scan(scans...)
if err != nil {
return make([]map[string]interface{}, 0), err
}
for i, v := range scans {
if v == nil {
data[fields[i]] = v
} else {
if fieldsTypes[i].DatabaseTypeName() == "VARCHAR" || fieldsTypes[i].DatabaseTypeName() == "TEXT" || fieldsTypes[i].DatabaseTypeName() == "CHAR" || fieldsTypes[i].DatabaseTypeName() == "LONGTEXT" {
data[fields[i]] = fmt.Sprintf("%s", v)
} else if fieldsTypes[i].DatabaseTypeName() == "INT" || fieldsTypes[i].DatabaseTypeName() == "BIGINT" || fieldsTypes[i].DatabaseTypeName() == "UNSIGNED INT" || fieldsTypes[i].DatabaseTypeName() == "UNSIGNED BIGINT" {
data[fields[i]] = fmt.Sprintf("%v", v)
} else if fieldsTypes[i].DatabaseTypeName() == "DECIMAL" {
data[fields[i]] = string(v.([]uint8))
} else {
data[fields[i]] = v
}
}
}
listData = append(listData, data)
}
db.clear()
return listData, nil
}
// Exec 通用执行-新增,更新,删除 // Exec 通用执行-新增,更新,删除
func (db *Executor) Exec(sqlStr string, args ...interface{}) (sql.Result, error) { func (db *Executor) Exec(sqlStr string, args ...interface{}) (sql.Result, error) {
if db.isDebug { if db.isDebug {
@@ -1177,7 +1108,7 @@ func getFieldNameMap(destValue reflect.Value, destType reflect.Type) map[string]
func getScans(columnNameList []string, fieldNameMap map[string]int, destValue reflect.Value) []interface{} { func getScans(columnNameList []string, fieldNameMap map[string]int, destValue reflect.Value) []interface{} {
var scans []interface{} var scans []interface{}
for _, columnName := range columnNameList { for _, columnName := range columnNameList {
fieldName := CamelString(columnName) fieldName := CamelString(strings.ToLower(columnName))
index, ok := fieldNameMap[fieldName] index, ok := fieldNameMap[fieldName]
if ok { if ok {
scans = append(scans, destValue.Field(index).Addr().Interface()) scans = append(scans, destValue.Field(index).Addr().Interface())

View File

@@ -8,21 +8,19 @@ import (
) )
type Table struct { type Table struct {
TableName string TableName String
Engine string Engine String
Comment string TableComment String
} }
type Column struct { type Column struct {
ColumnName string ColumnName String
ColumnDefault string ColumnDefault String
IsNullable string IsNullable String
DataType string //数据类型 varchar,bigint,int DataType String //数据类型 varchar,bigint,int
MaxLength int //数据最大长度 20 MaxLength Int //数据最大长度 20
ColumnType string //列类型 varchar(20) ColumnComment String
ColumnComment string Extra String //扩展信息 auto_increment
Extra string //扩展信息 auto_increment
DefaultVal string //默认值
} }
type Index struct { type Index struct {
@@ -47,11 +45,12 @@ func (db *Executor) Opinion(key string, val string) *Executor {
} }
func (db *Executor) ShowCreateTable(tableName string) string { func (db *Executor) ShowCreateTable(tableName string) string {
list, _ := db.Query("show create table " + tableName) var str string
return list[0]["Create Table"].(string) db.RawSql("show create table "+tableName).Value("Create Table", &str)
return str
} }
// Migrate 迁移数据库结构,需要输入数据库名,表名自动获取 // AutoMigrate 迁移数据库结构,需要输入数据库名,表名自动获取
func (db *Executor) AutoMigrate(dest interface{}) { func (db *Executor) AutoMigrate(dest interface{}) {
typeOf := reflect.TypeOf(dest) typeOf := reflect.TypeOf(dest)
arr := strings.Split(typeOf.String(), ".") arr := strings.Split(typeOf.String(), ".")
@@ -60,7 +59,7 @@ func (db *Executor) AutoMigrate(dest interface{}) {
db.migrateCommon(tableName, typeOf) db.migrateCommon(tableName, typeOf)
} }
// AutoMigrate 自动迁移数据库结构,需要输入数据库名,表名 // Migrate 自动迁移数据库结构,需要输入数据库名,表名
func (db *Executor) Migrate(tableName string, dest interface{}) { func (db *Executor) Migrate(tableName string, dest interface{}) {
typeOf := reflect.TypeOf(dest) typeOf := reflect.TypeOf(dest)
db.migrateCommon(tableName, typeOf) db.migrateCommon(tableName, typeOf)
@@ -72,14 +71,19 @@ func (db *Executor) migrateCommon(tableName string, typeOf reflect.Type) {
indexsFromCode := db.getIndexsFromCode(typeOf, tableFromCode) indexsFromCode := db.getIndexsFromCode(typeOf, tableFromCode)
//获取数据库名称 //获取数据库名称
dbNameRows, _ := db.Query("SELECT DATABASE()") var dbName string
dbName := dbNameRows[0]["DATABASE()"].(string) db.RawSql("SELECT DATABASE()").Value("DATABASE()", &dbName)
//查询表信息,如果找不到就新建 //查询表信息,如果找不到就新建
sql := "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA =" + "'" + dbName + "' AND TABLE_NAME =" + "'" + tableName + "'" sql := "SELECT TABLE_NAME,ENGINE,TABLE_COMMENT FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA =" + "'" + dbName + "' AND TABLE_NAME =" + "'" + tableName + "'"
dataList, _ := db.Query(sql) var dataList []Table
db.RawSql(sql).GetMany(&dataList)
for i := 0; i < len(dataList); i++ {
dataList[i].TableComment = StringFrom("'" + dataList[i].TableComment.String + "'")
}
if len(dataList) != 0 { if len(dataList) != 0 {
tableFromDb := getTableFromDb(dataList) tableFromDb := dataList[0]
columnsFromDb := db.getColumnsFromDb(dbName, tableName) columnsFromDb := db.getColumnsFromDb(dbName, tableName)
indexsFromDb := db.getIndexsFromDb(tableName) indexsFromDb := db.getIndexsFromDb(tableName)
@@ -91,9 +95,9 @@ func (db *Executor) migrateCommon(tableName string, typeOf reflect.Type) {
func (db *Executor) getTableFromCode(tableName string) Table { func (db *Executor) getTableFromCode(tableName string) Table {
var tableFromCode Table var tableFromCode Table
tableFromCode.TableName = tableName tableFromCode.TableName = StringFrom(tableName)
tableFromCode.Engine = db.getValFromOpinion("ENGINE", "MyISAM") tableFromCode.Engine = StringFrom(db.getValFromOpinion("ENGINE", "MyISAM"))
tableFromCode.Comment = db.getValFromOpinion("COMMENT", "") tableFromCode.TableComment = StringFrom(db.getValFromOpinion("COMMENT", ""))
return tableFromCode return tableFromCode
} }
@@ -130,7 +134,7 @@ func (db *Executor) getIndexsFromCode(typeOf reflect.Type, tableFromCode Table)
indexsFromCode = append(indexsFromCode, Index{ indexsFromCode = append(indexsFromCode, Index{
NonUnique: 0, NonUnique: 0,
ColumnName: fieldName, ColumnName: fieldName,
KeyName: "idx_" + tableFromCode.TableName + "_" + fieldName, KeyName: "idx_" + tableFromCode.TableName.String + "_" + fieldName,
}) })
} }
@@ -139,7 +143,7 @@ func (db *Executor) getIndexsFromCode(typeOf reflect.Type, tableFromCode Table)
indexsFromCode = append(indexsFromCode, Index{ indexsFromCode = append(indexsFromCode, Index{
NonUnique: 1, NonUnique: 1,
ColumnName: fieldName, ColumnName: fieldName,
KeyName: "idx_" + tableFromCode.TableName + "_" + fieldName, KeyName: "idx_" + tableFromCode.TableName.String + "_" + fieldName,
}) })
} }
} }
@@ -147,43 +151,16 @@ func (db *Executor) getIndexsFromCode(typeOf reflect.Type, tableFromCode Table)
return indexsFromCode return indexsFromCode
} }
func getTableFromDb(dataList []map[string]interface{}) Table {
var tableFromDb Table
tableFromDb.TableName = fmt.Sprintf("%v", dataList[0]["TABLE_NAME"])
tableFromDb.Engine = fmt.Sprintf("%v", dataList[0]["ENGINE"])
tableFromDb.Comment = "'" + fmt.Sprintf("%v", dataList[0]["TABLE_COMMENT"]) + "'"
return tableFromDb
}
func (db *Executor) getColumnsFromDb(dbName string, tableName string) []Column { func (db *Executor) getColumnsFromDb(dbName string, tableName string) []Column {
var columnsFromDb []Column var columnsFromDb []Column
sqlColumn := "SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA =" + "'" + dbName + "' AND TABLE_NAME =" + "'" + tableName + "'" sqlColumn := "SELECT COLUMN_NAME,DATA_TYPE,CHARACTER_MAXIMUM_LENGTH as Max_Length,COLUMN_DEFAULT,COLUMN_COMMENT,EXTRA,IS_NULLABLE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA =" + "'" + dbName + "' AND TABLE_NAME =" + "'" + tableName + "'"
dataColumn, _ := db.Query(sqlColumn) db.RawSql(sqlColumn).GetMany(&columnsFromDb)
for j := 0; j < len(dataColumn); j++ { for j := 0; j < len(columnsFromDb); j++ {
dataType := dataColumn[j]["DATA_TYPE"].(string) if columnsFromDb[j].DataType.String == "text" && columnsFromDb[j].MaxLength.Int64 == 65535 {
maxLength, _ := strconv.Atoi(fmt.Sprintf("%v", dataColumn[j]["CHARACTER_MAXIMUM_LENGTH"])) columnsFromDb[j].MaxLength = IntFrom(0)
if dataType == "text" && maxLength == 65535 {
maxLength = 0
} }
defaultVal := ""
if dataColumn[j]["COLUMN_DEFAULT"] != nil {
defaultVal = dataColumn[j]["COLUMN_DEFAULT"].(string)
}
columnsFromDb = append(columnsFromDb, Column{
ColumnName: dataColumn[j]["COLUMN_NAME"].(string),
DataType: dataType,
IsNullable: dataColumn[j]["IS_NULLABLE"].(string),
MaxLength: maxLength,
ColumnType: dataColumn[j]["COLUMN_TYPE"].(string),
ColumnComment: dataColumn[j]["COLUMN_COMMENT"].(string),
Extra: dataColumn[j]["EXTRA"].(string),
DefaultVal: defaultVal,
})
} }
return columnsFromDb return columnsFromDb
@@ -191,17 +168,9 @@ func (db *Executor) getColumnsFromDb(dbName string, tableName string) []Column {
func (db *Executor) getIndexsFromDb(tableName string) []Index { func (db *Executor) getIndexsFromDb(tableName string) []Index {
sqlIndex := "SHOW INDEXES FROM " + tableName sqlIndex := "SHOW INDEXES FROM " + tableName
dataIndex, _ := db.Query(sqlIndex)
var indexsFromDb []Index var indexsFromDb []Index
for j := 0; j < len(dataIndex); j++ { db.RawSql(sqlIndex).GetMany(&indexsFromDb)
nonUnique, _ := strconv.Atoi(fmt.Sprintf("%v", dataIndex[j]["Non_unique"]))
indexsFromDb = append(indexsFromDb, Index{
ColumnName: fmt.Sprintf("%v", dataIndex[j]["Column_name"]),
KeyName: fmt.Sprintf("%v", dataIndex[j]["Key_name"]),
NonUnique: nonUnique,
})
}
return indexsFromDb return indexsFromDb
} }
@@ -209,7 +178,7 @@ func (db *Executor) getIndexsFromDb(tableName string) []Index {
// 修改表 // 修改表
func (db *Executor) modifyTable(tableFromCode Table, columnsFromCode []Column, indexsFromCode []Index, tableFromDb Table, columnsFromDb []Column, indexsFromDb []Index) { func (db *Executor) modifyTable(tableFromCode Table, columnsFromCode []Column, indexsFromCode []Index, tableFromDb Table, columnsFromDb []Column, indexsFromDb []Index) {
if tableFromCode.Engine != tableFromDb.Engine { if tableFromCode.Engine != tableFromDb.Engine {
sql := "ALTER TABLE " + tableFromCode.TableName + " Engine " + tableFromCode.Engine sql := "ALTER TABLE " + tableFromCode.TableName.String + " Engine " + tableFromCode.Engine.String
_, err := db.Exec(sql) _, err := db.Exec(sql)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
@@ -218,8 +187,8 @@ func (db *Executor) modifyTable(tableFromCode Table, columnsFromCode []Column, i
} }
} }
if tableFromCode.Comment != tableFromDb.Comment { if tableFromCode.TableComment != tableFromDb.TableComment {
sql := "ALTER TABLE " + tableFromCode.TableName + " Comment " + tableFromCode.Comment sql := "ALTER TABLE " + tableFromCode.TableName.String + " Comment " + tableFromCode.TableComment.String
_, err := db.Exec(sql) _, err := db.Exec(sql)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
@@ -236,8 +205,12 @@ func (db *Executor) modifyTable(tableFromCode Table, columnsFromCode []Column, i
columnDb := columnsFromDb[j] columnDb := columnsFromDb[j]
if columnCode.ColumnName == columnDb.ColumnName { if columnCode.ColumnName == columnDb.ColumnName {
isFind = 1 isFind = 1
if columnCode.DataType != columnDb.DataType || columnCode.MaxLength != columnDb.MaxLength || columnCode.ColumnComment != columnDb.ColumnComment || columnCode.Extra != columnDb.Extra || columnCode.DefaultVal != columnDb.DefaultVal { if columnCode.DataType.String != columnDb.DataType.String ||
sql := "ALTER TABLE " + tableFromCode.TableName + " MODIFY " + getColumnStr(columnCode) columnCode.MaxLength.Int64 != columnDb.MaxLength.Int64 ||
columnCode.ColumnComment.String != columnDb.ColumnComment.String ||
columnCode.Extra.String != columnDb.Extra.String ||
columnCode.ColumnDefault.String != columnDb.ColumnDefault.String {
sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getColumnStr(columnCode)
_, err := db.Exec(sql) _, err := db.Exec(sql)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
@@ -249,7 +222,7 @@ func (db *Executor) modifyTable(tableFromCode Table, columnsFromCode []Column, i
} }
if isFind == 0 { if isFind == 0 {
sql := "ALTER TABLE " + tableFromCode.TableName + " ADD " + getColumnStr(columnCode) sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getColumnStr(columnCode)
_, err := db.Exec(sql) _, err := db.Exec(sql)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
@@ -268,7 +241,7 @@ func (db *Executor) modifyTable(tableFromCode Table, columnsFromCode []Column, i
if indexCode.ColumnName == indexDb.ColumnName { if indexCode.ColumnName == indexDb.ColumnName {
isFind = 1 isFind = 1
if indexCode.KeyName != indexDb.KeyName || indexCode.NonUnique != indexDb.NonUnique { if indexCode.KeyName != indexDb.KeyName || indexCode.NonUnique != indexDb.NonUnique {
sql := "ALTER TABLE " + tableFromCode.TableName + " MODIFY " + getIndexStr(indexCode) sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getIndexStr(indexCode)
_, err := db.Exec(sql) _, err := db.Exec(sql)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
@@ -280,7 +253,7 @@ func (db *Executor) modifyTable(tableFromCode Table, columnsFromCode []Column, i
} }
if isFind == 0 { if isFind == 0 {
sql := "ALTER TABLE " + tableFromCode.TableName + " ADD " + getIndexStr(indexCode) sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getIndexStr(indexCode)
_, err := db.Exec(sql) _, err := db.Exec(sql)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
@@ -305,12 +278,12 @@ func (db *Executor) createTable(tableFromCode Table, columnsFromCode []Column, i
fieldArr = append(fieldArr, getIndexStr(index)) fieldArr = append(fieldArr, getIndexStr(index))
} }
sqlStr := "CREATE TABLE `" + tableFromCode.TableName + "` (\n" + strings.Join(fieldArr, ",\n") + "\n) " + getTableInfoFromCode(tableFromCode) + ";" sqlStr := "CREATE TABLE `" + tableFromCode.TableName.String + "` (\n" + strings.Join(fieldArr, ",\n") + "\n) " + getTableInfoFromCode(tableFromCode) + ";"
_, err := db.Exec(sqlStr) _, err := db.Exec(sqlStr)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} else { } else {
fmt.Println("创建表:" + tableFromCode.TableName) fmt.Println("创建表:" + tableFromCode.TableName.String)
} }
} }
@@ -326,33 +299,26 @@ func (db *Executor) getValFromOpinion(key string, def string) string {
} }
func getTableInfoFromCode(tableFromCode Table) string { func getTableInfoFromCode(tableFromCode Table) string {
return " ENGINE " + tableFromCode.Engine + " COMMENT " + tableFromCode.Comment return " ENGINE " + tableFromCode.Engine.String + " COMMENT " + tableFromCode.TableComment.String
} }
// 获得某列的结构 // 获得某列的结构
func getColumnFromCode(fieldName string, fieldType string, fieldMap map[string]string) Column { func getColumnFromCode(fieldName string, fieldType string, fieldMap map[string]string) Column {
var column Column var column Column
//字段名 //字段名
column.ColumnName = fieldName column.ColumnName = StringFrom(fieldName)
//字段数据类型 //字段数据类型
column.DataType = getDataType(fieldType, fieldMap) column.DataType = StringFrom(getDataType(fieldType, fieldMap))
//字段数据长度 //字段数据长度
maxLength := getMaxLength(column.DataType, fieldMap) column.MaxLength = IntFrom(int64(getMaxLength(column.DataType.String, fieldMap)))
columnType := column.DataType
if maxLength > 0 {
columnType = columnType + "(" + strconv.Itoa(maxLength) + ")"
}
column.MaxLength = maxLength
//字段是否可以为空 //字段是否可以为空
column.IsNullable = getNullAble(fieldMap) column.IsNullable = StringFrom(getNullAble(fieldMap))
//字段注释 //字段注释
column.ColumnComment = getComment(fieldMap) column.ColumnComment = StringFrom(getComment(fieldMap))
//字段类型
column.ColumnType = columnType
//扩展信息 //扩展信息
column.Extra = getExtra(fieldMap) column.Extra = StringFrom(getExtra(fieldMap))
//默认信息 //默认信息
column.DefaultVal = getDefaultVal(fieldMap) column.ColumnDefault = StringFrom(getDefaultVal(fieldMap))
return column return column
} }
@@ -375,31 +341,31 @@ func getTagMap(fieldTag string) map[string]string {
func getColumnStr(column Column) string { func getColumnStr(column Column) string {
var strArr []string var strArr []string
strArr = append(strArr, column.ColumnName) strArr = append(strArr, column.ColumnName.String)
if column.MaxLength == 0 { if column.MaxLength.Int64 == 0 {
if column.DataType == "varchar" { if column.DataType.String == "varchar" {
strArr = append(strArr, column.DataType+"(255)") strArr = append(strArr, column.DataType.String+"(255)")
} else { } else {
strArr = append(strArr, column.DataType) strArr = append(strArr, column.DataType.String)
} }
} else { } else {
strArr = append(strArr, column.DataType+"("+strconv.Itoa(column.MaxLength)+")") strArr = append(strArr, column.DataType.String+"("+strconv.Itoa(int(column.MaxLength.Int64))+")")
} }
if column.DefaultVal != "" { if column.ColumnDefault.String != "" {
strArr = append(strArr, "DEFAULT '"+column.DefaultVal+"'") strArr = append(strArr, "DEFAULT '"+column.ColumnDefault.String+"'")
} }
if column.IsNullable == "NO" { if column.IsNullable.String == "NO" {
strArr = append(strArr, "NOT NULL") strArr = append(strArr, "NOT NULL")
} }
if column.ColumnComment != "" { if column.ColumnComment.String != "" {
strArr = append(strArr, "COMMENT '"+column.ColumnComment+"'") strArr = append(strArr, "COMMENT '"+column.ColumnComment.String+"'")
} }
if column.Extra != "" { if column.Extra.String != "" {
strArr = append(strArr, column.Extra) strArr = append(strArr, column.Extra.String)
} }
return strings.Join(strArr, " ") return strings.Join(strArr, " ")

View File

@@ -94,7 +94,6 @@ func TestAll(t *testing.T) {
testMin(name, db) testMin(name, db)
testMax(name, db) testMax(name, db)
testQuery(name, db)
testExec(name, db) testExec(name, db)
testTransaction(name, db) testTransaction(name, db)
@@ -127,7 +126,6 @@ func testConnect() *sql.DB {
} }
func testMigrate(name string, db *sql.DB) { func testMigrate(name string, db *sql.DB) {
//AutoMigrate //AutoMigrate
aorm.Use(db).Opinion("ENGINE", "InnoDB").Opinion("COMMENT", "人员表").AutoMigrate(&Person{}) aorm.Use(db).Opinion("ENGINE", "InnoDB").Opinion("COMMENT", "人员表").AutoMigrate(&Person{})
aorm.Use(db).Opinion("ENGINE", "InnoDB").Opinion("COMMENT", "文章").AutoMigrate(&Article{}) aorm.Use(db).Opinion("ENGINE", "InnoDB").Opinion("COMMENT", "文章").AutoMigrate(&Article{})
@@ -206,7 +204,7 @@ func testGetMany(name string, db *sql.DB) {
var list []Person var list []Person
errSelect := aorm.Use(db).Debug(false).Where(&Person{Type: aorm.IntFrom(0)}).GetMany(&list) errSelect := aorm.Use(db).Debug(false).Where(&Person{Type: aorm.IntFrom(0)}).GetMany(&list)
if errSelect != nil { if errSelect != nil {
panic(name + "testGetMany" + "found err") panic(name + " testGetMany " + "found err:" + errSelect.Error())
} }
} }
@@ -441,7 +439,7 @@ func testPluck(name string, db *sql.DB) {
var ageList []int64 var ageList []int64
errAgeList := aorm.Use(db).Debug(false).Where(&Person{Type: aorm.IntFrom(0)}).Limit(0, 3).Pluck("age", &ageList) errAgeList := aorm.Use(db).Debug(false).Where(&Person{Type: aorm.IntFrom(0)}).Limit(0, 3).Pluck("age", &ageList)
if errAgeList != nil { if errAgeList != nil {
panic(name + "testPluck" + "found err") panic(name + "testPluck" + "found err:" + errAgeList.Error())
} }
var moneyList []float32 var moneyList []float32
@@ -492,13 +490,6 @@ func testMax(name string, db *sql.DB) {
} }
} }
func testQuery(name string, db *sql.DB) {
_, err := aorm.Use(db).Debug(false).Query("SELECT * FROM person WHERE id=? AND type=?", 1, 3)
if err != nil {
panic(name + "testQuery" + "found err")
}
}
func testExec(name string, db *sql.DB) { func testExec(name string, db *sql.DB) {
_, err := aorm.Use(db).Debug(false).Exec("UPDATE person SET name = ? WHERE id=?", "Bob", 3) _, err := aorm.Use(db).Debug(false).Exec("UPDATE person SET name = ? WHERE id=?", "Bob", 3)
if err != nil { if err != nil {