From fccbc17adaccbbfd2e594c9a8c9ea69921c5f0f5 Mon Sep 17 00:00:00 2001 From: tangpanqing Date: Fri, 23 Dec 2022 11:10:17 +0800 Subject: [PATCH] support sqlite3 --- .gitignore | 1 + aorm.go | 16 +- builder/aggregation.go | 66 +++ builder/crud.go | 644 +++++++++++++++++++++ builder/handle.go | 119 ++++ builder/having.go | 151 +++++ builder/join.go | 19 + builder/select.go | 46 ++ builder/where.go | 151 +++++ executor/crud.go | 1124 ------------------------------------ executor/executor.go | 38 -- migrate_mysql/migrate.go | 4 +- migrate_sqlite3/migrate.go | 88 +-- migrator/migrator.go | 8 +- test/aorm_test.go | 126 ++-- 15 files changed, 1308 insertions(+), 1293 deletions(-) create mode 100644 builder/aggregation.go create mode 100644 builder/crud.go create mode 100644 builder/handle.go create mode 100644 builder/having.go create mode 100644 builder/join.go create mode 100644 builder/select.go create mode 100644 builder/where.go delete mode 100644 executor/crud.go delete mode 100644 executor/executor.go diff --git a/.gitignore b/.gitignore index 4170155..ba89277 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /go.sum +/test/test.db diff --git a/aorm.go b/aorm.go index 3268f5d..4449479 100644 --- a/aorm.go +++ b/aorm.go @@ -2,7 +2,7 @@ package aorm import ( "database/sql" //只需导入你需要的驱动即可 - "github.com/tangpanqing/aorm/executor" + "github.com/tangpanqing/aorm/builder" "github.com/tangpanqing/aorm/migrator" "github.com/tangpanqing/aorm/model" ) @@ -27,8 +27,8 @@ func Open(driverName string, dataSourceName string) (DbContent, error) { } // Use 开始一个数据库操作 -func Use(linkCommon model.LinkCommon) *executor.Executor { - executor := &executor.Executor{ +func Use(linkCommon model.LinkCommon) *builder.Builder { + executor := &builder.Builder{ LinkCommon: linkCommon, } @@ -36,8 +36,8 @@ func Use(linkCommon model.LinkCommon) *executor.Executor { } // Sub 子查询 -func Sub() *executor.Executor { - executor := &executor.Executor{} +func Sub() *builder.Builder { + executor := &builder.Builder{} return executor } @@ -50,13 +50,13 @@ func Migrator(linkCommon model.LinkCommon) *migrator.Migrator { } //清空查询条件,复用对象 -//func (ex *executor.Executor) clear() { +//func (ex *builder.Executor) clear() { // ex.tableName = "" // ex.selectList = make([]string, 0) // ex.groupList = make([]string, 0) -// ex.whereList = make([]executor.WhereItem, 0) +// ex.whereList = make([]builder.WhereItem, 0) // ex.joinList = make([]string, 0) -// ex.havingList = make([]executor.WhereItem, 0) +// ex.havingList = make([]builder.WhereItem, 0) // ex.orderList = make([]string, 0) // ex.offset = 0 // ex.pageSize = 0 diff --git a/builder/aggregation.go b/builder/aggregation.go new file mode 100644 index 0000000..c08b586 --- /dev/null +++ b/builder/aggregation.go @@ -0,0 +1,66 @@ +package builder + +import "github.com/tangpanqing/aorm/null" + +type IntStruct struct { + C null.Int +} + +type FloatStruct struct { + C null.Float +} + +// Count 聚合函数-数量 +func (ex *Builder) Count(fieldName string) (int64, error) { + var obj []IntStruct + err := ex.SelectCount(fieldName, "c").GetMany(&obj) + if err != nil { + return 0, err + } + + return obj[0].C.Int64, nil +} + +// Sum 聚合函数-合计 +func (ex *Builder) Sum(fieldName string) (float64, error) { + var obj []FloatStruct + err := ex.SelectSum(fieldName, "c").GetMany(&obj) + if err != nil { + return 0, err + } + + return obj[0].C.Float64, nil +} + +// Avg 聚合函数-平均值 +func (ex *Builder) Avg(fieldName string) (float64, error) { + var obj []FloatStruct + err := ex.SelectAvg(fieldName, "c").GetMany(&obj) + if err != nil { + return 0, err + } + + return obj[0].C.Float64, nil +} + +// Max 聚合函数-最大值 +func (ex *Builder) Max(fieldName string) (float64, error) { + var obj []FloatStruct + err := ex.SelectMax(fieldName, "c").GetMany(&obj) + if err != nil { + return 0, err + } + + return obj[0].C.Float64, nil +} + +// Min 聚合函数-最小值 +func (ex *Builder) Min(fieldName string) (float64, error) { + var obj []FloatStruct + err := ex.SelectMin(fieldName, "c").GetMany(&obj) + if err != nil { + return 0, err + } + + return obj[0].C.Float64, nil +} diff --git a/builder/crud.go b/builder/crud.go new file mode 100644 index 0000000..1b45e27 --- /dev/null +++ b/builder/crud.go @@ -0,0 +1,644 @@ +package builder + +import ( + "database/sql" + "errors" + "fmt" + "github.com/tangpanqing/aorm/helper" + "github.com/tangpanqing/aorm/model" + "reflect" + "strings" + "unsafe" +) + +const Desc = "DESC" +const Asc = "ASC" + +const Eq = "=" +const Ne = "!=" +const Gt = ">" +const Ge = ">=" +const Lt = "<" +const Le = "<=" + +const In = "IN" +const NotIn = "NOT IN" +const Like = "LIKE" +const NotLike = "NOT LIKE" +const Between = "BETWEEN" +const NotBetween = "NOT BETWEEN" + +const Raw = "Raw" + +// SelectItem 将某子语句重命名为某字段 +type SelectItem struct { + Executor **Builder + FieldName string +} + +// Builder 查询记录所需要的条件 +type Builder struct { + //数据库操作连接 + LinkCommon model.LinkCommon + + //查询参数 + tableName string + selectList []string + selectExpList []*SelectItem + groupList []string + whereList []WhereItem + joinList []string + havingList []WhereItem + orderList []string + offset int + pageSize int + isDebug bool + isLockForUpdate bool + + //sql与参数 + sql string + paramList []interface{} + + //驱动名字 + driverName string +} + +type WhereItem struct { + Field string + Opt string + Val any +} + +func (ex *Builder) Driver(driverName string) *Builder { + ex.driverName = driverName + return ex +} + +// Insert 增加记录 +func (ex *Builder) Insert(dest interface{}) (int64, error) { + typeOf := reflect.TypeOf(dest) + valueOf := reflect.ValueOf(dest) + + //如果没有设置表名 + if ex.tableName == "" { + ex.tableName = getTableName(typeOf, valueOf) + } + + var keys []string + var paramList []any + var place []string + for i := 0; i < typeOf.Elem().NumField(); i++ { + isNotNull := valueOf.Elem().Field(i).Field(0).Field(1).Bool() + if isNotNull { + key := helper.UnderLine(typeOf.Elem().Field(i).Name) + val := valueOf.Elem().Field(i).Field(0).Field(0).Interface() + keys = append(keys, key) + paramList = append(paramList, val) + place = append(place, "?") + } + } + + sqlStr := "INSERT INTO " + ex.tableName + " (" + strings.Join(keys, ",") + ") VALUES (" + strings.Join(place, ",") + ")" + + res, err := ex.Exec(sqlStr, paramList...) + if err != nil { + return 0, err + } + + lastId, err := res.LastInsertId() + if err != nil { + return 0, err + } + + return lastId, nil +} + +// InsertBatch 批量增加记录 +func (ex *Builder) InsertBatch(values interface{}) (int64, error) { + + var keys []string + var paramList []any + var place []string + + valueOf := reflect.ValueOf(values).Elem() + if valueOf.Len() == 0 { + return 0, errors.New("the data list for insert batch not found") + } + typeOf := reflect.TypeOf(values).Elem().Elem() + + //如果没有设置表名 + if ex.tableName == "" { + ex.tableName = getTableName(typeOf, valueOf.Index(0)) + } + + for j := 0; j < valueOf.Len(); j++ { + var placeItem []string + + for i := 0; i < valueOf.Index(j).NumField(); i++ { + isNotNull := valueOf.Index(j).Field(i).Field(0).Field(1).Bool() + if isNotNull { + if j == 0 { + key := helper.UnderLine(typeOf.Field(i).Name) + keys = append(keys, key) + } + + val := valueOf.Index(j).Field(i).Field(0).Field(0).Interface() + paramList = append(paramList, val) + placeItem = append(placeItem, "?") + } + } + + place = append(place, "("+strings.Join(placeItem, ",")+")") + } + + sqlStr := "INSERT INTO " + ex.tableName + " (" + strings.Join(keys, ",") + ") VALUES " + strings.Join(place, ",") + + res, err := ex.Exec(sqlStr, paramList...) + if err != nil { + return 0, err + } + + count, err := res.RowsAffected() + if err != nil { + return 0, err + } + + return count, nil +} + +// GetRows 获取行操作 +func (ex *Builder) GetRows() (*sql.Rows, error) { + sqlStr, paramList := ex.GetSqlAndParams() + + smt, errSmt := ex.LinkCommon.Prepare(sqlStr) + if errSmt != nil { + return nil, errSmt + } + //defer smt.Close() + + rows, errRows := smt.Query(paramList...) + if errRows != nil { + return nil, errRows + } + + return rows, nil +} + +// GetMany 查询记录(新) +func (ex *Builder) GetMany(values interface{}) error { + rows, errRows := ex.GetRows() + defer rows.Close() + if errRows != nil { + return errRows + } + + destSlice := reflect.Indirect(reflect.ValueOf(values)) + destType := destSlice.Type().Elem() + destValue := reflect.New(destType).Elem() + + //从数据库中读出来的字段名字 + columnNameList, errColumns := rows.Columns() + if errColumns != nil { + return errColumns + } + + //从结构体反射出来的属性名 + fieldNameMap := getFieldNameMap(destValue, destType) + + for rows.Next() { + scans := getScans(columnNameList, fieldNameMap, destValue) + + errScan := rows.Scan(scans...) + if errScan != nil { + return errScan + } + + destSlice.Set(reflect.Append(destSlice, destValue)) + } + + return nil +} + +// GetOne 查询某一条记录 +func (ex *Builder) GetOne(obj interface{}) error { + ex.Limit(0, 1) + + rows, errRows := ex.GetRows() + defer rows.Close() + if errRows != nil { + return errRows + } + + destType := reflect.TypeOf(obj).Elem() + destValue := reflect.ValueOf(obj).Elem() + + //从数据库中读出来的字段名字 + columnNameList, errColumns := rows.Columns() + if errColumns != nil { + return errColumns + } + + //从结构体反射出来的属性名 + fieldNameMap := getFieldNameMap(destValue, destType) + + for rows.Next() { + scans := getScans(columnNameList, fieldNameMap, destValue) + err := rows.Scan(scans...) + if err != nil { + return err + } + } + + return nil +} + +// RawSql 执行原始的sql语句 +func (ex *Builder) RawSql(sql string, paramList ...interface{}) *Builder { + ex.sql = sql + ex.paramList = paramList + return ex +} + +func (ex *Builder) GetSqlAndParams() (string, []interface{}) { + if ex.sql != "" { + return ex.sql, ex.paramList + } + + var paramList []interface{} + + fieldStr, paramList := handleField(ex.selectList, ex.selectExpList, paramList) + whereStr, paramList := ex.handleWhere(ex.whereList, paramList) + joinStr := handleJoin(ex.joinList) + groupStr := handleGroup(ex.groupList) + havingStr, paramList := ex.handleHaving(ex.havingList, paramList) + orderStr := handleOrder(ex.orderList) + limitStr, paramList := handleLimit(ex.offset, ex.pageSize, paramList) + lockStr := handleLockForUpdate(ex.isLockForUpdate) + + sqlStr := "SELECT " + fieldStr + " FROM " + ex.tableName + joinStr + whereStr + groupStr + havingStr + orderStr + limitStr + lockStr + + if ex.isDebug { + fmt.Println(sqlStr) + fmt.Println(paramList...) + } + + return sqlStr, paramList +} + +// Update 更新记录 +func (ex *Builder) Update(dest interface{}) (int64, error) { + var paramList []any + setStr, paramList := ex.handleSet(dest, paramList) + whereStr, paramList := ex.handleWhere(ex.whereList, paramList) + sqlStr := "UPDATE " + ex.tableName + setStr + whereStr + + return ex.ExecAffected(sqlStr, paramList...) +} + +// Delete 删除记录 +func (ex *Builder) Delete() (int64, error) { + var paramList []any + whereStr, paramList := ex.handleWhere(ex.whereList, paramList) + sqlStr := "DELETE FROM " + ex.tableName + whereStr + + return ex.ExecAffected(sqlStr, paramList...) +} + +// Truncate 清空记录, sqlte3不支持此操作 +func (ex *Builder) Truncate() (int64, error) { + sqlStr := "TRUNCATE TABLE " + ex.tableName + + return ex.ExecAffected(sqlStr) +} + +// Value 字段值 +func (ex *Builder) Value(fieldName string, dest interface{}) error { + ex.Select(fieldName).Limit(0, 1) + + rows, errRows := ex.GetRows() + defer rows.Close() + if errRows != nil { + return errRows + } + + destValue := reflect.ValueOf(dest).Elem() + + //从数据库中读出来的字段名字 + columnNameList, errColumns := rows.Columns() + if errColumns != nil { + return errColumns + } + + for rows.Next() { + var scans []interface{} + for _, columnName := range columnNameList { + if fieldName == columnName { + scans = append(scans, destValue.Addr().Interface()) + } else { + var emptyVal interface{} + scans = append(scans, &emptyVal) + } + } + + err := rows.Scan(scans...) + if err != nil { + return err + } + } + + return nil +} + +// Pluck 获取某一列的值 +func (ex *Builder) Pluck(fieldName string, values interface{}) error { + ex.Select(fieldName) + + rows, errRows := ex.GetRows() + defer rows.Close() + if errRows != nil { + return errRows + } + + destSlice := reflect.Indirect(reflect.ValueOf(values)) + destType := destSlice.Type().Elem() + destValue := reflect.New(destType).Elem() + + //从数据库中读出来的字段名字 + columnNameList, errColumns := rows.Columns() + if errColumns != nil { + return errColumns + } + + for rows.Next() { + var scans []interface{} + for _, columnName := range columnNameList { + if fieldName == columnName { + scans = append(scans, destValue.Addr().Interface()) + } else { + var emptyVal interface{} + scans = append(scans, &emptyVal) + } + } + + err := rows.Scan(scans...) + if err != nil { + return err + } + + destSlice.Set(reflect.Append(destSlice, destValue)) + } + + return nil +} + +// Increment 某字段自增 +func (ex *Builder) Increment(fieldName string, step int) (int64, error) { + var paramList []any + paramList = append(paramList, step) + whereStr, paramList := ex.handleWhere(ex.whereList, paramList) + sqlStr := "UPDATE " + ex.tableName + " SET " + fieldName + "=" + fieldName + "+?" + whereStr + + return ex.ExecAffected(sqlStr, paramList...) +} + +// Decrement 某字段自减 +func (ex *Builder) Decrement(fieldName string, step int) (int64, error) { + var paramList []any + paramList = append(paramList, step) + whereStr, paramList := ex.handleWhere(ex.whereList, paramList) + sqlStr := "UPDATE " + ex.tableName + " SET " + fieldName + "=" + fieldName + "-?" + whereStr + + return ex.ExecAffected(sqlStr, paramList...) +} + +// Exec 通用执行-新增,更新,删除 +func (ex *Builder) Exec(sqlStr string, args ...interface{}) (sql.Result, error) { + if ex.isDebug { + fmt.Println(sqlStr) + fmt.Println(args...) + } + + smt, err1 := ex.LinkCommon.Prepare(sqlStr) + if err1 != nil { + return nil, err1 + } + defer smt.Close() + + res, err2 := smt.Exec(args...) + if err2 != nil { + return nil, err2 + } + + //ex.clear() + return res, nil +} + +// ExecAffected 通用执行-更新,删除 +func (ex *Builder) ExecAffected(sqlStr string, args ...interface{}) (int64, error) { + res, err := ex.Exec(sqlStr, args...) + if err != nil { + return 0, err + } + + count, err := res.RowsAffected() + if err != nil { + return 0, err + } + + return count, nil +} + +// Debug 链式操作-是否开启调试,打印sql +func (ex *Builder) Debug(isDebug bool) *Builder { + ex.isDebug = isDebug + return ex +} + +// Table 链式操作-从哪个表查询,允许直接写别名,例如 person p +func (ex *Builder) Table(tableName string) *Builder { + ex.tableName = tableName + return ex +} + +// GroupBy 链式操作,以某字段进行分组 +func (ex *Builder) GroupBy(fieldName string) *Builder { + ex.groupList = append(ex.groupList, fieldName) + return ex +} + +// OrderBy 链式操作,以某字段进行排序 +func (ex *Builder) OrderBy(field string, orderType string) *Builder { + ex.orderList = append(ex.orderList, field+" "+orderType) + return ex +} + +// Limit 链式操作,分页 +func (ex *Builder) Limit(offset int, pageSize int) *Builder { + ex.offset = offset + ex.pageSize = pageSize + return ex +} + +// Page 链式操作,分页 +func (ex *Builder) Page(pageNum int, pageSize int) *Builder { + ex.offset = (pageNum - 1) * pageSize + ex.pageSize = pageSize + return ex +} + +// LockForUpdate 加锁, sqlte3不支持此操作 +func (ex *Builder) LockForUpdate(isLockForUpdate bool) *Builder { + ex.isLockForUpdate = isLockForUpdate + return ex +} + +//拼接SQL,查询与筛选通用操作 +func (ex *Builder) whereAndHaving(where []WhereItem, paramList []any) ([]string, []any) { + var whereList []string + for i := 0; i < len(where); i++ { + if "**builder.Builder" == reflect.TypeOf(where[i].Val).String() { + executor := *(**Builder)(unsafe.Pointer(reflect.ValueOf(where[i].Val).Pointer())) + subSql, subParams := executor.GetSqlAndParams() + + if where[i].Opt != Raw { + whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+"("+subSql+")") + paramList = append(paramList, subParams...) + } else { + + } + } else { + if where[i].Opt == Eq || where[i].Opt == Ne || where[i].Opt == Gt || where[i].Opt == Ge || where[i].Opt == Lt || where[i].Opt == Le { + if ex.driverName == "sqlite3" { + whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+"?") + } else { + switch where[i].Val.(type) { + case float32: + whereList = append(whereList, ex.getConcat(where[i].Field, "''")+" "+where[i].Opt+" "+"?") + case float64: + whereList = append(whereList, ex.getConcat(where[i].Field, "''")+" "+where[i].Opt+" "+"?") + default: + whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+"?") + } + } + + paramList = append(paramList, fmt.Sprintf("%v", where[i].Val)) + } + + if where[i].Opt == Between || where[i].Opt == NotBetween { + values := toAnyArr(where[i].Val) + whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+"(?) AND (?)") + paramList = append(paramList, values...) + } + + if where[i].Opt == Like || where[i].Opt == NotLike { + values := toAnyArr(where[i].Val) + var valueStr []string + for j := 0; j < len(values); j++ { + str := fmt.Sprintf("%v", values[j]) + + if "%" != str { + paramList = append(paramList, str) + valueStr = append(valueStr, "?") + } else { + valueStr = append(valueStr, "'"+str+"'") + } + } + + whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+ex.getConcat(valueStr...)) + } + + if where[i].Opt == In || where[i].Opt == NotIn { + values := toAnyArr(where[i].Val) + var placeholder []string + for j := 0; j < len(values); j++ { + placeholder = append(placeholder, "?") + } + + whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+"("+strings.Join(placeholder, ",")+")") + paramList = append(paramList, values...) + } + + if where[i].Opt == Raw { + whereList = append(whereList, where[i].Field+fmt.Sprintf("%v", where[i].Val)) + } + } + } + + return whereList, paramList +} + +//将一个interface抽取成数组 +func toAnyArr(val any) []any { + var values []any + switch val.(type) { + case []int: + for _, value := range val.([]int) { + values = append(values, value) + } + case []int64: + for _, value := range val.([]int64) { + values = append(values, value) + } + case []float32: + for _, value := range val.([]float32) { + values = append(values, value) + } + case []float64: + for _, value := range val.([]float64) { + values = append(values, value) + } + case []string: + for _, value := range val.([]string) { + values = append(values, value) + } + } + + return values +} + +//反射表名,优先从方法获取,没有方法则从名字获取 +func getTableName(typeOf reflect.Type, valueOf reflect.Value) string { + method, isSet := typeOf.MethodByName("TableName") + if isSet { + var paramList []reflect.Value + paramList = append(paramList, valueOf) + res := method.Func.Call(paramList) + return res[0].String() + } else { + arr := strings.Split(typeOf.String(), ".") + return helper.UnderLine(arr[len(arr)-1]) + } +} + +func getFieldNameMap(destValue reflect.Value, destType reflect.Type) map[string]int { + fieldNameMap := make(map[string]int) + for i := 0; i < destValue.NumField(); i++ { + fieldNameMap[destType.Field(i).Name] = i + } + + return fieldNameMap +} + +func getScans(columnNameList []string, fieldNameMap map[string]int, destValue reflect.Value) []interface{} { + var scans []interface{} + for _, columnName := range columnNameList { + fieldName := helper.CamelString(strings.ToLower(columnName)) + index, ok := fieldNameMap[fieldName] + if ok { + scans = append(scans, destValue.Field(index).Addr().Interface()) + } else { + var emptyVal interface{} + scans = append(scans, &emptyVal) + } + } + + return scans +} + +func (ex *Builder) getConcat(vars ...string) string { + if ex.driverName == "sqlite3" { + return strings.Join(vars, "||") + } else { + return "CONCAT(" + strings.Join(vars, ",") + ")" + } +} diff --git a/builder/handle.go b/builder/handle.go new file mode 100644 index 0000000..790d300 --- /dev/null +++ b/builder/handle.go @@ -0,0 +1,119 @@ +package builder + +import ( + "github.com/tangpanqing/aorm/helper" + "reflect" + "strings" +) + +//拼接SQL,字段相关 +func handleField(selectList []string, selectExpList []*SelectItem, paramList []any) (string, []any) { + if len(selectList) == 0 && len(selectExpList) == 0 { + return "*", paramList + } + + //处理子语句 + for i := 0; i < len(selectExpList); i++ { + executor := *(selectExpList[i].Executor) + subSql, subParamList := executor.GetSqlAndParams() + selectList = append(selectList, "("+subSql+") AS "+selectExpList[i].FieldName) + paramList = append(paramList, subParamList...) + } + + return strings.Join(selectList, ","), paramList +} + +//拼接SQL,查询条件 +func (ex *Builder) handleWhere(where []WhereItem, paramList []any) (string, []any) { + if len(where) == 0 { + return "", paramList + } + + whereList, paramList := ex.whereAndHaving(where, paramList) + + return " WHERE " + strings.Join(whereList, " AND "), paramList +} + +//拼接SQL,更新信息 +func (ex *Builder) handleSet(dest interface{}, paramList []any) (string, []any) { + typeOf := reflect.TypeOf(dest) + valueOf := reflect.ValueOf(dest) + + //如果没有设置表名 + if ex.tableName == "" { + ex.tableName = getTableName(typeOf, valueOf) + } + + var keys []string + for i := 0; i < typeOf.Elem().NumField(); i++ { + isNotNull := valueOf.Elem().Field(i).Field(0).Field(1).Bool() + if isNotNull { + key := helper.UnderLine(typeOf.Elem().Field(i).Name) + val := valueOf.Elem().Field(i).Field(0).Field(0).Interface() + + keys = append(keys, key+"=?") + paramList = append(paramList, val) + } + } + + return " SET " + strings.Join(keys, ","), paramList +} + +//拼接SQL,关联查询 +func handleJoin(joinList []string) string { + if len(joinList) == 0 { + return "" + } + + return " " + strings.Join(joinList, " ") +} + +//拼接SQL,结果分组 +func handleGroup(groupList []string) string { + if len(groupList) == 0 { + return "" + } + + return " GROUP BY " + strings.Join(groupList, ",") +} + +//拼接SQL,结果筛选 +func (ex *Builder) handleHaving(having []WhereItem, paramList []any) (string, []any) { + if len(having) == 0 { + return "", paramList + } + + whereList, paramList := ex.whereAndHaving(having, paramList) + + return " Having " + strings.Join(whereList, " AND "), paramList +} + +//拼接SQL,结果排序 +func handleOrder(orderList []string) string { + if len(orderList) == 0 { + return "" + } + + return " Order BY " + strings.Join(orderList, ",") +} + +//拼接SQL,分页相关 +func handleLimit(offset int, pageSize int, paramList []any) (string, []any) { + if 0 == pageSize { + return "", paramList + } + + paramList = append(paramList, offset) + paramList = append(paramList, pageSize) + + return " Limit ?,? ", paramList +} + +//拼接SQL,锁 +func handleLockForUpdate(isLock bool) string { + if isLock { + return " FOR UPDATE" + } + + return "" +} diff --git a/builder/having.go b/builder/having.go new file mode 100644 index 0000000..c38755a --- /dev/null +++ b/builder/having.go @@ -0,0 +1,151 @@ +package builder + +import ( + "github.com/tangpanqing/aorm/helper" + "reflect" +) + +// Having 链式操作,以对象作为筛选条件 +func (ex *Builder) Having(dest interface{}) *Builder { + typeOf := reflect.TypeOf(dest) + valueOf := reflect.ValueOf(dest) + + //如果没有设置表名 + if ex.tableName == "" { + ex.tableName = getTableName(typeOf, valueOf) + } + + for i := 0; i < typeOf.Elem().NumField(); i++ { + isNotNull := valueOf.Elem().Field(i).Field(0).Field(1).Bool() + if isNotNull { + key := helper.UnderLine(typeOf.Elem().Field(i).Name) + val := valueOf.Elem().Field(i).Field(0).Field(0).Interface() + ex.havingList = append(ex.havingList, WhereItem{Field: key, Opt: Eq, Val: val}) + } + } + + return ex +} + +// HavingArr 链式操作,以数组作为筛选条件 +func (ex *Builder) HavingArr(havingList []WhereItem) *Builder { + ex.havingList = append(ex.havingList, havingList...) + return ex +} + +func (ex *Builder) HavingEq(field string, val interface{}) *Builder { + ex.havingList = append(ex.havingList, WhereItem{ + Field: field, + Opt: Eq, + Val: val, + }) + return ex +} + +func (ex *Builder) HavingNe(field string, val interface{}) *Builder { + ex.havingList = append(ex.havingList, WhereItem{ + Field: field, + Opt: Ne, + Val: val, + }) + return ex +} + +func (ex *Builder) HavingGt(field string, val interface{}) *Builder { + ex.havingList = append(ex.havingList, WhereItem{ + Field: field, + Opt: Gt, + Val: val, + }) + return ex +} + +func (ex *Builder) HavingGe(field string, val interface{}) *Builder { + ex.havingList = append(ex.havingList, WhereItem{ + Field: field, + Opt: Ge, + Val: val, + }) + return ex +} + +func (ex *Builder) HavingLt(field string, val interface{}) *Builder { + ex.havingList = append(ex.havingList, WhereItem{ + Field: field, + Opt: Lt, + Val: val, + }) + return ex +} + +func (ex *Builder) HavingLe(field string, val interface{}) *Builder { + ex.havingList = append(ex.havingList, WhereItem{ + Field: field, + Opt: Le, + Val: val, + }) + return ex +} + +func (ex *Builder) HavingIn(field string, val interface{}) *Builder { + ex.havingList = append(ex.havingList, WhereItem{ + Field: field, + Opt: In, + Val: val, + }) + return ex +} + +func (ex *Builder) HavingNotIn(field string, val interface{}) *Builder { + ex.havingList = append(ex.havingList, WhereItem{ + Field: field, + Opt: NotIn, + Val: val, + }) + return ex +} + +func (ex *Builder) HavingBetween(field string, val interface{}) *Builder { + ex.havingList = append(ex.havingList, WhereItem{ + Field: field, + Opt: Between, + Val: val, + }) + return ex +} + +func (ex *Builder) HavingNotBetween(field string, val interface{}) *Builder { + ex.havingList = append(ex.havingList, WhereItem{ + Field: field, + Opt: NotBetween, + Val: val, + }) + return ex +} + +func (ex *Builder) HavingLike(field string, val interface{}) *Builder { + ex.havingList = append(ex.havingList, WhereItem{ + Field: field, + Opt: Like, + Val: val, + }) + return ex +} + +func (ex *Builder) HavingNotLike(field string, val interface{}) *Builder { + ex.havingList = append(ex.havingList, WhereItem{ + Field: field, + Opt: NotLike, + Val: val, + }) + return ex +} + +func (ex *Builder) HavingRaw(field string, val interface{}) *Builder { + ex.havingList = append(ex.havingList, WhereItem{ + Field: field, + Opt: Raw, + Val: val, + }) + return ex +} diff --git a/builder/join.go b/builder/join.go new file mode 100644 index 0000000..e504925 --- /dev/null +++ b/builder/join.go @@ -0,0 +1,19 @@ +package builder + +// LeftJoin 链式操作,左联查询,例如 LeftJoin("project p", "p.project_id=o.project_id") +func (ex *Builder) LeftJoin(tableName string, condition string) *Builder { + ex.joinList = append(ex.joinList, "LEFT JOIN "+tableName+" ON "+condition) + return ex +} + +// RightJoin 链式操作,右联查询,例如 RightJoin("project p", "p.project_id=o.project_id") +func (ex *Builder) RightJoin(tableName string, condition string) *Builder { + ex.joinList = append(ex.joinList, "RIGHT JOIN "+tableName+" ON "+condition) + return ex +} + +// Join 链式操作,内联查询,例如 Join("project p", "p.project_id=o.project_id") +func (ex *Builder) Join(tableName string, condition string) *Builder { + ex.joinList = append(ex.joinList, "INNER JOIN "+tableName+" ON "+condition) + return ex +} diff --git a/builder/select.go b/builder/select.go new file mode 100644 index 0000000..6e1018c --- /dev/null +++ b/builder/select.go @@ -0,0 +1,46 @@ +package builder + +// Select 链式操作-查询哪些字段,默认 * +func (ex *Builder) Select(fields ...string) *Builder { + ex.selectList = append(ex.selectList, fields...) + return ex +} + +// SelectCount 链式操作-count(field) as field_new +func (ex *Builder) SelectCount(field string, fieldNew string) *Builder { + ex.selectList = append(ex.selectList, "count("+field+") AS "+fieldNew) + return ex +} + +// SelectSum 链式操作-sum(field) as field_new +func (ex *Builder) SelectSum(field string, fieldNew string) *Builder { + ex.selectList = append(ex.selectList, "sum("+field+") AS "+fieldNew) + return ex +} + +// SelectMin 链式操作-min(field) as field_new +func (ex *Builder) SelectMin(field string, fieldNew string) *Builder { + ex.selectList = append(ex.selectList, "min("+field+") AS "+fieldNew) + return ex +} + +// SelectMax 链式操作-max(field) as field_new +func (ex *Builder) SelectMax(field string, fieldNew string) *Builder { + ex.selectList = append(ex.selectList, "max("+field+") AS "+fieldNew) + return ex +} + +// SelectAvg 链式操作-avg(field) as field_new +func (ex *Builder) SelectAvg(field string, fieldNew string) *Builder { + ex.selectList = append(ex.selectList, "avg("+field+") AS "+fieldNew) + return ex +} + +// SelectExp 链式操作-表达式 +func (ex *Builder) SelectExp(dbSub **Builder, fieldName string) *Builder { + ex.selectExpList = append(ex.selectExpList, &SelectItem{ + Executor: dbSub, + FieldName: fieldName, + }) + return ex +} diff --git a/builder/where.go b/builder/where.go new file mode 100644 index 0000000..5b147d6 --- /dev/null +++ b/builder/where.go @@ -0,0 +1,151 @@ +package builder + +import ( + "github.com/tangpanqing/aorm/helper" + "reflect" +) + +// Where 链式操作,以对象作为查询条件 +func (ex *Builder) Where(dest interface{}) *Builder { + typeOf := reflect.TypeOf(dest) + valueOf := reflect.ValueOf(dest) + + //如果没有设置表名 + if ex.tableName == "" { + ex.tableName = getTableName(typeOf, valueOf) + } + + for i := 0; i < typeOf.Elem().NumField(); i++ { + isNotNull := valueOf.Elem().Field(i).Field(0).Field(1).Bool() + if isNotNull { + key := helper.UnderLine(typeOf.Elem().Field(i).Name) + val := valueOf.Elem().Field(i).Field(0).Field(0).Interface() + ex.whereList = append(ex.whereList, WhereItem{Field: key, Opt: Eq, Val: val}) + } + } + + return ex +} + +// WhereArr 链式操作,以数组作为查询条件 +func (ex *Builder) WhereArr(whereList []WhereItem) *Builder { + ex.whereList = append(ex.whereList, whereList...) + return ex +} + +func (ex *Builder) WhereEq(field string, val interface{}) *Builder { + ex.whereList = append(ex.whereList, WhereItem{ + Field: field, + Opt: Eq, + Val: val, + }) + return ex +} + +func (ex *Builder) WhereNe(field string, val interface{}) *Builder { + ex.whereList = append(ex.whereList, WhereItem{ + Field: field, + Opt: Ne, + Val: val, + }) + return ex +} + +func (ex *Builder) WhereGt(field string, val interface{}) *Builder { + ex.whereList = append(ex.whereList, WhereItem{ + Field: field, + Opt: Gt, + Val: val, + }) + return ex +} + +func (ex *Builder) WhereGe(field string, val interface{}) *Builder { + ex.whereList = append(ex.whereList, WhereItem{ + Field: field, + Opt: Ge, + Val: val, + }) + return ex +} + +func (ex *Builder) WhereLt(field string, val interface{}) *Builder { + ex.whereList = append(ex.whereList, WhereItem{ + Field: field, + Opt: Lt, + Val: val, + }) + return ex +} + +func (ex *Builder) WhereLe(field string, val interface{}) *Builder { + ex.whereList = append(ex.whereList, WhereItem{ + Field: field, + Opt: Le, + Val: val, + }) + return ex +} + +func (ex *Builder) WhereIn(field string, val interface{}) *Builder { + ex.whereList = append(ex.whereList, WhereItem{ + Field: field, + Opt: In, + Val: val, + }) + return ex +} + +func (ex *Builder) WhereNotIn(field string, val interface{}) *Builder { + ex.whereList = append(ex.whereList, WhereItem{ + Field: field, + Opt: NotIn, + Val: val, + }) + return ex +} + +func (ex *Builder) WhereBetween(field string, val interface{}) *Builder { + ex.whereList = append(ex.whereList, WhereItem{ + Field: field, + Opt: Between, + Val: val, + }) + return ex +} + +func (ex *Builder) WhereNotBetween(field string, val interface{}) *Builder { + ex.whereList = append(ex.whereList, WhereItem{ + Field: field, + Opt: NotBetween, + Val: val, + }) + return ex +} + +func (ex *Builder) WhereLike(field string, val interface{}) *Builder { + ex.whereList = append(ex.whereList, WhereItem{ + Field: field, + Opt: Like, + Val: val, + }) + return ex +} + +func (ex *Builder) WhereNotLike(field string, val interface{}) *Builder { + ex.whereList = append(ex.whereList, WhereItem{ + Field: field, + Opt: NotLike, + Val: val, + }) + return ex +} + +func (ex *Builder) WhereRaw(field string, val interface{}) *Builder { + ex.whereList = append(ex.whereList, WhereItem{ + Field: field, + Opt: Raw, + Val: val, + }) + return ex +} diff --git a/executor/crud.go b/executor/crud.go deleted file mode 100644 index d6c58ba..0000000 --- a/executor/crud.go +++ /dev/null @@ -1,1124 +0,0 @@ -package executor - -import ( - "database/sql" - "errors" - "fmt" - "github.com/tangpanqing/aorm/helper" - "github.com/tangpanqing/aorm/null" - "reflect" - "strings" - "unsafe" -) - -const Desc = "DESC" -const Asc = "ASC" - -const Eq = "=" -const Ne = "!=" -const Gt = ">" -const Ge = ">=" -const Lt = "<" -const Le = "<=" - -const In = "IN" -const NotIn = "NOT IN" -const Like = "LIKE" -const NotLike = "NOT LIKE" -const Between = "BETWEEN" -const NotBetween = "NOT BETWEEN" - -const Raw = "Raw" - -type WhereItem struct { - Field string - Opt string - Val any -} - -type IntStruct struct { - C null.Int -} - -type FloatStruct struct { - C null.Float -} - -// Insert 增加记录 -func (ex *Executor) Insert(dest interface{}) (int64, error) { - typeOf := reflect.TypeOf(dest) - valueOf := reflect.ValueOf(dest) - - //如果没有设置表名 - if ex.tableName == "" { - ex.tableName = reflectTableName(typeOf, valueOf) - } - - var keys []string - var paramList []any - var place []string - for i := 0; i < typeOf.Elem().NumField(); i++ { - isNotNull := valueOf.Elem().Field(i).Field(0).Field(1).Bool() - if isNotNull { - key := helper.UnderLine(typeOf.Elem().Field(i).Name) - val := valueOf.Elem().Field(i).Field(0).Field(0).Interface() - keys = append(keys, key) - paramList = append(paramList, val) - place = append(place, "?") - } - } - - sqlStr := "INSERT INTO " + ex.tableName + " (" + strings.Join(keys, ",") + ") VALUES (" + strings.Join(place, ",") + ")" - - res, err := ex.Exec(sqlStr, paramList...) - if err != nil { - return 0, err - } - - lastId, err := res.LastInsertId() - if err != nil { - return 0, err - } - - return lastId, nil -} - -// InsertBatch 批量增加记录 -func (ex *Executor) InsertBatch(values interface{}) (int64, error) { - - var keys []string - var paramList []any - var place []string - - valueOf := reflect.ValueOf(values).Elem() - if valueOf.Len() == 0 { - return 0, errors.New("the data list for insert batch not found") - } - typeOf := reflect.TypeOf(values).Elem().Elem() - - //如果没有设置表名 - if ex.tableName == "" { - ex.tableName = reflectTableName(typeOf, valueOf.Index(0)) - } - - for j := 0; j < valueOf.Len(); j++ { - var placeItem []string - - for i := 0; i < valueOf.Index(j).NumField(); i++ { - isNotNull := valueOf.Index(j).Field(i).Field(0).Field(1).Bool() - if isNotNull { - if j == 0 { - key := helper.UnderLine(typeOf.Field(i).Name) - keys = append(keys, key) - } - - val := valueOf.Index(j).Field(i).Field(0).Field(0).Interface() - paramList = append(paramList, val) - placeItem = append(placeItem, "?") - } - } - - place = append(place, "("+strings.Join(placeItem, ",")+")") - } - - sqlStr := "INSERT INTO " + ex.tableName + " (" + strings.Join(keys, ",") + ") VALUES " + strings.Join(place, ",") - - res, err := ex.Exec(sqlStr, paramList...) - if err != nil { - return 0, err - } - - count, err := res.RowsAffected() - if err != nil { - return 0, err - } - - return count, nil -} - -// GetRows 获取行操作 -func (ex *Executor) GetRows() (*sql.Rows, error) { - sqlStr, paramList := ex.GetSqlAndParams() - - smt, errSmt := ex.LinkCommon.Prepare(sqlStr) - if errSmt != nil { - return nil, errSmt - } - //defer smt.Close() - - rows, errRows := smt.Query(paramList...) - if errRows != nil { - return nil, errRows - } - - return rows, nil -} - -// GetMany 查询记录(新) -func (ex *Executor) GetMany(values interface{}) error { - rows, errRows := ex.GetRows() - defer rows.Close() - if errRows != nil { - return errRows - } - - destSlice := reflect.Indirect(reflect.ValueOf(values)) - destType := destSlice.Type().Elem() - destValue := reflect.New(destType).Elem() - - //从数据库中读出来的字段名字 - columnNameList, errColumns := rows.Columns() - if errColumns != nil { - return errColumns - } - - //从结构体反射出来的属性名 - fieldNameMap := getFieldNameMap(destValue, destType) - - for rows.Next() { - scans := getScans(columnNameList, fieldNameMap, destValue) - - errScan := rows.Scan(scans...) - if errScan != nil { - return errScan - } - - destSlice.Set(reflect.Append(destSlice, destValue)) - } - - return nil -} - -// GetOne 查询某一条记录 -func (ex *Executor) GetOne(obj interface{}) error { - ex.Limit(0, 1) - - rows, errRows := ex.GetRows() - defer rows.Close() - if errRows != nil { - return errRows - } - - destType := reflect.TypeOf(obj).Elem() - destValue := reflect.ValueOf(obj).Elem() - - //从数据库中读出来的字段名字 - columnNameList, errColumns := rows.Columns() - if errColumns != nil { - return errColumns - } - - //从结构体反射出来的属性名 - fieldNameMap := getFieldNameMap(destValue, destType) - - for rows.Next() { - scans := getScans(columnNameList, fieldNameMap, destValue) - err := rows.Scan(scans...) - if err != nil { - return err - } - } - - return nil -} - -// RawSql 执行原始的sql语句 -func (ex *Executor) RawSql(sql string, paramList ...interface{}) *Executor { - ex.sql = sql - ex.paramList = paramList - return ex -} - -func (ex *Executor) GetSqlAndParams() (string, []interface{}) { - if ex.sql != "" { - return ex.sql, ex.paramList - } - - var paramList []interface{} - - fieldStr, paramList := handleField(ex.selectList, ex.selectExpList, paramList) - whereStr, paramList := handleWhere(ex.whereList, paramList) - joinStr := handleJoin(ex.joinList) - groupStr := handleGroup(ex.groupList) - havingStr, paramList := handleHaving(ex.havingList, paramList) - orderStr := handleOrder(ex.orderList) - limitStr, paramList := handleLimit(ex.offset, ex.pageSize, paramList) - lockStr := handleLockForUpdate(ex.isLockForUpdate) - - sqlStr := "SELECT " + fieldStr + " FROM " + ex.tableName + joinStr + whereStr + groupStr + havingStr + orderStr + limitStr + lockStr - - if ex.isDebug { - fmt.Println(sqlStr) - fmt.Println(paramList...) - } - - return sqlStr, paramList -} - -// Update 更新记录 -func (ex *Executor) Update(dest interface{}) (int64, error) { - var paramList []any - setStr, paramList := ex.handleSet(dest, paramList) - whereStr, paramList := handleWhere(ex.whereList, paramList) - sqlStr := "UPDATE " + ex.tableName + setStr + whereStr - - return ex.ExecAffected(sqlStr, paramList...) -} - -// Delete 删除记录 -func (ex *Executor) Delete() (int64, error) { - var paramList []any - whereStr, paramList := handleWhere(ex.whereList, paramList) - sqlStr := "DELETE FROM " + ex.tableName + whereStr - - return ex.ExecAffected(sqlStr, paramList...) -} - -// Truncate 清空记录 -func (ex *Executor) Truncate() (int64, error) { - sqlStr := "TRUNCATE TABLE " + ex.tableName - - return ex.ExecAffected(sqlStr) -} - -// Count 聚合函数-数量 -func (ex *Executor) Count(fieldName string) (int64, error) { - var obj []IntStruct - err := ex.Select("count(" + fieldName + ") as c").GetMany(&obj) - if err != nil { - return 0, err - } - - return obj[0].C.Int64, nil -} - -// Sum 聚合函数-合计 -func (ex *Executor) Sum(fieldName string) (float64, error) { - var obj []FloatStruct - err := ex.Select("sum(" + fieldName + ") as c").GetMany(&obj) - if err != nil { - return 0, err - } - - return obj[0].C.Float64, nil -} - -// Avg 聚合函数-平均值 -func (ex *Executor) Avg(fieldName string) (float64, error) { - var obj []FloatStruct - err := ex.Select("avg(" + fieldName + ") as c").GetMany(&obj) - if err != nil { - return 0, err - } - - return obj[0].C.Float64, nil -} - -// Max 聚合函数-最大值 -func (ex *Executor) Max(fieldName string) (float64, error) { - var obj []FloatStruct - err := ex.Select("max(" + fieldName + ") as c").GetMany(&obj) - if err != nil { - return 0, err - } - - return obj[0].C.Float64, nil -} - -// Min 聚合函数-最小值 -func (ex *Executor) Min(fieldName string) (float64, error) { - var obj []FloatStruct - err := ex.Select("min(" + fieldName + ") as c").GetMany(&obj) - if err != nil { - return 0, err - } - - return obj[0].C.Float64, nil -} - -// Value 字段值 -func (ex *Executor) Value(fieldName string, dest interface{}) error { - ex.Select(fieldName).Limit(0, 1) - - rows, errRows := ex.GetRows() - defer rows.Close() - if errRows != nil { - return errRows - } - - destValue := reflect.ValueOf(dest).Elem() - - //从数据库中读出来的字段名字 - columnNameList, errColumns := rows.Columns() - if errColumns != nil { - return errColumns - } - - for rows.Next() { - var scans []interface{} - for _, columnName := range columnNameList { - if fieldName == columnName { - scans = append(scans, destValue.Addr().Interface()) - } else { - var emptyVal interface{} - scans = append(scans, &emptyVal) - } - } - - err := rows.Scan(scans...) - if err != nil { - return err - } - } - - return nil -} - -// Pluck 获取某一列的值 -func (ex *Executor) Pluck(fieldName string, values interface{}) error { - ex.Select(fieldName) - - rows, errRows := ex.GetRows() - defer rows.Close() - if errRows != nil { - return errRows - } - - destSlice := reflect.Indirect(reflect.ValueOf(values)) - destType := destSlice.Type().Elem() - destValue := reflect.New(destType).Elem() - - //从数据库中读出来的字段名字 - columnNameList, errColumns := rows.Columns() - if errColumns != nil { - return errColumns - } - - for rows.Next() { - var scans []interface{} - for _, columnName := range columnNameList { - if fieldName == columnName { - scans = append(scans, destValue.Addr().Interface()) - } else { - var emptyVal interface{} - scans = append(scans, &emptyVal) - } - } - - err := rows.Scan(scans...) - if err != nil { - return err - } - - destSlice.Set(reflect.Append(destSlice, destValue)) - } - - return nil -} - -// Increment 某字段自增 -func (ex *Executor) Increment(fieldName string, step int) (int64, error) { - var paramList []any - paramList = append(paramList, step) - whereStr, paramList := handleWhere(ex.whereList, paramList) - sqlStr := "UPDATE " + ex.tableName + " SET " + fieldName + "=" + fieldName + "+?" + whereStr - - return ex.ExecAffected(sqlStr, paramList...) -} - -// Decrement 某字段自减 -func (ex *Executor) Decrement(fieldName string, step int) (int64, error) { - var paramList []any - paramList = append(paramList, step) - whereStr, paramList := handleWhere(ex.whereList, paramList) - sqlStr := "UPDATE " + ex.tableName + " SET " + fieldName + "=" + fieldName + "-?" + whereStr - - return ex.ExecAffected(sqlStr, paramList...) -} - -// Exec 通用执行-新增,更新,删除 -func (ex *Executor) Exec(sqlStr string, args ...interface{}) (sql.Result, error) { - if ex.isDebug { - fmt.Println(sqlStr) - fmt.Println(args...) - } - - smt, err1 := ex.LinkCommon.Prepare(sqlStr) - if err1 != nil { - return nil, err1 - } - defer smt.Close() - - res, err2 := smt.Exec(args...) - if err2 != nil { - return nil, err2 - } - - //ex.clear() - return res, nil -} - -// ExecAffected 通用执行-更新,删除 -func (ex *Executor) ExecAffected(sqlStr string, args ...interface{}) (int64, error) { - res, err := ex.Exec(sqlStr, args...) - if err != nil { - return 0, err - } - - count, err := res.RowsAffected() - if err != nil { - return 0, err - } - - return count, nil -} - -// Debug 链式操作-是否开启调试,打印sql -func (ex *Executor) Debug(isDebug bool) *Executor { - ex.isDebug = isDebug - return ex -} - -// Select 链式操作-查询哪些字段,默认 * -func (ex *Executor) Select(fields ...string) *Executor { - ex.selectList = append(ex.selectList, fields...) - return ex -} - -// SelectCount 链式操作-count(field) as field_new -func (ex *Executor) SelectCount(field string, fieldNew string) *Executor { - ex.selectList = append(ex.selectList, "count("+field+") AS "+fieldNew) - return ex -} - -// SelectSum 链式操作-sum(field) as field_new -func (ex *Executor) SelectSum(field string, fieldNew string) *Executor { - ex.selectList = append(ex.selectList, "sum("+field+") AS "+fieldNew) - return ex -} - -// SelectMin 链式操作-min(field) as field_new -func (ex *Executor) SelectMin(field string, fieldNew string) *Executor { - ex.selectList = append(ex.selectList, "min("+field+") AS "+fieldNew) - return ex -} - -// SelectMax 链式操作-max(field) as field_new -func (ex *Executor) SelectMax(field string, fieldNew string) *Executor { - ex.selectList = append(ex.selectList, "max("+field+") AS "+fieldNew) - return ex -} - -// SelectAvg 链式操作-avg(field) as field_new -func (ex *Executor) SelectAvg(field string, fieldNew string) *Executor { - ex.selectList = append(ex.selectList, "avg("+field+") AS "+fieldNew) - return ex -} - -// SelectExp 链式操作-表达式 -func (ex *Executor) SelectExp(dbSub **Executor, fieldName string) *Executor { - ex.selectExpList = append(ex.selectExpList, &ExpItem{ - Executor: dbSub, - FieldName: fieldName, - }) - return ex -} - -// Table 链式操作-从哪个表查询,允许直接写别名,例如 person p -func (ex *Executor) Table(tableName string) *Executor { - ex.tableName = tableName - return ex -} - -// LeftJoin 链式操作,左联查询,例如 LeftJoin("project p", "p.project_id=o.project_id") -func (ex *Executor) LeftJoin(tableName string, condition string) *Executor { - ex.joinList = append(ex.joinList, "LEFT JOIN "+tableName+" ON "+condition) - return ex -} - -// RightJoin 链式操作,右联查询,例如 RightJoin("project p", "p.project_id=o.project_id") -func (ex *Executor) RightJoin(tableName string, condition string) *Executor { - ex.joinList = append(ex.joinList, "RIGHT JOIN "+tableName+" ON "+condition) - return ex -} - -// Join 链式操作,内联查询,例如 Join("project p", "p.project_id=o.project_id") -func (ex *Executor) Join(tableName string, condition string) *Executor { - ex.joinList = append(ex.joinList, "INNER JOIN "+tableName+" ON "+condition) - return ex -} - -// Where 链式操作,以对象作为查询条件 -func (ex *Executor) Where(dest interface{}) *Executor { - typeOf := reflect.TypeOf(dest) - valueOf := reflect.ValueOf(dest) - - //如果没有设置表名 - if ex.tableName == "" { - ex.tableName = reflectTableName(typeOf, valueOf) - } - - for i := 0; i < typeOf.Elem().NumField(); i++ { - isNotNull := valueOf.Elem().Field(i).Field(0).Field(1).Bool() - if isNotNull { - key := helper.UnderLine(typeOf.Elem().Field(i).Name) - val := valueOf.Elem().Field(i).Field(0).Field(0).Interface() - ex.whereList = append(ex.whereList, WhereItem{Field: key, Opt: Eq, Val: val}) - } - } - - return ex -} - -// WhereArr 链式操作,以数组作为查询条件 -func (ex *Executor) WhereArr(whereList []WhereItem) *Executor { - ex.whereList = append(ex.whereList, whereList...) - return ex -} - -func (ex *Executor) WhereEq(field string, val interface{}) *Executor { - ex.whereList = append(ex.whereList, WhereItem{ - Field: field, - Opt: Eq, - Val: val, - }) - return ex -} - -func (ex *Executor) WhereNe(field string, val interface{}) *Executor { - ex.whereList = append(ex.whereList, WhereItem{ - Field: field, - Opt: Ne, - Val: val, - }) - return ex -} - -func (ex *Executor) WhereGt(field string, val interface{}) *Executor { - ex.whereList = append(ex.whereList, WhereItem{ - Field: field, - Opt: Gt, - Val: val, - }) - return ex -} - -func (ex *Executor) WhereGe(field string, val interface{}) *Executor { - ex.whereList = append(ex.whereList, WhereItem{ - Field: field, - Opt: Ge, - Val: val, - }) - return ex -} - -func (ex *Executor) WhereLt(field string, val interface{}) *Executor { - ex.whereList = append(ex.whereList, WhereItem{ - Field: field, - Opt: Lt, - Val: val, - }) - return ex -} - -func (ex *Executor) WhereLe(field string, val interface{}) *Executor { - ex.whereList = append(ex.whereList, WhereItem{ - Field: field, - Opt: Le, - Val: val, - }) - return ex -} - -func (ex *Executor) WhereIn(field string, val interface{}) *Executor { - ex.whereList = append(ex.whereList, WhereItem{ - Field: field, - Opt: In, - Val: val, - }) - return ex -} - -func (ex *Executor) WhereNotIn(field string, val interface{}) *Executor { - ex.whereList = append(ex.whereList, WhereItem{ - Field: field, - Opt: NotIn, - Val: val, - }) - return ex -} - -func (ex *Executor) WhereBetween(field string, val interface{}) *Executor { - ex.whereList = append(ex.whereList, WhereItem{ - Field: field, - Opt: Between, - Val: val, - }) - return ex -} - -func (ex *Executor) WhereNotBetween(field string, val interface{}) *Executor { - ex.whereList = append(ex.whereList, WhereItem{ - Field: field, - Opt: NotBetween, - Val: val, - }) - return ex -} - -func (ex *Executor) WhereLike(field string, val interface{}) *Executor { - ex.whereList = append(ex.whereList, WhereItem{ - Field: field, - Opt: Like, - Val: val, - }) - return ex -} - -func (ex *Executor) WhereNotLike(field string, val interface{}) *Executor { - ex.whereList = append(ex.whereList, WhereItem{ - Field: field, - Opt: NotLike, - Val: val, - }) - return ex -} - -func (ex *Executor) WhereRaw(field string, val interface{}) *Executor { - ex.whereList = append(ex.whereList, WhereItem{ - Field: field, - Opt: Raw, - Val: val, - }) - return ex -} - -// GroupBy 链式操作,以某字段进行分组 -func (ex *Executor) GroupBy(fieldName string) *Executor { - ex.groupList = append(ex.groupList, fieldName) - return ex -} - -// Having 链式操作,以对象作为筛选条件 -func (ex *Executor) Having(dest interface{}) *Executor { - typeOf := reflect.TypeOf(dest) - valueOf := reflect.ValueOf(dest) - - //如果没有设置表名 - if ex.tableName == "" { - ex.tableName = reflectTableName(typeOf, valueOf) - } - - for i := 0; i < typeOf.Elem().NumField(); i++ { - isNotNull := valueOf.Elem().Field(i).Field(0).Field(1).Bool() - if isNotNull { - key := helper.UnderLine(typeOf.Elem().Field(i).Name) - val := valueOf.Elem().Field(i).Field(0).Field(0).Interface() - ex.havingList = append(ex.havingList, WhereItem{Field: key, Opt: Eq, Val: val}) - } - } - - return ex -} - -// HavingArr 链式操作,以数组作为筛选条件 -func (ex *Executor) HavingArr(havingList []WhereItem) *Executor { - ex.havingList = append(ex.havingList, havingList...) - return ex -} - -func (ex *Executor) HavingEq(field string, val interface{}) *Executor { - ex.havingList = append(ex.havingList, WhereItem{ - Field: field, - Opt: Eq, - Val: val, - }) - return ex -} - -func (ex *Executor) HavingNe(field string, val interface{}) *Executor { - ex.havingList = append(ex.havingList, WhereItem{ - Field: field, - Opt: Ne, - Val: val, - }) - return ex -} - -func (ex *Executor) HavingGt(field string, val interface{}) *Executor { - ex.havingList = append(ex.havingList, WhereItem{ - Field: field, - Opt: Gt, - Val: val, - }) - return ex -} - -func (ex *Executor) HavingGe(field string, val interface{}) *Executor { - ex.havingList = append(ex.havingList, WhereItem{ - Field: field, - Opt: Ge, - Val: val, - }) - return ex -} - -func (ex *Executor) HavingLt(field string, val interface{}) *Executor { - ex.havingList = append(ex.havingList, WhereItem{ - Field: field, - Opt: Lt, - Val: val, - }) - return ex -} - -func (ex *Executor) HavingLe(field string, val interface{}) *Executor { - ex.havingList = append(ex.havingList, WhereItem{ - Field: field, - Opt: Le, - Val: val, - }) - return ex -} - -func (ex *Executor) HavingIn(field string, val interface{}) *Executor { - ex.havingList = append(ex.havingList, WhereItem{ - Field: field, - Opt: In, - Val: val, - }) - return ex -} - -func (ex *Executor) HavingNotIn(field string, val interface{}) *Executor { - ex.havingList = append(ex.havingList, WhereItem{ - Field: field, - Opt: NotIn, - Val: val, - }) - return ex -} - -func (ex *Executor) HavingBetween(field string, val interface{}) *Executor { - ex.havingList = append(ex.havingList, WhereItem{ - Field: field, - Opt: Between, - Val: val, - }) - return ex -} - -func (ex *Executor) HavingNotBetween(field string, val interface{}) *Executor { - ex.havingList = append(ex.havingList, WhereItem{ - Field: field, - Opt: NotBetween, - Val: val, - }) - return ex -} - -func (ex *Executor) HavingLike(field string, val interface{}) *Executor { - ex.havingList = append(ex.havingList, WhereItem{ - Field: field, - Opt: Like, - Val: val, - }) - return ex -} - -func (ex *Executor) HavingNotLike(field string, val interface{}) *Executor { - ex.havingList = append(ex.havingList, WhereItem{ - Field: field, - Opt: NotLike, - Val: val, - }) - return ex -} - -func (ex *Executor) HavingRaw(field string, val interface{}) *Executor { - ex.havingList = append(ex.havingList, WhereItem{ - Field: field, - Opt: Raw, - Val: val, - }) - return ex -} - -// OrderBy 链式操作,以某字段进行排序 -func (ex *Executor) OrderBy(field string, orderType string) *Executor { - ex.orderList = append(ex.orderList, field+" "+orderType) - return ex -} - -// Limit 链式操作,分页 -func (ex *Executor) Limit(offset int, pageSize int) *Executor { - ex.offset = offset - ex.pageSize = pageSize - return ex -} - -// Page 链式操作,分页 -func (ex *Executor) Page(pageNum int, pageSize int) *Executor { - ex.offset = (pageNum - 1) * pageSize - ex.pageSize = pageSize - return ex -} - -// LockForUpdate 加锁 -func (ex *Executor) LockForUpdate(isLockForUpdate bool) *Executor { - ex.isLockForUpdate = isLockForUpdate - return ex -} - -//拼接SQL,字段相关 -func handleField(selectList []string, selectExpList []*ExpItem, paramList []any) (string, []any) { - if len(selectList) == 0 && len(selectExpList) == 0 { - return "*", paramList - } - - //处理子语句 - for i := 0; i < len(selectExpList); i++ { - executor := *(selectExpList[i].Executor) - subSql, subParamList := executor.GetSqlAndParams() - selectList = append(selectList, "("+subSql+") AS "+selectExpList[i].FieldName) - paramList = append(paramList, subParamList...) - } - - return strings.Join(selectList, ","), paramList -} - -//拼接SQL,查询条件 -func handleWhere(where []WhereItem, paramList []any) (string, []any) { - if len(where) == 0 { - return "", paramList - } - - whereList, paramList := whereAndHaving(where, paramList) - - return " WHERE " + strings.Join(whereList, " AND "), paramList -} - -//拼接SQL,更新信息 -func (ex *Executor) handleSet(dest interface{}, paramList []any) (string, []any) { - typeOf := reflect.TypeOf(dest) - valueOf := reflect.ValueOf(dest) - - //如果没有设置表名 - if ex.tableName == "" { - ex.tableName = reflectTableName(typeOf, valueOf) - } - - var keys []string - for i := 0; i < typeOf.Elem().NumField(); i++ { - isNotNull := valueOf.Elem().Field(i).Field(0).Field(1).Bool() - if isNotNull { - key := helper.UnderLine(typeOf.Elem().Field(i).Name) - val := valueOf.Elem().Field(i).Field(0).Field(0).Interface() - - keys = append(keys, key+"=?") - paramList = append(paramList, val) - } - } - - return " SET " + strings.Join(keys, ","), paramList -} - -//拼接SQL,关联查询 -func handleJoin(joinList []string) string { - if len(joinList) == 0 { - return "" - } - - return " " + strings.Join(joinList, " ") -} - -//拼接SQL,结果分组 -func handleGroup(groupList []string) string { - if len(groupList) == 0 { - return "" - } - - return " GROUP BY " + strings.Join(groupList, ",") -} - -//拼接SQL,结果筛选 -func handleHaving(having []WhereItem, paramList []any) (string, []any) { - if len(having) == 0 { - return "", paramList - } - - whereList, paramList := whereAndHaving(having, paramList) - - return " Having " + strings.Join(whereList, " AND "), paramList -} - -//拼接SQL,结果排序 -func handleOrder(orderList []string) string { - if len(orderList) == 0 { - return "" - } - - return " Order BY " + strings.Join(orderList, ",") -} - -//拼接SQL,分页相关 -func handleLimit(offset int, pageSize int, paramList []any) (string, []any) { - if 0 == pageSize { - return "", paramList - } - - paramList = append(paramList, offset) - paramList = append(paramList, pageSize) - - return " Limit ?,? ", paramList -} - -//拼接SQL,锁 -func handleLockForUpdate(isLock bool) string { - if isLock { - return " FOR UPDATE" - } - - return "" -} - -//拼接SQL,查询与筛选通用操作 -func whereAndHaving(where []WhereItem, paramList []any) ([]string, []any) { - var whereList []string - for i := 0; i < len(where); i++ { - if "**executor.Executor" == reflect.TypeOf(where[i].Val).String() { - executor := *(**Executor)(unsafe.Pointer(reflect.ValueOf(where[i].Val).Pointer())) - subSql, subParams := executor.GetSqlAndParams() - - if where[i].Opt != Raw { - whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+"("+subSql+")") - paramList = append(paramList, subParams...) - } else { - - } - } else { - if where[i].Opt == Eq || where[i].Opt == Ne || where[i].Opt == Gt || where[i].Opt == Ge || where[i].Opt == Lt || where[i].Opt == Le { - //如果是浮点数查询 - switch where[i].Val.(type) { - case float32: - whereList = append(whereList, "CONCAT("+where[i].Field+",'') "+where[i].Opt+" "+"?") - case float64: - whereList = append(whereList, "CONCAT("+where[i].Field+",'') "+where[i].Opt+" "+"?") - default: - whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+"?") - } - - paramList = append(paramList, fmt.Sprintf("%v", where[i].Val)) - } - - if where[i].Opt == Between || where[i].Opt == NotBetween { - values := toAnyArr(where[i].Val) - whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+"(?) AND (?)") - paramList = append(paramList, values...) - } - - if where[i].Opt == Like || where[i].Opt == NotLike { - values := toAnyArr(where[i].Val) - var valueStr []string - for j := 0; j < len(values); j++ { - str := fmt.Sprintf("%v", values[j]) - - if "%" != str { - //values[j] = "?" - paramList = append(paramList, str) - valueStr = append(valueStr, "?") - } else { - valueStr = append(valueStr, "'"+str+"'") - } - } - - whereList = append(whereList, where[i].Field+" "+where[i].Opt+" concat("+strings.Join(valueStr, ",")+")") - } - - if where[i].Opt == In || where[i].Opt == NotIn { - values := toAnyArr(where[i].Val) - var placeholder []string - for j := 0; j < len(values); j++ { - placeholder = append(placeholder, "?") - } - - whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+"("+strings.Join(placeholder, ",")+")") - paramList = append(paramList, values...) - } - - if where[i].Opt == Raw { - whereList = append(whereList, where[i].Field+fmt.Sprintf("%v", where[i].Val)) - } - } - } - - return whereList, paramList -} - -//将一个interface抽取成数组 -func toAnyArr(val any) []any { - var values []any - switch val.(type) { - case []int: - for _, value := range val.([]int) { - values = append(values, value) - } - case []int64: - for _, value := range val.([]int64) { - values = append(values, value) - } - case []float32: - for _, value := range val.([]float32) { - values = append(values, value) - } - case []float64: - for _, value := range val.([]float64) { - values = append(values, value) - } - case []string: - for _, value := range val.([]string) { - values = append(values, value) - } - } - - return values -} - -//反射表名,优先从方法获取,没有方法则从名字获取 -func reflectTableName(typeOf reflect.Type, valueOf reflect.Value) string { - method, isSet := typeOf.MethodByName("TableName") - if isSet { - var paramList []reflect.Value - paramList = append(paramList, valueOf) - res := method.Func.Call(paramList) - return res[0].String() - } else { - arr := strings.Split(typeOf.String(), ".") - return helper.UnderLine(arr[len(arr)-1]) - } -} - -func getFieldNameMap(destValue reflect.Value, destType reflect.Type) map[string]int { - fieldNameMap := make(map[string]int) - for i := 0; i < destValue.NumField(); i++ { - fieldNameMap[destType.Field(i).Name] = i - } - - return fieldNameMap -} - -func getScans(columnNameList []string, fieldNameMap map[string]int, destValue reflect.Value) []interface{} { - var scans []interface{} - for _, columnName := range columnNameList { - fieldName := helper.CamelString(strings.ToLower(columnName)) - index, ok := fieldNameMap[fieldName] - if ok { - scans = append(scans, destValue.Field(index).Addr().Interface()) - } else { - var emptyVal interface{} - scans = append(scans, &emptyVal) - } - } - - return scans -} diff --git a/executor/executor.go b/executor/executor.go deleted file mode 100644 index 0162731..0000000 --- a/executor/executor.go +++ /dev/null @@ -1,38 +0,0 @@ -package executor - -import ( - "github.com/tangpanqing/aorm/model" -) - -// ExpItem 将某子语句重命名为某字段 -type ExpItem struct { - Executor **Executor - FieldName string -} - -// Executor 查询记录所需要的条件 -type Executor struct { - //数据库操作连接 - LinkCommon model.LinkCommon - - //查询参数 - tableName string - selectList []string - selectExpList []*ExpItem - groupList []string - whereList []WhereItem - joinList []string - havingList []WhereItem - orderList []string - offset int - pageSize int - isDebug bool - isLockForUpdate bool - - //sql与参数 - sql string - paramList []interface{} - - //驱动名字 - driverName string -} diff --git a/migrate_mysql/migrate.go b/migrate_mysql/migrate.go index 7b9dc38..a7977b1 100644 --- a/migrate_mysql/migrate.go +++ b/migrate_mysql/migrate.go @@ -2,7 +2,7 @@ package migrate_mysql import ( "fmt" - "github.com/tangpanqing/aorm/executor" + "github.com/tangpanqing/aorm/builder" "github.com/tangpanqing/aorm/helper" "github.com/tangpanqing/aorm/model" "github.com/tangpanqing/aorm/null" @@ -42,7 +42,7 @@ type MigrateExecutor struct { OpinionList []model.OpinionItem //执行者 - Ex *executor.Executor + Ex *builder.Builder } //ShowCreateTable 查看创建表的ddl diff --git a/migrate_sqlite3/migrate.go b/migrate_sqlite3/migrate.go index 9ee5e26..1ddc7e7 100644 --- a/migrate_sqlite3/migrate.go +++ b/migrate_sqlite3/migrate.go @@ -2,7 +2,7 @@ package migrate_sqlite3 import ( "fmt" - "github.com/tangpanqing/aorm/executor" + "github.com/tangpanqing/aorm/builder" "github.com/tangpanqing/aorm/helper" "github.com/tangpanqing/aorm/model" "github.com/tangpanqing/aorm/null" @@ -48,7 +48,7 @@ type MigrateExecutor struct { OpinionList []model.OpinionItem //执行者 - Ex *executor.Executor + Ex *builder.Builder } //ShowCreateTable 查看创建表的ddl @@ -203,8 +203,7 @@ func (mm *MigrateExecutor) getColumnsFromDb(dbName string, tableName string) []C } func (mm *MigrateExecutor) getIndexesFromDb(tableName string) []Index { - sqlIndex := "select * from sqlite_master where type = 'index' and tbl_name=" + "'" + tableName + "'" - + sqlIndex := "select * from sqlite_master where type = 'index' and name not like '%sqlite_autoindex%' and tbl_name=" + "'" + tableName + "'" var sqliteMasterList []SqliteMaster mm.Ex.RawSql(sqlIndex).GetMany(&sqliteMasterList) @@ -217,10 +216,8 @@ func (mm *MigrateExecutor) getIndexesFromDb(tableName string) []Index { t = 0 } - fmt.Println(sql) compileRegex := regexp.MustCompile("INDEX\\s(.*?)\\son.*?\\((.*?)\\)") matchArr := compileRegex.FindAllStringSubmatch(sql, -1) - fmt.Println(matchArr) indexesFromDb = append(indexesFromDb, Index{ NonUnique: null.IntFrom(int64(t)), @@ -229,6 +226,21 @@ func (mm *MigrateExecutor) getIndexesFromDb(tableName string) []Index { }) } + //查询是否有主键索引 + sql := "select * from sqlite_master where type='table' and tbl_name=" + "'" + tableName + "'" + var sqliteMaster SqliteMaster + mm.Ex.RawSql(sql).GetOne(&sqliteMaster) + + compileRegex := regexp.MustCompile("PRIMARY\\sKEY\\s\\((.*?)\\)") + matchArr2 := compileRegex.FindAllStringSubmatch(sqliteMaster.Sql.String, -1) + if len(matchArr2) > 0 { + indexesFromDb = append(indexesFromDb, Index{ + NonUnique: null.IntFrom(0), + ColumnName: null.StringFrom(matchArr2[0][1]), + KeyName: null.StringFrom("PRIMARY"), + }) + } + return indexesFromDb } @@ -243,23 +255,16 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co isFind = 1 if columnCode.DataType.String != columnDb.DataType.String || - columnCode.Extra.String != columnDb.Extra.String || columnCode.ColumnDefault.String != columnDb.ColumnDefault.String { - fmt.Println(columnCode.DataType.String) - fmt.Println(columnDb.DataType.String) - fmt.Println("-------") - fmt.Println(columnCode.Extra.String) - fmt.Println(columnDb.Extra.String) - fmt.Println("-------") - - //sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getColumnStr(columnCode) - //_, err := mm.Ex.Exec(sql) - //if err != nil { - // fmt.Println(err) - //} else { - // fmt.Println("修改属性:" + sql) - //} + sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getColumnStr(columnCode) + _, err := mm.Ex.Exec(sql) + if err != nil { + fmt.Println(sql) + fmt.Println(err) + } else { + fmt.Println("修改属性:" + sql) + } } } } @@ -268,6 +273,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getColumnStr(columnCode) _, err := mm.Ex.Exec(sql) if err != nil { + fmt.Println(sql) fmt.Println(err) } else { fmt.Println("增加属性:" + sql) @@ -296,13 +302,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co } if isFind == 0 { - sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getIndexStr(indexCode) - _, err := mm.Ex.Exec(sql) - if err != nil { - fmt.Println(err) - } else { - fmt.Println("增加索引:" + sql) - } + mm.createIndex(tableFromCode.TableName.String, indexCode) } } } @@ -335,17 +335,26 @@ func (mm *MigrateExecutor) createTable(tableFromCode Table, columnsFromCode []Co for i := 0; i < len(indexesFromCode); i++ { index := indexesFromCode[i] if index.KeyName.String != "PRIMARY" { - keyType := "" - if index.NonUnique.Int64 == 0 { - keyType = "UNIQUE" - } - - sql := "CREATE " + keyType + " INDEX " + index.KeyName.String + " on " + tableFromCode.TableName.String + " (" + index.ColumnName.String + ")" - mm.Ex.Exec(sql) + mm.createIndex(tableFromCode.TableName.String, index) } } } +func (mm *MigrateExecutor) createIndex(tableName string, index Index) { + keyType := "" + if index.NonUnique.Int64 == 0 { + keyType = "UNIQUE" + } + + sql := "CREATE " + keyType + " INDEX " + index.KeyName.String + " on " + tableName + " (" + index.ColumnName.String + ")" + _, err := mm.Ex.Exec(sql) + if err != nil { + fmt.Println(err) + } else { + fmt.Println("增加索引:" + sql) + } +} + func (mm *MigrateExecutor) getOpinionVal(key string, def string) string { opinions := mm.OpinionList for i := 0; i < len(opinions); i++ { @@ -490,15 +499,6 @@ func getNullAble(fieldMap map[string]string) string { return IsNullable } -func getComment(fieldMap map[string]string) string { - commentVal, commentIs := fieldMap["comment"] - if commentIs { - return commentVal - } - - return "" -} - func getExtra(fieldMap map[string]string) string { _, commentIs := fieldMap["auto_increment"] if commentIs { diff --git a/migrator/migrator.go b/migrator/migrator.go index f62b16a..87f1b7a 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -1,7 +1,7 @@ package migrator import ( - "github.com/tangpanqing/aorm/executor" + "github.com/tangpanqing/aorm/builder" "github.com/tangpanqing/aorm/helper" "github.com/tangpanqing/aorm/migrate_mysql" "github.com/tangpanqing/aorm/migrate_sqlite3" @@ -42,7 +42,7 @@ func (mi *Migrator) ShowCreateTable(tableName string) string { me := migrate_mysql.MigrateExecutor{ DriverName: mi.driverName, OpinionList: mi.opinionList, - Ex: &executor.Executor{ + Ex: &builder.Builder{ LinkCommon: mi.LinkCommon, }, } @@ -71,7 +71,7 @@ func (mi *Migrator) migrateCommon(tableName string, typeOf reflect.Type) { me := migrate_mysql.MigrateExecutor{ DriverName: mi.driverName, OpinionList: mi.opinionList, - Ex: &executor.Executor{ + Ex: &builder.Builder{ LinkCommon: mi.LinkCommon, }, } @@ -82,7 +82,7 @@ func (mi *Migrator) migrateCommon(tableName string, typeOf reflect.Type) { me := migrate_sqlite3.MigrateExecutor{ DriverName: mi.driverName, OpinionList: mi.opinionList, - Ex: &executor.Executor{ + Ex: &builder.Builder{ LinkCommon: mi.LinkCommon, }, } diff --git a/test/aorm_test.go b/test/aorm_test.go index 5c00e1c..ebc8dcd 100644 --- a/test/aorm_test.go +++ b/test/aorm_test.go @@ -5,7 +5,7 @@ import ( _ "github.com/go-sql-driver/mysql" _ "github.com/mattn/go-sqlite3" "github.com/tangpanqing/aorm" - "github.com/tangpanqing/aorm/executor" + "github.com/tangpanqing/aorm/builder" "github.com/tangpanqing/aorm/helper" "github.com/tangpanqing/aorm/null" "testing" @@ -56,34 +56,15 @@ type PersonWithArticleCount struct { } func TestAll(t *testing.T) { - - sqlite3Content, sqlite3Err := aorm.Open("sqlite3", "test.db") - if sqlite3Err != nil { - panic(sqlite3Err) - } - - username := "root" - password := "root" - hostname := "localhost" - port := "3306" - dbname := "database_name" - - mysqlContent, mysqlErr := aorm.Open("mysql", username+":"+password+"@tcp("+hostname+":"+port+")/"+dbname+"?charset=utf8mb4&parseTime=True&loc=Local") - if mysqlErr != nil { - panic(mysqlErr) - } - dbList := make([]aorm.DbContent, 0) - dbList = append(dbList, sqlite3Content) - dbList = append(dbList, mysqlContent) + dbList = append(dbList, testSqlite3Connect()) + dbList = append(dbList, testMysqlConnect()) for i := 0; i < len(dbList); i++ { dbItem := dbList[i] testMigrate(dbItem.DriverName, dbItem.DbLink) - return - testShowCreateTable(dbItem.DriverName, dbItem.DbLink) id := testInsert(dbItem.DriverName, dbItem.DbLink) @@ -126,44 +107,38 @@ func TestAll(t *testing.T) { testTruncate(dbItem.DriverName, dbItem.DbLink) testHelper(dbItem.DriverName, dbItem.DbLink) } - - // - //for _, db := range dbMap { - // db.Close() - //} } -func testMysqlConnect() *sql.DB { - //replace this database param +func testSqlite3Connect() aorm.DbContent { + sqlite3Content, sqlite3Err := aorm.Open("sqlite3", "test.db") + if sqlite3Err != nil { + panic(sqlite3Err) + } + return sqlite3Content +} + +func testMysqlConnect() aorm.DbContent { username := "root" password := "root" hostname := "localhost" port := "3306" dbname := "database_name" - //connect - db, err := sql.Open("mysql", username+":"+password+"@tcp("+hostname+":"+port+")/"+dbname+"?charset=utf8mb4&parseTime=True&loc=Local") - if err != nil { - panic(err) - } - //defer db.Close() - - //ping test - err1 := db.Ping() - if err1 != nil { - panic(err1) + mysqlContent, mysqlErr := aorm.Open("mysql", username+":"+password+"@tcp("+hostname+":"+port+")/"+dbname+"?charset=utf8mb4&parseTime=True&loc=Local") + if mysqlErr != nil { + panic(mysqlErr) } - return db + return mysqlContent } func testMigrate(name string, db *sql.DB) { //AutoMigrate aorm.Migrator(db).Driver(name).Opinion("ENGINE", "InnoDB").Opinion("COMMENT", "人员表").AutoMigrate(&Person{}) - //aorm.Migrator(db).Driver(name).Opinion("ENGINE", "InnoDB").Opinion("COMMENT", "文章").AutoMigrate(&Article{}) + aorm.Migrator(db).Driver(name).Opinion("ENGINE", "InnoDB").Opinion("COMMENT", "文章").AutoMigrate(&Article{}) //Migrate - //aorm.Migrator(db).Driver(name).Opinion("ENGINE", "InnoDB").Opinion("COMMENT", "人员表").Migrate("person_1", &Person{}) + aorm.Migrator(db).Driver(name).Opinion("ENGINE", "InnoDB").Opinion("COMMENT", "人员表").Migrate("person_1", &Person{}) } func testShowCreateTable(name string, db *sql.DB) { @@ -257,7 +232,7 @@ func testDelete(name string, db *sql.DB, id int64) { func testTable(name string, db *sql.DB) { _, err := aorm.Use(db).Debug(false).Table("person_1").Insert(&Person{Name: null.StringFrom("Cherry")}) if err != nil { - panic(name + "testTable" + "found err") + panic(name + " testTable " + "found err:" + err.Error()) } } @@ -302,14 +277,14 @@ func testWhereWithSub(name string, db *sql.DB) { func testWhere(name string, db *sql.DB) { var listByWhere []Person - var where1 []executor.WhereItem - where1 = append(where1, executor.WhereItem{Field: "type", Opt: executor.Eq, Val: 0}) - where1 = append(where1, executor.WhereItem{Field: "age", Opt: executor.In, Val: []int{18, 20}}) - where1 = append(where1, executor.WhereItem{Field: "money", Opt: executor.Between, Val: []float64{100.1, 200.9}}) - where1 = append(where1, executor.WhereItem{Field: "money", Opt: executor.Eq, Val: 100.15}) - where1 = append(where1, executor.WhereItem{Field: "name", Opt: executor.Like, Val: []string{"%", "li", "%"}}) + var where1 []builder.WhereItem + where1 = append(where1, builder.WhereItem{Field: "type", Opt: builder.Eq, Val: 0}) + where1 = append(where1, builder.WhereItem{Field: "age", Opt: builder.In, Val: []int{18, 20}}) + where1 = append(where1, builder.WhereItem{Field: "money", Opt: builder.Between, Val: []float64{100.1, 200.9}}) + where1 = append(where1, builder.WhereItem{Field: "money", Opt: builder.Eq, Val: 100.15987654321}) + where1 = append(where1, builder.WhereItem{Field: "name", Opt: builder.Like, Val: []string{"%", "li", "%"}}) - err := aorm.Use(db).Debug(false).Table("person").WhereArr(where1).GetMany(&listByWhere) + err := aorm.Use(db).Debug(false).Driver(name).Table("person").WhereArr(where1).GetMany(&listByWhere) if err != nil { panic(name + "testWhere" + "found err") } @@ -317,9 +292,9 @@ func testWhere(name string, db *sql.DB) { func testJoin(name string, db *sql.DB) { var list2 []ArticleVO - var where2 []executor.WhereItem - where2 = append(where2, executor.WhereItem{Field: "o.type", Opt: executor.Eq, Val: 0}) - where2 = append(where2, executor.WhereItem{Field: "p.age", Opt: executor.In, Val: []int{18, 20}}) + var where2 []builder.WhereItem + where2 = append(where2, builder.WhereItem{Field: "o.type", Opt: builder.Eq, Val: 0}) + where2 = append(where2, builder.WhereItem{Field: "p.age", Opt: builder.In, Val: []int{18, 20}}) err := aorm.Use(db).Debug(false). Table("article o"). LeftJoin("person p", "p.id=o.person_id"). @@ -334,8 +309,8 @@ func testJoin(name string, db *sql.DB) { func testGroupBy(name string, db *sql.DB) { var personAge PersonAge - var where []executor.WhereItem - where = append(where, executor.WhereItem{Field: "type", Opt: executor.Eq, Val: 0}) + var where []builder.WhereItem + where = append(where, builder.WhereItem{Field: "type", Opt: builder.Eq, Val: 0}) err := aorm.Use(db).Debug(false). Table("person"). Select("age"). @@ -351,11 +326,11 @@ func testGroupBy(name string, db *sql.DB) { func testHaving(name string, db *sql.DB) { var listByHaving []PersonAge - var where3 []executor.WhereItem - where3 = append(where3, executor.WhereItem{Field: "type", Opt: executor.Eq, Val: 0}) + var where3 []builder.WhereItem + where3 = append(where3, builder.WhereItem{Field: "type", Opt: builder.Eq, Val: 0}) - var having []executor.WhereItem - having = append(having, executor.WhereItem{Field: "age_count", Opt: executor.Gt, Val: 4}) + var having []builder.WhereItem + having = append(having, builder.WhereItem{Field: "age_count", Opt: builder.Gt, Val: 4}) err := aorm.Use(db).Debug(false). Table("person"). @@ -372,12 +347,12 @@ func testHaving(name string, db *sql.DB) { func testOrderBy(name string, db *sql.DB) { var listByOrder []Person - var where []executor.WhereItem - where = append(where, executor.WhereItem{Field: "type", Opt: executor.Eq, Val: 0}) + var where []builder.WhereItem + where = append(where, builder.WhereItem{Field: "type", Opt: builder.Eq, Val: 0}) err := aorm.Use(db).Debug(false). Table("person"). WhereArr(where). - OrderBy("age", executor.Desc). + OrderBy("age", builder.Desc). GetMany(&listByOrder) if err != nil { panic(name + "testOrderBy" + "found err") @@ -386,8 +361,8 @@ func testOrderBy(name string, db *sql.DB) { func testLimit(name string, db *sql.DB) { var list3 []Person - var where1 []executor.WhereItem - where1 = append(where1, executor.WhereItem{Field: "type", Opt: executor.Eq, Val: 0}) + var where1 []builder.WhereItem + where1 = append(where1, builder.WhereItem{Field: "type", Opt: builder.Eq, Val: 0}) err1 := aorm.Use(db).Debug(false). Table("person"). WhereArr(where1). @@ -398,8 +373,8 @@ func testLimit(name string, db *sql.DB) { } var list4 []Person - var where2 []executor.WhereItem - where2 = append(where2, executor.WhereItem{Field: "type", Opt: executor.Eq, Val: 0}) + var where2 []builder.WhereItem + where2 = append(where2, builder.WhereItem{Field: "type", Opt: builder.Eq, Val: 0}) err := aorm.Use(db).Debug(false). Table("person"). WhereArr(where2). @@ -411,6 +386,9 @@ func testLimit(name string, db *sql.DB) { } func testLock(name string, db *sql.DB, id int64) { + if name == "sqlite3" { + return + } var itemByLock Person err := aorm.Use(db).Debug(false).LockForUpdate(true).Where(&Person{Id: null.IntFrom(id)}).GetOne(&itemByLock) @@ -530,7 +508,6 @@ func testExec(name string, db *sql.DB) { } func testTransaction(name string, db *sql.DB) { - tx, _ := db.Begin() id, errInsert := aorm.Use(tx).Debug(false).Insert(&Person{ @@ -578,18 +555,21 @@ func testTransaction(name string, db *sql.DB) { } func testTruncate(name string, db *sql.DB) { + if name == "sqlite3" { + return + } + _, err := aorm.Use(db).Debug(false).Table("person").Truncate() if err != nil { - panic(name + "testTruncate" + "found err") + panic(name + " testTruncate " + "found err") } } func testHelper(name string, db *sql.DB) { - var list2 []ArticleVO - var where2 []executor.WhereItem - where2 = append(where2, executor.WhereItem{Field: "o.type", Opt: executor.Eq, Val: 0}) - where2 = append(where2, executor.WhereItem{Field: "p.age", Opt: executor.In, Val: []int{18, 20}}) + var where2 []builder.WhereItem + where2 = append(where2, builder.WhereItem{Field: "o.type", Opt: builder.Eq, Val: 0}) + where2 = append(where2, builder.WhereItem{Field: "p.age", Opt: builder.In, Val: []int{18, 20}}) err := aorm.Use(db).Debug(false). Table("article o"). LeftJoin("person p", helper.Ul("p.id=o.personId")).