mirror of
https://github.com/tangpanqing/aorm.git
synced 2025-12-24 12:13:03 +08:00
support table with expression
This commit is contained in:
@@ -99,10 +99,6 @@ func getPrefixByField(valueOf reflect.Value, prefix ...string) string {
|
||||
|
||||
//getTableNameByTable 根据传入的表信息,获取表名
|
||||
func getTableNameByTable(table interface{}) string {
|
||||
if table == nil {
|
||||
panic("当前table不能是nil")
|
||||
}
|
||||
|
||||
valueOf := reflect.ValueOf(table)
|
||||
if reflect.Ptr == valueOf.Kind() {
|
||||
return getTableMap(valueOf.Pointer())
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const Desc = "DESC"
|
||||
@@ -284,7 +283,10 @@ func (b *Builder) Update(dest interface{}) (int64, error) {
|
||||
|
||||
var args []any
|
||||
setStr, args := b.handleSet(typeOf, valueOf, args)
|
||||
whereStr, args := b.handleWhere(args, false)
|
||||
whereStr, args, err := b.handleWhere(args, false)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
query := "UPDATE " + b.getTableNameCommon(typeOf, valueOf) + setStr + whereStr
|
||||
|
||||
return b.execAffected(query, args...)
|
||||
@@ -303,11 +305,17 @@ func (b *Builder) Delete(destList ...interface{}) (int64, error) {
|
||||
}
|
||||
|
||||
if tableName == "" {
|
||||
if b.table == nil {
|
||||
return 0, errors.New("表名不能为空")
|
||||
}
|
||||
tableName = getTableNameByTable(b.table)
|
||||
}
|
||||
|
||||
var args []any
|
||||
whereStr, args := b.handleWhere(args, false)
|
||||
whereStr, args, err := b.handleWhere(args, false)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
query := "DELETE FROM " + tableName + whereStr
|
||||
|
||||
return b.execAffected(query, args...)
|
||||
@@ -348,6 +356,10 @@ func (b *Builder) LockForUpdate(isLockForUpdate bool) *Builder {
|
||||
|
||||
// Truncate 清空记录
|
||||
func (b *Builder) Truncate() (int64, error) {
|
||||
if b.table == nil {
|
||||
return 0, errors.New("表名不能为空")
|
||||
}
|
||||
|
||||
query := ""
|
||||
if b.Link.DriverName() == driver.Sqlite3 {
|
||||
query = "DELETE FROM " + getTableNameByTable(b.table)
|
||||
@@ -367,7 +379,10 @@ func (b *Builder) RawSql(query string, args ...interface{}) *Builder {
|
||||
|
||||
// GetRows 获取行操作
|
||||
func (b *Builder) GetRows() (*sql.Rows, error) {
|
||||
query, args := b.GetSqlAndParams()
|
||||
query, args, err := b.GetSqlAndParams()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if b.Link.DriverName() == driver.Postgres {
|
||||
query = convertToPostgresSql(query)
|
||||
@@ -419,7 +434,7 @@ func (b *Builder) Exec() (sql.Result, error) {
|
||||
}
|
||||
|
||||
//拼接SQL,查询与筛选通用操作
|
||||
func (b *Builder) whereAndHaving(where []WhereItem, args []any, isFromHaving bool, needPrefix bool) ([]string, []any) {
|
||||
func (b *Builder) whereAndHaving(where []WhereItem, args []any, isFromHaving bool, needPrefix bool) ([]string, []any, error) {
|
||||
var whereList []string
|
||||
for i := 0; i < len(where); i++ {
|
||||
valueOfField := reflect.ValueOf(where[i].Field)
|
||||
@@ -445,8 +460,11 @@ func (b *Builder) whereAndHaving(where []WhereItem, args []any, isFromHaving boo
|
||||
}
|
||||
|
||||
if "**builder.Builder" == reflect.TypeOf(where[i].Val).String() {
|
||||
subBuilder := *(**Builder)(unsafe.Pointer(reflect.ValueOf(where[i].Val).Pointer()))
|
||||
subSql, subParams := subBuilder.GetSqlAndParams()
|
||||
subBuilder := *(**Builder)(reflect.ValueOf(where[i].Val).UnsafePointer())
|
||||
subSql, subParams, err := subBuilder.GetSqlAndParams()
|
||||
if err != nil {
|
||||
return whereList, args, err
|
||||
}
|
||||
|
||||
if where[i].Opt != Raw {
|
||||
whereList = append(whereList, allFieldName+" "+where[i].Opt+" "+"("+subSql+")")
|
||||
@@ -518,7 +536,7 @@ func (b *Builder) whereAndHaving(where []WhereItem, args []any, isFromHaving boo
|
||||
}
|
||||
}
|
||||
}
|
||||
return whereList, args
|
||||
return whereList, args, nil
|
||||
}
|
||||
|
||||
func (b *Builder) getConcatForFloat(vars ...string) string {
|
||||
@@ -547,25 +565,40 @@ func (b *Builder) getTableNameCommon(typeOf reflect.Type, valueOf reflect.Value)
|
||||
return getTableNameByReflect(typeOf, valueOf)
|
||||
}
|
||||
|
||||
func (b *Builder) GetSqlAndParams() (string, []interface{}) {
|
||||
func (b *Builder) GetSqlAndParams() (string, []interface{}, error) {
|
||||
if b.query != "" {
|
||||
return b.query, b.args
|
||||
return b.query, b.args, nil
|
||||
}
|
||||
|
||||
var args []interface{}
|
||||
fieldStr, args := b.handleSelect(args)
|
||||
tableName := getTableNameByTable(b.table)
|
||||
selectStr, args, err := b.handleSelect(args)
|
||||
if err != nil {
|
||||
return "", args, err
|
||||
}
|
||||
|
||||
tableStr, args, err := b.handleTable(args)
|
||||
if err != nil {
|
||||
return "", args, err
|
||||
}
|
||||
joinStr, args := b.handleJoin(args)
|
||||
whereStr, args := b.handleWhere(args, true)
|
||||
whereStr, args, err := b.handleWhere(args, true)
|
||||
if err != nil {
|
||||
return "", args, err
|
||||
}
|
||||
|
||||
groupStr, args := b.handleGroup(args)
|
||||
havingStr, args := b.handleHaving(args)
|
||||
havingStr, args, err := b.handleHaving(args)
|
||||
if err != nil {
|
||||
return "", args, err
|
||||
}
|
||||
|
||||
orderStr, args := b.handleOrder(args)
|
||||
limitStr, args := b.handleLimit(args)
|
||||
lockStr := b.handleLockForUpdate()
|
||||
|
||||
query := "SELECT " + fieldStr + " FROM " + tableName + " " + b.tableAlias + joinStr + whereStr + groupStr + havingStr + orderStr + limitStr + lockStr
|
||||
query := selectStr + tableStr + joinStr + whereStr + groupStr + havingStr + orderStr + limitStr + lockStr
|
||||
|
||||
return query, args
|
||||
return query, args, nil
|
||||
}
|
||||
|
||||
// execAffected 通用执行-更新,删除
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package builder
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/tangpanqing/aorm/driver"
|
||||
"reflect"
|
||||
"strings"
|
||||
@@ -29,7 +31,7 @@ func handleSelectWith(selectItem SelectItem) string {
|
||||
}
|
||||
|
||||
//拼接SQL,字段相关
|
||||
func (b *Builder) handleSelect(paramList []any) (string, []any) {
|
||||
func (b *Builder) handleSelect(paramList []any) (string, []any, error) {
|
||||
fieldStr := ""
|
||||
if b.distinct {
|
||||
fieldStr += "DISTINCT "
|
||||
@@ -37,7 +39,7 @@ func (b *Builder) handleSelect(paramList []any) (string, []any) {
|
||||
|
||||
if len(b.selectList) == 0 && len(b.selectExpList) == 0 {
|
||||
fieldStr += "*"
|
||||
return fieldStr, paramList
|
||||
return "SELECT " + fieldStr, paramList, nil
|
||||
}
|
||||
|
||||
var strList []string
|
||||
@@ -59,24 +61,63 @@ func (b *Builder) handleSelect(paramList []any) (string, []any) {
|
||||
//处理子语句
|
||||
for i := 0; i < len(b.selectExpList); i++ {
|
||||
subBuilder := *(b.selectExpList[i].Builder)
|
||||
subSql, subParamList := subBuilder.GetSqlAndParams()
|
||||
subSql, subParamList, err := subBuilder.GetSqlAndParams()
|
||||
if err != nil {
|
||||
return "", paramList, err
|
||||
}
|
||||
strList = append(strList, "("+subSql+") AS "+getFieldNameByField(b.selectExpList[i].FieldName))
|
||||
paramList = append(paramList, subParamList...)
|
||||
}
|
||||
|
||||
fieldStr += strings.Join(strList, ",")
|
||||
return fieldStr, paramList
|
||||
return "SELECT " + fieldStr, paramList, nil
|
||||
}
|
||||
|
||||
func (b *Builder) handleTable(paramList []any) (string, []any, error) {
|
||||
if b.table == nil {
|
||||
return "", paramList, errors.New("表不能为空")
|
||||
}
|
||||
|
||||
var tableName string
|
||||
|
||||
valueOf := reflect.ValueOf(b.table)
|
||||
if reflect.Ptr == valueOf.Kind() {
|
||||
|
||||
if "**builder.Builder" != valueOf.Type().String() {
|
||||
tableName = getTableMap(valueOf.Pointer())
|
||||
} else {
|
||||
if b.tableAlias == "" {
|
||||
return "", paramList, errors.New("别名不能为空")
|
||||
}
|
||||
|
||||
subBuilder := *(**Builder)(valueOf.UnsafePointer())
|
||||
subSql, subParamList, err := subBuilder.GetSqlAndParams()
|
||||
if err != nil {
|
||||
return "", paramList, err
|
||||
}
|
||||
|
||||
tableName = "(" + subSql + ")"
|
||||
paramList = append(paramList, subParamList...)
|
||||
}
|
||||
} else {
|
||||
tableName = fmt.Sprintf("%v", b.table)
|
||||
}
|
||||
|
||||
return " FROM " + tableName + " " + b.tableAlias, paramList, nil
|
||||
}
|
||||
|
||||
//拼接SQL,查询条件
|
||||
func (b *Builder) handleWhere(paramList []any, needPrefix bool) (string, []any) {
|
||||
func (b *Builder) handleWhere(paramList []any, needPrefix bool) (string, []any, error) {
|
||||
if len(b.whereList) == 0 {
|
||||
return "", paramList
|
||||
return "", paramList, nil
|
||||
}
|
||||
|
||||
strList, paramList := b.whereAndHaving(b.whereList, paramList, false, needPrefix)
|
||||
strList, paramList, err := b.whereAndHaving(b.whereList, paramList, false, needPrefix)
|
||||
if err != nil {
|
||||
return "", paramList, nil
|
||||
}
|
||||
|
||||
return " WHERE " + strings.Join(strList, " AND "), paramList
|
||||
return " WHERE " + strings.Join(strList, " AND "), paramList, nil
|
||||
}
|
||||
|
||||
//拼接SQL,更新信息
|
||||
@@ -149,14 +190,17 @@ func (b *Builder) handleGroup(paramList []any) (string, []any) {
|
||||
}
|
||||
|
||||
//拼接SQL,结果筛选
|
||||
func (b *Builder) handleHaving(paramList []any) (string, []any) {
|
||||
func (b *Builder) handleHaving(paramList []any) (string, []any, error) {
|
||||
if len(b.havingList) == 0 {
|
||||
return "", paramList
|
||||
return "", paramList, nil
|
||||
}
|
||||
|
||||
strList, paramList := b.whereAndHaving(b.havingList, paramList, true, true)
|
||||
strList, paramList, err := b.whereAndHaving(b.havingList, paramList, true, true)
|
||||
if err != nil {
|
||||
return "", paramList, err
|
||||
}
|
||||
|
||||
return " Having " + strings.Join(strList, " AND "), paramList
|
||||
return " Having " + strings.Join(strList, " AND "), paramList, nil
|
||||
}
|
||||
|
||||
//拼接SQL,结果排序
|
||||
|
||||
@@ -1,12 +1,20 @@
|
||||
package builder
|
||||
|
||||
import "errors"
|
||||
|
||||
// Increment 某字段自增
|
||||
func (b *Builder) Increment(field interface{}, step int) (int64, error) {
|
||||
var vars []any
|
||||
vars = append(vars, step)
|
||||
whereStr, vars := b.handleWhere(vars, false)
|
||||
query := "UPDATE " + getTableNameByTable(b.table) + " SET " + getFieldNameByField(field) + "=" + getFieldNameByField(field) + "+?" + whereStr
|
||||
whereStr, vars, err := b.handleWhere(vars, false)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if b.table == nil {
|
||||
return 0, errors.New("表名不能为空")
|
||||
}
|
||||
query := "UPDATE " + getTableNameByTable(b.table) + " SET " + getFieldNameByField(field) + "=" + getFieldNameByField(field) + "+?" + whereStr
|
||||
return b.execAffected(query, vars...)
|
||||
}
|
||||
|
||||
@@ -14,8 +22,14 @@ func (b *Builder) Increment(field interface{}, step int) (int64, error) {
|
||||
func (b *Builder) Decrement(field interface{}, step int) (int64, error) {
|
||||
var vars []any
|
||||
vars = append(vars, step)
|
||||
whereStr, vars := b.handleWhere(vars, false)
|
||||
query := "UPDATE " + getTableNameByTable(b.table) + " SET " + getFieldNameByField(field) + "=" + getFieldNameByField(field) + "-?" + whereStr
|
||||
whereStr, vars, err := b.handleWhere(vars, false)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if b.table == nil {
|
||||
return 0, errors.New("表名不能为空")
|
||||
}
|
||||
query := "UPDATE " + getTableNameByTable(b.table) + " SET " + getFieldNameByField(field) + "=" + getFieldNameByField(field) + "-?" + whereStr
|
||||
return b.execAffected(query, vars...)
|
||||
}
|
||||
|
||||
@@ -38,6 +38,16 @@ func (b *Builder) SelectAvg(field interface{}, fieldNew interface{}, prefix ...s
|
||||
return b.selectCommon("Avg", field, fieldNew, prefix...)
|
||||
}
|
||||
|
||||
// SelectConcat 链式操作-concat(field) as field_new
|
||||
func (b *Builder) SelectConcat(field interface{}, fieldNew interface{}, prefix ...string) *Builder {
|
||||
return b.selectCommon("concat", field, fieldNew, prefix...)
|
||||
}
|
||||
|
||||
// SelectGroupConcat 链式操作-group_concat(field) as field_new
|
||||
func (b *Builder) SelectGroupConcat(field interface{}, fieldNew interface{}, prefix ...string) *Builder {
|
||||
return b.selectCommon("group_concat", field, fieldNew, prefix...)
|
||||
}
|
||||
|
||||
func (b *Builder) selectCommon(funcName string, field interface{}, fieldNew interface{}, prefix ...string) *Builder {
|
||||
b.selectList = append(b.selectList, SelectItem{funcName, prefix, field, fieldNew})
|
||||
return b
|
||||
|
||||
@@ -386,6 +386,13 @@ func testTable(db *base.Db) {
|
||||
if err2 != nil {
|
||||
panic(db.DriverName() + " testTable " + "found err:" + err2.Error())
|
||||
}
|
||||
|
||||
var personList2 []Person
|
||||
subTable := aorm.Db(db).Table(&person)
|
||||
err3 := aorm.Db(db).Table(&subTable, "o").Debug(false).GetMany(&personList2)
|
||||
if err3 != nil {
|
||||
panic(db.DriverName() + " testTable " + "found err:" + err3.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func testSelect(db *base.Db) {
|
||||
|
||||
Reference in New Issue
Block a user