support table with expression

This commit is contained in:
tangpanqing
2023-03-16 13:29:23 +08:00
parent f4336860df
commit d3e297da13
6 changed files with 140 additions and 36 deletions

View File

@@ -99,10 +99,6 @@ func getPrefixByField(valueOf reflect.Value, prefix ...string) string {
//getTableNameByTable 根据传入的表信息,获取表名
func getTableNameByTable(table interface{}) string {
if table == nil {
panic("当前table不能是nil")
}
valueOf := reflect.ValueOf(table)
if reflect.Ptr == valueOf.Kind() {
return getTableMap(valueOf.Pointer())

View File

@@ -9,7 +9,6 @@ import (
"reflect"
"strconv"
"strings"
"unsafe"
)
const Desc = "DESC"
@@ -284,7 +283,10 @@ func (b *Builder) Update(dest interface{}) (int64, error) {
var args []any
setStr, args := b.handleSet(typeOf, valueOf, args)
whereStr, args := b.handleWhere(args, false)
whereStr, args, err := b.handleWhere(args, false)
if err != nil {
return 0, err
}
query := "UPDATE " + b.getTableNameCommon(typeOf, valueOf) + setStr + whereStr
return b.execAffected(query, args...)
@@ -303,11 +305,17 @@ func (b *Builder) Delete(destList ...interface{}) (int64, error) {
}
if tableName == "" {
if b.table == nil {
return 0, errors.New("表名不能为空")
}
tableName = getTableNameByTable(b.table)
}
var args []any
whereStr, args := b.handleWhere(args, false)
whereStr, args, err := b.handleWhere(args, false)
if err != nil {
return 0, err
}
query := "DELETE FROM " + tableName + whereStr
return b.execAffected(query, args...)
@@ -348,6 +356,10 @@ func (b *Builder) LockForUpdate(isLockForUpdate bool) *Builder {
// Truncate 清空记录
func (b *Builder) Truncate() (int64, error) {
if b.table == nil {
return 0, errors.New("表名不能为空")
}
query := ""
if b.Link.DriverName() == driver.Sqlite3 {
query = "DELETE FROM " + getTableNameByTable(b.table)
@@ -367,7 +379,10 @@ func (b *Builder) RawSql(query string, args ...interface{}) *Builder {
// GetRows 获取行操作
func (b *Builder) GetRows() (*sql.Rows, error) {
query, args := b.GetSqlAndParams()
query, args, err := b.GetSqlAndParams()
if err != nil {
return nil, err
}
if b.Link.DriverName() == driver.Postgres {
query = convertToPostgresSql(query)
@@ -419,7 +434,7 @@ func (b *Builder) Exec() (sql.Result, error) {
}
//拼接SQL,查询与筛选通用操作
func (b *Builder) whereAndHaving(where []WhereItem, args []any, isFromHaving bool, needPrefix bool) ([]string, []any) {
func (b *Builder) whereAndHaving(where []WhereItem, args []any, isFromHaving bool, needPrefix bool) ([]string, []any, error) {
var whereList []string
for i := 0; i < len(where); i++ {
valueOfField := reflect.ValueOf(where[i].Field)
@@ -445,8 +460,11 @@ func (b *Builder) whereAndHaving(where []WhereItem, args []any, isFromHaving boo
}
if "**builder.Builder" == reflect.TypeOf(where[i].Val).String() {
subBuilder := *(**Builder)(unsafe.Pointer(reflect.ValueOf(where[i].Val).Pointer()))
subSql, subParams := subBuilder.GetSqlAndParams()
subBuilder := *(**Builder)(reflect.ValueOf(where[i].Val).UnsafePointer())
subSql, subParams, err := subBuilder.GetSqlAndParams()
if err != nil {
return whereList, args, err
}
if where[i].Opt != Raw {
whereList = append(whereList, allFieldName+" "+where[i].Opt+" "+"("+subSql+")")
@@ -518,7 +536,7 @@ func (b *Builder) whereAndHaving(where []WhereItem, args []any, isFromHaving boo
}
}
}
return whereList, args
return whereList, args, nil
}
func (b *Builder) getConcatForFloat(vars ...string) string {
@@ -547,25 +565,40 @@ func (b *Builder) getTableNameCommon(typeOf reflect.Type, valueOf reflect.Value)
return getTableNameByReflect(typeOf, valueOf)
}
func (b *Builder) GetSqlAndParams() (string, []interface{}) {
func (b *Builder) GetSqlAndParams() (string, []interface{}, error) {
if b.query != "" {
return b.query, b.args
return b.query, b.args, nil
}
var args []interface{}
fieldStr, args := b.handleSelect(args)
tableName := getTableNameByTable(b.table)
selectStr, args, err := b.handleSelect(args)
if err != nil {
return "", args, err
}
tableStr, args, err := b.handleTable(args)
if err != nil {
return "", args, err
}
joinStr, args := b.handleJoin(args)
whereStr, args := b.handleWhere(args, true)
whereStr, args, err := b.handleWhere(args, true)
if err != nil {
return "", args, err
}
groupStr, args := b.handleGroup(args)
havingStr, args := b.handleHaving(args)
havingStr, args, err := b.handleHaving(args)
if err != nil {
return "", args, err
}
orderStr, args := b.handleOrder(args)
limitStr, args := b.handleLimit(args)
lockStr := b.handleLockForUpdate()
query := "SELECT " + fieldStr + " FROM " + tableName + " " + b.tableAlias + joinStr + whereStr + groupStr + havingStr + orderStr + limitStr + lockStr
query := selectStr + tableStr + joinStr + whereStr + groupStr + havingStr + orderStr + limitStr + lockStr
return query, args
return query, args, nil
}
// execAffected 通用执行-更新,删除

View File

@@ -1,6 +1,8 @@
package builder
import (
"errors"
"fmt"
"github.com/tangpanqing/aorm/driver"
"reflect"
"strings"
@@ -29,7 +31,7 @@ func handleSelectWith(selectItem SelectItem) string {
}
//拼接SQL,字段相关
func (b *Builder) handleSelect(paramList []any) (string, []any) {
func (b *Builder) handleSelect(paramList []any) (string, []any, error) {
fieldStr := ""
if b.distinct {
fieldStr += "DISTINCT "
@@ -37,7 +39,7 @@ func (b *Builder) handleSelect(paramList []any) (string, []any) {
if len(b.selectList) == 0 && len(b.selectExpList) == 0 {
fieldStr += "*"
return fieldStr, paramList
return "SELECT " + fieldStr, paramList, nil
}
var strList []string
@@ -59,24 +61,63 @@ func (b *Builder) handleSelect(paramList []any) (string, []any) {
//处理子语句
for i := 0; i < len(b.selectExpList); i++ {
subBuilder := *(b.selectExpList[i].Builder)
subSql, subParamList := subBuilder.GetSqlAndParams()
subSql, subParamList, err := subBuilder.GetSqlAndParams()
if err != nil {
return "", paramList, err
}
strList = append(strList, "("+subSql+") AS "+getFieldNameByField(b.selectExpList[i].FieldName))
paramList = append(paramList, subParamList...)
}
fieldStr += strings.Join(strList, ",")
return fieldStr, paramList
return "SELECT " + fieldStr, paramList, nil
}
func (b *Builder) handleTable(paramList []any) (string, []any, error) {
if b.table == nil {
return "", paramList, errors.New("表不能为空")
}
var tableName string
valueOf := reflect.ValueOf(b.table)
if reflect.Ptr == valueOf.Kind() {
if "**builder.Builder" != valueOf.Type().String() {
tableName = getTableMap(valueOf.Pointer())
} else {
if b.tableAlias == "" {
return "", paramList, errors.New("别名不能为空")
}
subBuilder := *(**Builder)(valueOf.UnsafePointer())
subSql, subParamList, err := subBuilder.GetSqlAndParams()
if err != nil {
return "", paramList, err
}
tableName = "(" + subSql + ")"
paramList = append(paramList, subParamList...)
}
} else {
tableName = fmt.Sprintf("%v", b.table)
}
return " FROM " + tableName + " " + b.tableAlias, paramList, nil
}
//拼接SQL,查询条件
func (b *Builder) handleWhere(paramList []any, needPrefix bool) (string, []any) {
func (b *Builder) handleWhere(paramList []any, needPrefix bool) (string, []any, error) {
if len(b.whereList) == 0 {
return "", paramList
return "", paramList, nil
}
strList, paramList := b.whereAndHaving(b.whereList, paramList, false, needPrefix)
strList, paramList, err := b.whereAndHaving(b.whereList, paramList, false, needPrefix)
if err != nil {
return "", paramList, nil
}
return " WHERE " + strings.Join(strList, " AND "), paramList
return " WHERE " + strings.Join(strList, " AND "), paramList, nil
}
//拼接SQL,更新信息
@@ -149,14 +190,17 @@ func (b *Builder) handleGroup(paramList []any) (string, []any) {
}
//拼接SQL,结果筛选
func (b *Builder) handleHaving(paramList []any) (string, []any) {
func (b *Builder) handleHaving(paramList []any) (string, []any, error) {
if len(b.havingList) == 0 {
return "", paramList
return "", paramList, nil
}
strList, paramList := b.whereAndHaving(b.havingList, paramList, true, true)
strList, paramList, err := b.whereAndHaving(b.havingList, paramList, true, true)
if err != nil {
return "", paramList, err
}
return " Having " + strings.Join(strList, " AND "), paramList
return " Having " + strings.Join(strList, " AND "), paramList, nil
}
//拼接SQL,结果排序

View File

@@ -1,12 +1,20 @@
package builder
import "errors"
// Increment 某字段自增
func (b *Builder) Increment(field interface{}, step int) (int64, error) {
var vars []any
vars = append(vars, step)
whereStr, vars := b.handleWhere(vars, false)
query := "UPDATE " + getTableNameByTable(b.table) + " SET " + getFieldNameByField(field) + "=" + getFieldNameByField(field) + "+?" + whereStr
whereStr, vars, err := b.handleWhere(vars, false)
if err != nil {
return 0, err
}
if b.table == nil {
return 0, errors.New("表名不能为空")
}
query := "UPDATE " + getTableNameByTable(b.table) + " SET " + getFieldNameByField(field) + "=" + getFieldNameByField(field) + "+?" + whereStr
return b.execAffected(query, vars...)
}
@@ -14,8 +22,14 @@ func (b *Builder) Increment(field interface{}, step int) (int64, error) {
func (b *Builder) Decrement(field interface{}, step int) (int64, error) {
var vars []any
vars = append(vars, step)
whereStr, vars := b.handleWhere(vars, false)
query := "UPDATE " + getTableNameByTable(b.table) + " SET " + getFieldNameByField(field) + "=" + getFieldNameByField(field) + "-?" + whereStr
whereStr, vars, err := b.handleWhere(vars, false)
if err != nil {
return 0, err
}
if b.table == nil {
return 0, errors.New("表名不能为空")
}
query := "UPDATE " + getTableNameByTable(b.table) + " SET " + getFieldNameByField(field) + "=" + getFieldNameByField(field) + "-?" + whereStr
return b.execAffected(query, vars...)
}

View File

@@ -38,6 +38,16 @@ func (b *Builder) SelectAvg(field interface{}, fieldNew interface{}, prefix ...s
return b.selectCommon("Avg", field, fieldNew, prefix...)
}
// SelectConcat 链式操作-concat(field) as field_new
func (b *Builder) SelectConcat(field interface{}, fieldNew interface{}, prefix ...string) *Builder {
return b.selectCommon("concat", field, fieldNew, prefix...)
}
// SelectGroupConcat 链式操作-group_concat(field) as field_new
func (b *Builder) SelectGroupConcat(field interface{}, fieldNew interface{}, prefix ...string) *Builder {
return b.selectCommon("group_concat", field, fieldNew, prefix...)
}
func (b *Builder) selectCommon(funcName string, field interface{}, fieldNew interface{}, prefix ...string) *Builder {
b.selectList = append(b.selectList, SelectItem{funcName, prefix, field, fieldNew})
return b

View File

@@ -386,6 +386,13 @@ func testTable(db *base.Db) {
if err2 != nil {
panic(db.DriverName() + " testTable " + "found err:" + err2.Error())
}
var personList2 []Person
subTable := aorm.Db(db).Table(&person)
err3 := aorm.Db(db).Table(&subTable, "o").Debug(false).GetMany(&personList2)
if err3 != nil {
panic(db.DriverName() + " testTable " + "found err:" + err3.Error())
}
}
func testSelect(db *base.Db) {