diff --git a/README.md b/README.md index be631d6..af8db74 100644 Binary files a/README.md and b/README.md differ diff --git a/README_zh.md b/README_zh.md index 3770d48..1514d71 100644 Binary files a/README_zh.md and b/README_zh.md differ diff --git a/builder/aggregation.go b/builder/aggregation.go index bdd5c82..db04ab4 100644 --- a/builder/aggregation.go +++ b/builder/aggregation.go @@ -11,7 +11,7 @@ type FloatStruct struct { } // Count 聚合函数-数量 -func (b *Builder) Count(fieldName string) (int64, error) { +func (b *Builder) Count(fieldName interface{}) (int64, error) { var obj []IntStruct err := b.SelectCount(fieldName, "c", "").GetMany(&obj) if err != nil { diff --git a/builder/builder.go b/builder/builder.go index f8b0eac..f8b97bf 100644 --- a/builder/builder.go +++ b/builder/builder.go @@ -37,6 +37,11 @@ type OrderItem struct { OrderType string } +type LimitItem struct { + offset int + pageSize int +} + type JoinItem struct { joinType string table interface{} diff --git a/builder/crud.go b/builder/crud.go index d354151..9fd1a95 100644 --- a/builder/crud.go +++ b/builder/crud.go @@ -33,66 +33,55 @@ const RawEq = "RawEq" // Builder 查询记录所需要的条件 type Builder struct { - //数据库操作连接 LinkCommon model.LinkCommon + driverName string table interface{} tableAlias string - //查询参数 - tableName string - selectList []SelectItem - selectExpList []*SelectExpItem - groupList []GroupItem - whereList []WhereItem - joinList []JoinItem - havingList []WhereItem - orderList []OrderItem - offset int - pageSize int + selectList []SelectItem + selectExpList []*SelectExpItem + groupList []GroupItem + whereList []WhereItem + joinList []JoinItem + havingList []WhereItem + orderList []OrderItem + limitItem LimitItem + distinct bool isDebug bool isLockForUpdate bool //sql与参数 - sql string - paramList []interface{} - - //驱动名字 - driverName string + sql string + args []interface{} } -func (b *Builder) Distinct(distinct bool) *Builder { - b.distinct = distinct +// Debug 链式操作-是否开启调试,打印sql +func (b *Builder) Debug(isDebug bool) *Builder { + b.isDebug = isDebug return b } +// Driver 驱动类型 func (b *Builder) Driver(driverName string) *Builder { b.driverName = driverName return b } -func (b *Builder) getTableNameCommon(typeOf reflect.Type, valueOf reflect.Value) string { - if b.table != nil { - return getTableNameByTable(b.table) - } - - return getTableNameByReflect(typeOf, valueOf) +// Distinct 过滤重复记录 +func (b *Builder) Distinct(distinct bool) *Builder { + b.distinct = distinct + return b } -func getTagMap(fieldTag string) map[string]string { - var fieldMap = make(map[string]string) - if "" != fieldTag { - tagArr := strings.Split(fieldTag, ";") - for j := 0; j < len(tagArr); j++ { - tagArrArr := strings.Split(tagArr[j], ":") - fieldMap[tagArrArr[0]] = "" - if len(tagArrArr) > 1 { - fieldMap[tagArrArr[0]] = tagArrArr[1] - } - } +// Table 链式操作-从哪个表查询,允许直接写别名,例如 person p +func (b *Builder) Table(table interface{}, alias ...string) *Builder { + b.table = table + if len(alias) > 0 { + b.tableAlias = alias[0] } - return fieldMap + return b } // Insert 增加记录 @@ -104,7 +93,7 @@ func (b *Builder) Insert(dest interface{}) (int64, error) { var primaryKey = "" var keys []string - var paramList []any + var args []any var place []string for i := 0; i < typeOf.Elem().NumField(); i++ { key, tagMap := getFieldNameByReflect(typeOf.Elem().Field(i)) @@ -121,7 +110,7 @@ func (b *Builder) Insert(dest interface{}) (int64, error) { val := valueOf.Elem().Field(i).Field(0).Field(0).Interface() keys = append(keys, key) - paramList = append(paramList, val) + args = append(args, val) place = append(place, "?") } } @@ -129,23 +118,23 @@ func (b *Builder) Insert(dest interface{}) (int64, error) { sqlStr := "INSERT INTO " + b.getTableNameCommon(typeOf, valueOf) + " (" + strings.Join(keys, ",") + ") VALUES (" + strings.Join(place, ",") + ")" if b.driverName == model.Mssql { - return b.insertForMssqlOrPostgres(sqlStr+"; SELECT SCOPE_IDENTITY()", paramList...) + return b.insertForMssqlOrPostgres(sqlStr+"; SELECT SCOPE_IDENTITY()", args...) } else if b.driverName == model.Postgres { sqlStr = convertToPostgresSql(sqlStr) - return b.insertForMssqlOrPostgres(sqlStr+" RETURNING "+primaryKey, paramList...) + return b.insertForMssqlOrPostgres(sqlStr+" RETURNING "+primaryKey, args...) } else { - return b.insertForCommon(sqlStr, paramList...) + return b.insertForCommon(sqlStr, args...) } } //对于Mssql,Postgres类型数据库,为了获取最后插入的id,需要改写入为查询 -func (b *Builder) insertForMssqlOrPostgres(sql string, paramList ...any) (int64, error) { +func (b *Builder) insertForMssqlOrPostgres(sql string, args ...any) (int64, error) { if b.isDebug { fmt.Println(sql) - fmt.Println(paramList...) + fmt.Println(args...) } - rows, err := b.LinkCommon.Query(sql, paramList...) + rows, err := b.LinkCommon.Query(sql, args...) if err != nil { return 0, err } @@ -158,8 +147,8 @@ func (b *Builder) insertForMssqlOrPostgres(sql string, paramList ...any) (int64, } //对于非Mssql,Postgres类型数据库,可以直接获取最后插入的id -func (b *Builder) insertForCommon(sql string, paramList ...any) (int64, error) { - res, err := b.Exec(sql, paramList...) +func (b *Builder) insertForCommon(sql string, args ...any) (int64, error) { + res, err := b.RawSql(sql, args...).Exec() if err != nil { return 0, err } @@ -172,24 +161,10 @@ func (b *Builder) insertForCommon(sql string, paramList ...any) (int64, error) { return lastId, nil } -//对于Postgres数据库,不支持?占位符,支持$1,$2类型,需要做转换 -func convertToPostgresSql(sqlStr string) string { - t := 1 - for { - if strings.Index(sqlStr, "?") == -1 { - break - } - sqlStr = strings.Replace(sqlStr, "?", "$"+strconv.Itoa(t), 1) - t += 1 - } - - return sqlStr -} - // InsertBatch 批量增加记录 func (b *Builder) InsertBatch(values interface{}) (int64, error) { var keys []string - var paramList []any + var args []any var place []string valueOf := reflect.ValueOf(values).Elem() @@ -211,7 +186,7 @@ func (b *Builder) InsertBatch(values interface{}) (int64, error) { } val := valueOf.Index(j).Elem().Field(i).Field(0).Field(0).Interface() - paramList = append(paramList, val) + args = append(args, val) placeItem = append(placeItem, "?") } } @@ -225,7 +200,7 @@ func (b *Builder) InsertBatch(values interface{}) (int64, error) { sqlStr = convertToPostgresSql(sqlStr) } - res, err := b.Exec(sqlStr, paramList...) + res, err := b.RawSql(sqlStr, args...).Exec() if err != nil { return 0, err } @@ -238,24 +213,6 @@ func (b *Builder) InsertBatch(values interface{}) (int64, error) { return count, nil } -// GetRows 获取行操作 -func (b *Builder) GetRows() (*sql.Rows, error) { - sqlStr, paramList := b.GetSqlAndParams() - - smt, errSmt := b.LinkCommon.Prepare(sqlStr) - if errSmt != nil { - return nil, errSmt - } - //defer smt.Close() - - rows, errRows := smt.Query(paramList...) - if errRows != nil { - return nil, errRows - } - - return rows, nil -} - // GetMany 查询记录(新) func (b *Builder) GetMany(values interface{}) error { rows, errRows := b.GetRows() @@ -324,53 +281,17 @@ func (b *Builder) GetOne(obj interface{}) error { return nil } -// RawSql 执行原始的sql语句 -func (b *Builder) RawSql(sql string, paramList ...interface{}) *Builder { - b.sql = sql - b.paramList = paramList - return b -} - -func (b *Builder) GetSqlAndParams() (string, []interface{}) { - if b.sql != "" { - return b.sql, b.paramList - } - - var paramList []interface{} - tableName := getTableNameByTable(b.table) - fieldStr, paramList := b.handleSelect(paramList) - whereStr, paramList := b.handleWhere(paramList) - joinStr, paramList := b.handleJoin(paramList) - groupStr, paramList := b.handleGroup(paramList) - havingStr, paramList := b.handleHaving(paramList) - orderStr, paramList := b.handleOrder(paramList) - limitStr, paramList := b.handleLimit(paramList) - lockStr := b.handleLockForUpdate() - - sql := "SELECT " + fieldStr + " FROM " + tableName + " " + b.tableAlias + joinStr + whereStr + groupStr + havingStr + orderStr + limitStr + lockStr - if b.driverName == model.Postgres { - sql = convertToPostgresSql(sql) - } - - if b.isDebug { - fmt.Println(sql) - fmt.Println(paramList...) - } - - return sql, paramList -} - // Update 更新记录 func (b *Builder) Update(dest interface{}) (int64, error) { typeOf := reflect.TypeOf(dest) valueOf := reflect.ValueOf(dest) - var paramList []any - setStr, paramList := b.handleSet(typeOf, valueOf, paramList) - whereStr, paramList := b.handleWhere(paramList) + var args []any + setStr, args := b.handleSet(typeOf, valueOf, args) + whereStr, args := b.handleWhere(args) sqlStr := "UPDATE " + b.getTableNameCommon(typeOf, valueOf) + setStr + whereStr - return b.ExecAffected(sqlStr, paramList...) + return b.execAffected(sqlStr, args...) } // Delete 删除记录 @@ -389,208 +310,11 @@ func (b *Builder) Delete(destList ...interface{}) (int64, error) { tableName = getTableNameByTable(b.table) } - var paramList []any - whereStr, paramList := b.handleWhere(paramList) + var args []any + whereStr, args := b.handleWhere(args) sqlStr := "DELETE FROM " + tableName + whereStr - return b.ExecAffected(sqlStr, paramList...) -} - -// Truncate 清空记录 -func (b *Builder) Truncate() (int64, error) { - sqlStr := "" - if b.driverName == model.Sqlite3 { - sqlStr = "DELETE FROM " + getTableNameByTable(b.table) - } else { - sqlStr = "TRUNCATE TABLE " + getTableNameByTable(b.table) - } - - return b.ExecAffected(sqlStr) -} - -// Exists 存在某记录 -func (b *Builder) Exists() (bool, error) { - var obj IntStruct - - err := b.selectCommon("", "1 AS c", nil, "").Limit(0, 1).GetOne(&obj) - if err != nil { - return false, err - } - - if obj.C.Int64 == 1 { - return true, nil - } else { - return false, nil - } -} - -// DoesntExist 不存在某记录 -func (b *Builder) DoesntExist() (bool, error) { - isE, err := b.Exists() - return !isE, err -} - -// Value 字段值 -func (b *Builder) Value(field interface{}, dest interface{}) error { - b.Select(field).Limit(0, 1) - - fieldName := getFieldName(field) - - rows, errRows := b.GetRows() - defer rows.Close() - if errRows != nil { - return errRows - } - - destValue := reflect.ValueOf(dest).Elem() - - //从数据库中读出来的字段名字 - columnNameList, errColumns := rows.Columns() - if errColumns != nil { - return errColumns - } - - for rows.Next() { - var scans []interface{} - for _, columnName := range columnNameList { - if fieldName == columnName { - scans = append(scans, destValue.Addr().Interface()) - } else { - var emptyVal interface{} - scans = append(scans, &emptyVal) - } - } - - err := rows.Scan(scans...) - if err != nil { - return err - } - } - - return nil -} - -// Pluck 获取某一列的值 -func (b *Builder) Pluck(field interface{}, values interface{}) error { - b.Select(field) - fieldName := getFieldName(field) - - rows, errRows := b.GetRows() - defer rows.Close() - if errRows != nil { - return errRows - } - - destSlice := reflect.Indirect(reflect.ValueOf(values)) - destType := destSlice.Type().Elem() - destValue := reflect.New(destType).Elem() - - //从数据库中读出来的字段名字 - columnNameList, errColumns := rows.Columns() - if errColumns != nil { - return errColumns - } - - for rows.Next() { - var scans []interface{} - for _, columnName := range columnNameList { - if fieldName == columnName { - scans = append(scans, destValue.Addr().Interface()) - } else { - var emptyVal interface{} - scans = append(scans, &emptyVal) - } - } - - err := rows.Scan(scans...) - if err != nil { - return err - } - - destSlice.Set(reflect.Append(destSlice, destValue)) - } - - return nil -} - -// Increment 某字段自增 -func (b *Builder) Increment(field interface{}, step int) (int64, error) { - var paramList []any - paramList = append(paramList, step) - whereStr, paramList := b.handleWhere(paramList) - sqlStr := "UPDATE " + getTableNameByTable(b.table) + " SET " + getFieldName(field) + "=" + getFieldName(field) + "+?" + whereStr - - return b.ExecAffected(sqlStr, paramList...) -} - -// Decrement 某字段自减 -func (b *Builder) Decrement(field interface{}, step int) (int64, error) { - var paramList []any - paramList = append(paramList, step) - whereStr, paramList := b.handleWhere(paramList) - sqlStr := "UPDATE " + getTableNameByTable(b.table) + " SET " + getFieldName(field) + "=" + getFieldName(field) + "-?" + whereStr - - return b.ExecAffected(sqlStr, paramList...) -} - -// Exec 通用执行-新增,更新,删除 -func (b *Builder) Exec(sql string, paramList ...interface{}) (sql.Result, error) { - if b.driverName == model.Postgres { - sql = convertToPostgresSql(sql) - } - - if b.isDebug { - fmt.Println(sql) - fmt.Println(paramList...) - } - - smt, err1 := b.LinkCommon.Prepare(sql) - if err1 != nil { - return nil, err1 - } - defer smt.Close() - - res, err2 := smt.Exec(paramList...) - if err2 != nil { - return nil, err2 - } - - //b.clear() - return res, nil -} - -// ExecAffected 通用执行-更新,删除 -func (b *Builder) ExecAffected(sql string, paramList ...interface{}) (int64, error) { - if b.driverName == model.Postgres { - sql = convertToPostgresSql(sql) - } - - res, err := b.Exec(sql, paramList...) - if err != nil { - return 0, err - } - - count, err := res.RowsAffected() - if err != nil { - return 0, err - } - - return count, nil -} - -// Debug 链式操作-是否开启调试,打印sql -func (b *Builder) Debug(isDebug bool) *Builder { - b.isDebug = isDebug - return b -} - -// Table 链式操作-从哪个表查询,允许直接写别名,例如 person p -func (b *Builder) Table(table interface{}, alias ...string) *Builder { - b.table = table - if len(alias) > 0 { - b.tableAlias = alias[0] - } - return b + return b.execAffected(sqlStr, args...) } // GroupBy 链式操作,以某字段进行分组 @@ -602,39 +326,104 @@ func (b *Builder) GroupBy(field interface{}, prefix ...string) *Builder { return b } -// OrderBy 链式操作,以某字段进行排序 -func (b *Builder) OrderBy(field interface{}, orderType string, prefix ...string) *Builder { - b.orderList = append(b.orderList, OrderItem{ - Prefix: getPrefixByField(field, prefix...), - Field: field, - OrderType: orderType, - }) - - return b -} - // Limit 链式操作,分页 func (b *Builder) Limit(offset int, pageSize int) *Builder { - b.offset = offset - b.pageSize = pageSize + b.limitItem = LimitItem{ + offset: offset, + pageSize: pageSize, + } return b } // Page 链式操作,分页 func (b *Builder) Page(pageNum int, pageSize int) *Builder { - b.offset = (pageNum - 1) * pageSize - b.pageSize = pageSize + b.limitItem = LimitItem{ + offset: (pageNum - 1) * pageSize, + pageSize: pageSize, + } return b } -// LockForUpdate 加锁, sqlte3不支持此操作 +// LockForUpdate 加锁, sqlite3不支持此操作 func (b *Builder) LockForUpdate(isLockForUpdate bool) *Builder { b.isLockForUpdate = isLockForUpdate return b } +// Truncate 清空记录 +func (b *Builder) Truncate() (int64, error) { + sqlStr := "" + if b.driverName == model.Sqlite3 { + sqlStr = "DELETE FROM " + getTableNameByTable(b.table) + } else { + sqlStr = "TRUNCATE TABLE " + getTableNameByTable(b.table) + } + + return b.execAffected(sqlStr) +} + +// RawSql 执行原始的sql语句 +func (b *Builder) RawSql(sql string, args ...interface{}) *Builder { + b.sql = sql + b.args = args + return b +} + +// GetRows 获取行操作 +func (b *Builder) GetRows() (*sql.Rows, error) { + sql, args := b.GetSqlAndParams() + + if b.driverName == model.Postgres { + sql = convertToPostgresSql(sql) + } + + if b.isDebug { + fmt.Println(sql) + fmt.Println(args...) + } + + smt, errSmt := b.LinkCommon.Prepare(sql) + if errSmt != nil { + return nil, errSmt + } + //defer smt.Close() + + rows, errRows := smt.Query(args...) + if errRows != nil { + return nil, errRows + } + + return rows, nil +} + +// Exec 通用执行-新增,更新,删除 +func (b *Builder) Exec() (sql.Result, error) { + if b.driverName == model.Postgres { + b.sql = convertToPostgresSql(b.sql) + } + + if b.isDebug { + fmt.Println(b.sql) + fmt.Println(b.args...) + } + + smt, err1 := b.LinkCommon.Prepare(b.sql) + if err1 != nil { + return nil, err1 + } + defer smt.Close() + + res, err2 := smt.Exec(b.args...) + if err2 != nil { + return nil, err2 + } + + //b.clear() + return res, nil +} + //拼接SQL,查询与筛选通用操作 -func (b *Builder) whereAndHaving(where []WhereItem, paramList []any, isFromHaving bool) ([]string, []any) { +func (b *Builder) whereAndHaving(where []WhereItem, args []any, isFromHaving bool) ([]string, []any) { var whereList []string for i := 0; i < len(where); i++ { allFieldName := "" @@ -655,12 +444,12 @@ func (b *Builder) whereAndHaving(where []WhereItem, paramList []any, isFromHavin } if "**builder.Builder" == reflect.TypeOf(where[i].Val).String() { - executor := *(**Builder)(unsafe.Pointer(reflect.ValueOf(where[i].Val).Pointer())) - subSql, subParams := executor.GetSqlAndParams() + subBuilder := *(**Builder)(unsafe.Pointer(reflect.ValueOf(where[i].Val).Pointer())) + subSql, subParams := subBuilder.GetSqlAndParams() if where[i].Opt != Raw { whereList = append(whereList, allFieldName+" "+where[i].Opt+" "+"("+subSql+")") - paramList = append(paramList, subParams...) + args = append(args, subParams...) } } else { if where[i].Opt == Eq || where[i].Opt == Ne || where[i].Opt == Gt || where[i].Opt == Ge || where[i].Opt == Lt || where[i].Opt == Le { @@ -677,13 +466,13 @@ func (b *Builder) whereAndHaving(where []WhereItem, paramList []any, isFromHavin } } - paramList = append(paramList, fmt.Sprintf("%v", where[i].Val)) + args = append(args, fmt.Sprintf("%v", where[i].Val)) } if where[i].Opt == Between || where[i].Opt == NotBetween { values := toAnyArr(where[i].Val) whereList = append(whereList, allFieldName+" "+where[i].Opt+" "+"(?) AND (?)") - paramList = append(paramList, values...) + args = append(args, values...) } if where[i].Opt == Like || where[i].Opt == NotLike { @@ -693,7 +482,7 @@ func (b *Builder) whereAndHaving(where []WhereItem, paramList []any, isFromHavin str := fmt.Sprintf("%v", values[j]) if "%" != str { - paramList = append(paramList, str) + args = append(args, str) valueStr = append(valueStr, "?") } else { valueStr = append(valueStr, "'"+str+"'") @@ -711,7 +500,7 @@ func (b *Builder) whereAndHaving(where []WhereItem, paramList []any, isFromHavin } whereList = append(whereList, allFieldName+" "+where[i].Opt+" "+"("+strings.Join(placeholder, ",")+")") - paramList = append(paramList, values...) + args = append(args, values...) } if where[i].Opt == Raw { @@ -723,7 +512,7 @@ func (b *Builder) whereAndHaving(where []WhereItem, paramList []any, isFromHavin } } } - return whereList, paramList + return whereList, args } func (b *Builder) getConcatForFloat(vars ...string) string { @@ -743,3 +532,80 @@ func (b *Builder) getConcatForLike(vars ...string) string { return "CONCAT(" + strings.Join(vars, ",") + ")" } } + +func (b *Builder) getTableNameCommon(typeOf reflect.Type, valueOf reflect.Value) string { + if b.table != nil { + return getTableNameByTable(b.table) + } + + return getTableNameByReflect(typeOf, valueOf) +} + +func (b *Builder) GetSqlAndParams() (string, []interface{}) { + if b.sql != "" { + return b.sql, b.args + } + + var args []interface{} + tableName := getTableNameByTable(b.table) + fieldStr, args := b.handleSelect(args) + whereStr, args := b.handleWhere(args) + joinStr, args := b.handleJoin(args) + groupStr, args := b.handleGroup(args) + havingStr, args := b.handleHaving(args) + orderStr, args := b.handleOrder(args) + limitStr, args := b.handleLimit(args) + lockStr := b.handleLockForUpdate() + + sql := "SELECT " + fieldStr + " FROM " + tableName + " " + b.tableAlias + joinStr + whereStr + groupStr + havingStr + orderStr + limitStr + lockStr + + return sql, args +} + +// execAffected 通用执行-更新,删除 +func (b *Builder) execAffected(sql string, args ...interface{}) (int64, error) { + if b.driverName == model.Postgres { + sql = convertToPostgresSql(sql) + } + + res, err := b.RawSql(sql, args...).Exec() + if err != nil { + return 0, err + } + + count, err := res.RowsAffected() + if err != nil { + return 0, err + } + + return count, nil +} + +func getTagMap(fieldTag string) map[string]string { + var fieldMap = make(map[string]string) + if "" != fieldTag { + tagArr := strings.Split(fieldTag, ";") + for j := 0; j < len(tagArr); j++ { + tagArrArr := strings.Split(tagArr[j], ":") + fieldMap[tagArrArr[0]] = "" + if len(tagArrArr) > 1 { + fieldMap[tagArrArr[0]] = tagArrArr[1] + } + } + } + return fieldMap +} + +//对于Postgres数据库,不支持?占位符,支持$1,$2类型,需要做转换 +func convertToPostgresSql(sqlStr string) string { + t := 1 + for { + if strings.Index(sqlStr, "?") == -1 { + break + } + sqlStr = strings.Replace(sqlStr, "?", "$"+strconv.Itoa(t), 1) + t += 1 + } + + return sqlStr +} diff --git a/builder/exists.go b/builder/exists.go new file mode 100644 index 0000000..1b05f6e --- /dev/null +++ b/builder/exists.go @@ -0,0 +1,23 @@ +package builder + +// Exists 存在某记录 +func (b *Builder) Exists() (bool, error) { + var obj IntStruct + + err := b.selectCommon("", "1 AS c", nil, "").Limit(0, 1).GetOne(&obj) + if err != nil { + return false, err + } + + if obj.C.Int64 == 1 { + return true, nil + } else { + return false, nil + } +} + +// DoesntExist 不存在某记录 +func (b *Builder) DoesntExist() (bool, error) { + isE, err := b.Exists() + return !isE, err +} diff --git a/builder/handle.go b/builder/handle.go index 9a5316d..62b9fa1 100644 --- a/builder/handle.go +++ b/builder/handle.go @@ -57,8 +57,8 @@ func (b *Builder) handleSelect(paramList []any) (string, []any) { //处理子语句 for i := 0; i < len(b.selectExpList); i++ { - executor := *(b.selectExpList[i].Builder) - subSql, subParamList := executor.GetSqlAndParams() + subBuilder := *(b.selectExpList[i].Builder) + subSql, subParamList := subBuilder.GetSqlAndParams() strList = append(strList, "("+subSql+") AS "+getFieldName(b.selectExpList[i].FieldName)) paramList = append(paramList, subParamList...) } @@ -82,8 +82,8 @@ func (b *Builder) handleWhere(paramList []any) (string, []any) { func (b *Builder) handleSet(typeOf reflect.Type, valueOf reflect.Value, paramList []any) (string, []any) { //如果没有设置表名 - if b.tableName == "" { - b.tableName = getTableNameByReflect(typeOf, valueOf) + if b.table == nil { + b.table = getTableNameByReflect(typeOf, valueOf) } var keys []string @@ -163,19 +163,19 @@ func (b *Builder) handleOrder(paramList []any) (string, []any) { //拼接SQL,分页相关 Postgres数据库分页数量在前偏移在后,其他数据库偏移量在前分页数量在后,另外Mssql数据库的关键词是offset...next func (b *Builder) handleLimit(paramList []any) (string, []any) { - if 0 == b.pageSize { + if 0 == b.limitItem.pageSize { return "", paramList } str := "" if b.driverName == model.Postgres { - paramList = append(paramList, b.pageSize) - paramList = append(paramList, b.offset) + paramList = append(paramList, b.limitItem.pageSize) + paramList = append(paramList, b.limitItem.offset) str = " Limit ? offset ? " } else { - paramList = append(paramList, b.offset) - paramList = append(paramList, b.pageSize) + paramList = append(paramList, b.limitItem.offset) + paramList = append(paramList, b.limitItem.pageSize) str = " Limit ?,? " if b.driverName == model.Mssql { diff --git a/builder/having.go b/builder/having.go index 2d39e27..a633faf 100644 --- a/builder/having.go +++ b/builder/having.go @@ -11,8 +11,8 @@ func (b *Builder) Having(dest interface{}) *Builder { valueOf := reflect.ValueOf(dest) //如果没有设置表名 - if b.tableName == "" { - b.tableName = getTableNameByReflect(typeOf, valueOf) + if b.table == nil { + b.table = getTableNameByReflect(typeOf, valueOf) } for i := 0; i < typeOf.Elem().NumField(); i++ { diff --git a/builder/increment.go b/builder/increment.go new file mode 100644 index 0000000..3723c20 --- /dev/null +++ b/builder/increment.go @@ -0,0 +1,21 @@ +package builder + +// Increment 某字段自增 +func (b *Builder) Increment(field interface{}, step int) (int64, error) { + var paramList []any + paramList = append(paramList, step) + whereStr, paramList := b.handleWhere(paramList) + sqlStr := "UPDATE " + getTableNameByTable(b.table) + " SET " + getFieldName(field) + "=" + getFieldName(field) + "+?" + whereStr + + return b.execAffected(sqlStr, paramList...) +} + +// Decrement 某字段自减 +func (b *Builder) Decrement(field interface{}, step int) (int64, error) { + var paramList []any + paramList = append(paramList, step) + whereStr, paramList := b.handleWhere(paramList) + sqlStr := "UPDATE " + getTableNameByTable(b.table) + " SET " + getFieldName(field) + "=" + getFieldName(field) + "-?" + whereStr + + return b.execAffected(sqlStr, paramList...) +} diff --git a/builder/order.go b/builder/order.go new file mode 100644 index 0000000..80302a0 --- /dev/null +++ b/builder/order.go @@ -0,0 +1,20 @@ +package builder + +func (b *Builder) OrderDescBy(field interface{}, prefix ...string) *Builder { + return b.OrderBy(field, Desc, prefix...) +} + +func (b *Builder) OrderAscBy(field interface{}, prefix ...string) *Builder { + return b.OrderBy(field, Asc, prefix...) +} + +// OrderBy 链式操作,以某字段进行排序 +func (b *Builder) OrderBy(field interface{}, orderType string, prefix ...string) *Builder { + b.orderList = append(b.orderList, OrderItem{ + Prefix: getPrefixByField(field, prefix...), + Field: field, + OrderType: orderType, + }) + + return b +} diff --git a/builder/value.go b/builder/value.go new file mode 100644 index 0000000..8a40528 --- /dev/null +++ b/builder/value.go @@ -0,0 +1,86 @@ +package builder + +import "reflect" + +// Value 字段值 +func (b *Builder) Value(field interface{}, dest interface{}) error { + b.Select(field).Limit(0, 1) + + fieldName := getFieldName(field) + + rows, errRows := b.GetRows() + defer rows.Close() + if errRows != nil { + return errRows + } + + destValue := reflect.ValueOf(dest).Elem() + + //从数据库中读出来的字段名字 + columnNameList, errColumns := rows.Columns() + if errColumns != nil { + return errColumns + } + + for rows.Next() { + var scans []interface{} + for _, columnName := range columnNameList { + if fieldName == columnName { + scans = append(scans, destValue.Addr().Interface()) + } else { + var emptyVal interface{} + scans = append(scans, &emptyVal) + } + } + + err := rows.Scan(scans...) + if err != nil { + return err + } + } + + return nil +} + +// Pluck 获取某一列的值 +func (b *Builder) Pluck(field interface{}, values interface{}) error { + b.Select(field) + fieldName := getFieldName(field) + + rows, errRows := b.GetRows() + defer rows.Close() + if errRows != nil { + return errRows + } + + destSlice := reflect.Indirect(reflect.ValueOf(values)) + destType := destSlice.Type().Elem() + destValue := reflect.New(destType).Elem() + + //从数据库中读出来的字段名字 + columnNameList, errColumns := rows.Columns() + if errColumns != nil { + return errColumns + } + + for rows.Next() { + var scans []interface{} + for _, columnName := range columnNameList { + if fieldName == columnName { + scans = append(scans, destValue.Addr().Interface()) + } else { + var emptyVal interface{} + scans = append(scans, &emptyVal) + } + } + + err := rows.Scan(scans...) + if err != nil { + return err + } + + destSlice.Set(reflect.Append(destSlice, destValue)) + } + + return nil +} diff --git a/builder/where.go b/builder/where.go index d84920f..ef0186e 100644 --- a/builder/where.go +++ b/builder/where.go @@ -11,8 +11,8 @@ func (b *Builder) Where(dest interface{}) *Builder { valueOf := reflect.ValueOf(dest) //如果没有设置表名 - if b.tableName == "" { - b.tableName = getTableNameByReflect(typeOf, valueOf) + if b.table == nil { + b.table = getTableNameByReflect(typeOf, valueOf) } for i := 0; i < typeOf.Elem().NumField(); i++ { diff --git a/migrate_mssql/migrate.go b/migrate_mssql/migrate.go index e3fbc10..5f666ba 100644 --- a/migrate_mssql/migrate.go +++ b/migrate_mssql/migrate.go @@ -220,7 +220,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getColumnStr(columnCode) fmt.Println(sql) - _, err := mm.Ex.Exec(sql) + _, err := mm.Ex.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -232,7 +232,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co if isFind == 0 { sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getColumnStr(columnCode) - _, err := mm.Ex.Exec(sql) + _, err := mm.Ex.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -261,7 +261,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co if !keyMatch || indexCode.NonUnique.Int64 != indexDb.NonUnique.Int64 { sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getIndexStr(indexCode) - _, err := mm.Ex.Exec(sql) + _, err := mm.Ex.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -273,7 +273,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co if isFind == 0 { sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getIndexStr(indexCode) - _, err := mm.Ex.Exec(sql) + _, err := mm.Ex.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -297,9 +297,8 @@ func (mm *MigrateExecutor) createTable(tableFromCode Table, columnsFromCode []Co } sqlStr := "CREATE TABLE " + tableFromCode.TableName.String + " (\n" + strings.Join(fieldArr, ",\n") + "\n) " + ";" - fmt.Println(sqlStr) - _, err := mm.Ex.Exec(sqlStr) + _, err := mm.Ex.RawSql(sqlStr).Exec() if err != nil { fmt.Println(err) } else { diff --git a/migrate_mysql/migrate.go b/migrate_mysql/migrate.go index 7e794b7..ec59ff9 100644 --- a/migrate_mysql/migrate.go +++ b/migrate_mysql/migrate.go @@ -232,7 +232,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co columnCode.Extra.String != columnDb.Extra.String || columnCode.ColumnDefault.String != columnDb.ColumnDefault.String { sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getColumnStr(columnCode) - _, err := mm.Ex.Exec(sql) + _, err := mm.Ex.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -244,7 +244,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co if isFind == 0 { sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getColumnStr(columnCode) - _, err := mm.Ex.Exec(sql) + _, err := mm.Ex.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -263,7 +263,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co isFind = 1 if indexCode.KeyName != indexDb.KeyName || indexCode.NonUnique != indexDb.NonUnique { sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getIndexStr(indexCode) - _, err := mm.Ex.Exec(sql) + _, err := mm.Ex.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -275,7 +275,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co if isFind == 0 { sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getIndexStr(indexCode) - _, err := mm.Ex.Exec(sql) + _, err := mm.Ex.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -287,7 +287,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co func (mm *MigrateExecutor) modifyTableEngine(tableFromCode Table) { sql := "ALTER TABLE " + tableFromCode.TableName.String + " Engine " + tableFromCode.Engine.String - _, err := mm.Ex.Exec(sql) + _, err := mm.Ex.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -297,7 +297,7 @@ func (mm *MigrateExecutor) modifyTableEngine(tableFromCode Table) { func (mm *MigrateExecutor) modifyTableComment(tableFromCode Table) { sql := "ALTER TABLE " + tableFromCode.TableName.String + " Comment " + tableFromCode.TableComment.String - _, err := mm.Ex.Exec(sql) + _, err := mm.Ex.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -318,8 +318,8 @@ func (mm *MigrateExecutor) createTable(tableFromCode Table, columnsFromCode []Co fieldArr = append(fieldArr, getIndexStr(index)) } - sqlStr := "CREATE TABLE `" + tableFromCode.TableName.String + "` (\n" + strings.Join(fieldArr, ",\n") + "\n) " + " ENGINE " + tableFromCode.Engine.String + " COMMENT " + tableFromCode.TableComment.String + ";" - _, err := mm.Ex.Exec(sqlStr) + sql := "CREATE TABLE `" + tableFromCode.TableName.String + "` (\n" + strings.Join(fieldArr, ",\n") + "\n) " + " ENGINE " + tableFromCode.Engine.String + " COMMENT " + tableFromCode.TableComment.String + ";" + _, err := mm.Ex.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { diff --git a/migrate_postgres/migrate.go b/migrate_postgres/migrate.go index 147abaf..7c3bf0d 100644 --- a/migrate_postgres/migrate.go +++ b/migrate_postgres/migrate.go @@ -252,7 +252,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co sql := "ALTER TABLE " + tableFromCode.TableName.String + " alter COLUMN " + getColumnStr(columnCode, "type") //fmt.Println(sql) - _, err := mm.Ex.Exec(sql) + _, err := mm.Ex.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -264,7 +264,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co if isFind == 0 { sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getColumnStr(columnCode, "") - _, err := mm.Ex.Exec(sql) + _, err := mm.Ex.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -283,7 +283,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co isFind = 1 if indexCode.KeyName != indexDb.KeyName || indexCode.NonUnique != indexDb.NonUnique { sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getIndexStr(indexCode) - _, err := mm.Ex.Exec(sql) + _, err := mm.Ex.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -301,7 +301,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co func (mm *MigrateExecutor) modifyTableComment(tableFromCode Table) { sql := "ALTER TABLE " + tableFromCode.TableName.String + " Comment " + tableFromCode.TableComment.String - _, err := mm.Ex.Exec(sql) + _, err := mm.Ex.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -324,10 +324,9 @@ func (mm *MigrateExecutor) createTable(tableFromCode Table, columnsFromCode []Co } } - sqlStr := "CREATE TABLE " + tableFromCode.TableName.String + " (\n" + strings.Join(fieldArr, ",\n") + "\n) " + ";" - fmt.Println(sqlStr) + sql := "CREATE TABLE " + tableFromCode.TableName.String + " (\n" + strings.Join(fieldArr, ",\n") + "\n) " + ";" - _, err := mm.Ex.Exec(sqlStr) + _, err := mm.Ex.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -350,7 +349,7 @@ func (mm *MigrateExecutor) createIndex(tableName string, index Index) { } sql := "CREATE " + keyType + " INDEX " + index.KeyName.String + " on " + tableName + " (" + index.ColumnName.String + ")" - _, err := mm.Ex.Exec(sql) + _, err := mm.Ex.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { diff --git a/migrate_sqlite3/migrate.go b/migrate_sqlite3/migrate.go index 1ddc7e7..5cfbe48 100644 --- a/migrate_sqlite3/migrate.go +++ b/migrate_sqlite3/migrate.go @@ -258,7 +258,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co columnCode.ColumnDefault.String != columnDb.ColumnDefault.String { sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getColumnStr(columnCode) - _, err := mm.Ex.Exec(sql) + _, err := mm.Ex.RawSql(sql).Exec() if err != nil { fmt.Println(sql) fmt.Println(err) @@ -271,7 +271,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co if isFind == 0 { sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getColumnStr(columnCode) - _, err := mm.Ex.Exec(sql) + _, err := mm.Ex.RawSql(sql).Exec() if err != nil { fmt.Println(sql) fmt.Println(err) @@ -291,7 +291,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co isFind = 1 if indexCode.KeyName != indexDb.KeyName || indexCode.NonUnique != indexDb.NonUnique { sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getIndexStr(indexCode) - _, err := mm.Ex.Exec(sql) + _, err := mm.Ex.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -323,8 +323,8 @@ func (mm *MigrateExecutor) createTable(tableFromCode Table, columnsFromCode []Co } //创建表结构与主键索引 - sqlStr := "CREATE TABLE `" + tableFromCode.TableName.String + "` (\n" + strings.Join(fieldArr, ",\n") + "\n) " + ";" - _, err := mm.Ex.Exec(sqlStr) + sql := "CREATE TABLE `" + tableFromCode.TableName.String + "` (\n" + strings.Join(fieldArr, ",\n") + "\n) " + ";" + _, err := mm.Ex.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -347,7 +347,7 @@ func (mm *MigrateExecutor) createIndex(tableName string, index Index) { } sql := "CREATE " + keyType + " INDEX " + index.KeyName.String + " on " + tableName + " (" + index.ColumnName.String + ")" - _, err := mm.Ex.Exec(sql) + _, err := mm.Ex.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { diff --git a/test/aorm_test.go b/test/aorm_test.go index dd3d869..d0536fb 100644 --- a/test/aorm_test.go +++ b/test/aorm_test.go @@ -156,7 +156,7 @@ func TestAll(t *testing.T) { testDistinct(dbItem.DriverName, dbItem.DbLink) - testExec(dbItem.DriverName, dbItem.DbLink) + testRawSql(dbItem.DriverName, dbItem.DbLink, id2) testTransaction(dbItem.DriverName, dbItem.DbLink) testTruncate(dbItem.DriverName, dbItem.DbLink) @@ -556,6 +556,7 @@ func testLock(driver string, db *sql.DB, id int64) { if driver == model.Sqlite3 || driver == model.Mssql { return } + var itemByLock Person err := aorm.Db(db). Debug(false). @@ -680,10 +681,16 @@ func testDistinct(driver string, db *sql.DB) { } } -func testExec(driver string, db *sql.DB) { - _, err := aorm.Db(db).Debug(false).Driver(driver).Exec("UPDATE person SET name = ? WHERE person.id=?", "Bob", 3) +func testRawSql(driver string, db *sql.DB, id2 int64) { + var list []Person + err1 := aorm.Db(db).Debug(false).Driver(driver).RawSql("SELECT * FROM person WHERE id=? AND type=?", id2, 0).GetMany(&list) + if err1 != nil { + panic(err1) + } + + _, err := aorm.Db(db).Debug(false).Driver(driver).RawSql("UPDATE person SET name = ? WHERE id=?", "Bob2", id2).Exec() if err != nil { - panic(driver + "testExec" + "found err") + panic(driver + "testRawSql" + "found err") } }