This commit is contained in:
tangpanqing
2023-01-05 17:43:29 +08:00
parent 4e7c71a910
commit 2bc08359b8
8 changed files with 851 additions and 603 deletions

View File

@@ -32,23 +32,26 @@ const NotBetween = "NOT BETWEEN"
const Raw = "Raw"
// SelectItem 将某子语句重命名为某字段
type SelectItem struct {
Executor **Builder
FieldName string
}
//type SelectItem struct {
// Executor **Builder
// FieldName string
//}
// Builder 查询记录所需要的条件
type Builder struct {
//数据库操作连接
LinkCommon model.LinkCommon
table interface{}
tableAlias string
//查询参数
tableName string
selectList []string
selectList []SelectItem
selectExpList []*SelectItem
groupList []string
whereList []WhereItem
joinList []string
joinList []JoinItem
havingList []WhereItem
orderList []string
offset int
@@ -64,11 +67,11 @@ type Builder struct {
driverName string
}
type WhereItem struct {
Field string
Opt string
Val any
}
//type WhereItem struct {
// Field string
// Opt string
// Val any
//}
func (ex *Builder) Driver(driverName string) *Builder {
ex.driverName = driverName
@@ -313,17 +316,18 @@ func (ex *Builder) GetSqlAndParams() (string, []interface{}) {
}
var paramList []interface{}
tableName := getTableNameByTable(ex.table)
fieldStr, paramList := handleField(ex.selectList, ex.selectExpList, paramList)
whereStr, paramList := ex.handleWhere(ex.whereList, paramList)
joinStr := handleJoin(ex.joinList)
joinStr, paramList := ex.handleJoin(paramList)
groupStr := handleGroup(ex.groupList)
havingStr, paramList := ex.handleHaving(ex.havingList, paramList)
orderStr := handleOrder(ex.orderList)
limitStr, paramList := ex.handleLimit(ex.offset, ex.pageSize, paramList)
lockStr := handleLockForUpdate(ex.isLockForUpdate)
sqlStr := "SELECT " + fieldStr + " FROM " + ex.tableName + joinStr + whereStr + groupStr + havingStr + orderStr + limitStr + lockStr
sqlStr := "SELECT " + fieldStr + " FROM " + tableName + " " + ex.tableAlias + joinStr + whereStr + groupStr + havingStr + orderStr + limitStr + lockStr
if ex.driverName == model.Postgres {
sqlStr = convertToPostgresSql(sqlStr)
@@ -377,7 +381,10 @@ func (ex *Builder) Truncate() (int64, error) {
// Exists 存在某记录
func (ex *Builder) Exists() (bool, error) {
var obj IntStruct
err := ex.Select("1 as c").Limit(0, 1).GetOne(&obj)
ex.selectCommon("", "1 as c", nil)
err := ex.Limit(0, 1).GetOne(&obj)
if err != nil {
return false, err
}
@@ -551,9 +558,12 @@ func (ex *Builder) Debug(isDebug bool) *Builder {
}
// Table 链式操作-从哪个表查询,允许直接写别名,例如 person p
func (ex *Builder) Table(tableName string) *Builder {
ex.tableName = tableName
return ex
func (b *Builder) Table(table interface{}, alias ...string) *Builder {
b.table = table
if len(alias) > 0 {
b.tableAlias = alias[0]
}
return b
}
// GroupBy 链式操作,以某字段进行分组
@@ -563,8 +573,8 @@ func (ex *Builder) GroupBy(fieldName string) *Builder {
}
// OrderBy 链式操作,以某字段进行排序
func (ex *Builder) OrderBy(field string, orderType string) *Builder {
ex.orderList = append(ex.orderList, field+" "+orderType)
func (ex *Builder) OrderBy(field interface{}, orderType string) *Builder {
//ex.orderList = append(ex.orderList, field+" "+orderType)
return ex
}
@@ -592,12 +602,15 @@ func (ex *Builder) LockForUpdate(isLockForUpdate bool) *Builder {
func (ex *Builder) whereAndHaving(where []WhereItem, paramList []any) ([]string, []any) {
var whereList []string
for i := 0; i < len(where); i++ {
prefix := where[i].Prefix
fieldName := getFieldName(where[i].Field)
if "**builder.Builder" == reflect.TypeOf(where[i].Val).String() {
executor := *(**Builder)(unsafe.Pointer(reflect.ValueOf(where[i].Val).Pointer()))
subSql, subParams := executor.GetSqlAndParams()
if where[i].Opt != Raw {
whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+"("+subSql+")")
whereList = append(whereList, prefix+"."+fieldName+" "+where[i].Opt+" "+"("+subSql+")")
paramList = append(paramList, subParams...)
} else {
@@ -605,15 +618,15 @@ func (ex *Builder) whereAndHaving(where []WhereItem, paramList []any) ([]string,
} else {
if where[i].Opt == Eq || where[i].Opt == Ne || where[i].Opt == Gt || where[i].Opt == Ge || where[i].Opt == Lt || where[i].Opt == Le {
if ex.driverName == model.Sqlite3 {
whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+"?")
whereList = append(whereList, prefix+"."+fieldName+" "+where[i].Opt+" "+"?")
} else {
switch where[i].Val.(type) {
case float32:
whereList = append(whereList, ex.getConcatForFloat(where[i].Field, "''")+" "+where[i].Opt+" "+"?")
whereList = append(whereList, ex.getConcatForFloat(prefix+"."+fieldName, "''")+" "+where[i].Opt+" "+"?")
case float64:
whereList = append(whereList, ex.getConcatForFloat(where[i].Field, "''")+" "+where[i].Opt+" "+"?")
whereList = append(whereList, ex.getConcatForFloat(prefix+"."+fieldName, "''")+" "+where[i].Opt+" "+"?")
default:
whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+"?")
whereList = append(whereList, prefix+"."+fieldName+" "+where[i].Opt+" "+"?")
}
}
@@ -622,7 +635,7 @@ func (ex *Builder) whereAndHaving(where []WhereItem, paramList []any) ([]string,
if where[i].Opt == Between || where[i].Opt == NotBetween {
values := toAnyArr(where[i].Val)
whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+"(?) AND (?)")
whereList = append(whereList, prefix+"."+fieldName+" "+where[i].Opt+" "+"(?) AND (?)")
paramList = append(paramList, values...)
}
@@ -640,7 +653,7 @@ func (ex *Builder) whereAndHaving(where []WhereItem, paramList []any) ([]string,
}
}
whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+ex.getConcatForLike(valueStr...))
whereList = append(whereList, prefix+"."+fieldName+" "+where[i].Opt+" "+ex.getConcatForLike(valueStr...))
}
if where[i].Opt == In || where[i].Opt == NotIn {
@@ -650,12 +663,12 @@ func (ex *Builder) whereAndHaving(where []WhereItem, paramList []any) ([]string,
placeholder = append(placeholder, "?")
}
whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+"("+strings.Join(placeholder, ",")+")")
whereList = append(whereList, prefix+"."+fieldName+" "+where[i].Opt+" "+"("+strings.Join(placeholder, ",")+")")
paramList = append(paramList, values...)
}
if where[i].Opt == Raw {
whereList = append(whereList, where[i].Field+fmt.Sprintf("%v", where[i].Val))
whereList = append(whereList, prefix+"."+fieldName+fmt.Sprintf("%v", where[i].Val))
}
}
}