From bc96e8e9b379e80159f56cd6a923b0430298f16b Mon Sep 17 00:00:00 2001 From: tangpanqing Date: Fri, 16 Dec 2022 18:54:10 +0800 Subject: [PATCH] add method SelectCount .... --- aorm.go | 6 +++++ crud.go | 62 +++++++++++++++++++++++++++++++++++++++++++---- test/aorm_test.go | 4 +++ 3 files changed, 67 insertions(+), 5 deletions(-) diff --git a/aorm.go b/aorm.go index eec9aae..73ba2ba 100644 --- a/aorm.go +++ b/aorm.go @@ -17,6 +17,7 @@ type Executor struct { LinkCommon LinkCommon TableName string SelectList []string + SelectExpList []ExpItem GroupList []string WhereList []WhereItem JoinList []string @@ -29,6 +30,11 @@ type Executor struct { OpinionList []OpinionItem } +type ExpItem struct { + Executor *Executor + FieldName string +} + // Use 使用数据库连接,或者事务 func Use(linkCommon LinkCommon) *Executor { executor := &Executor{ diff --git a/crud.go b/crud.go index 13ccfe3..14de836 100644 --- a/crud.go +++ b/crud.go @@ -25,6 +25,8 @@ const NotLike = "NOT LIKE" const Between = "BETWEEN" const NotBetween = "NOT BETWEEN" +const Raw = "Raw" + type WhereItem struct { Field string Opt string @@ -217,7 +219,7 @@ func (db *Executor) GetOne(obj interface{}) error { func (db *Executor) getSqlAndParams() (string, []any) { var paramList []any - fieldStr := handleField(db.SelectList) + fieldStr, paramList := handleField(db.SelectList, db.SelectExpList, paramList) whereStr, paramList := handleWhere(db.WhereList, paramList) joinStr := handleJoin(db.JoinList) groupStr := handleGroup(db.GroupList) @@ -548,6 +550,45 @@ func (db *Executor) Select(f string) *Executor { return db } +// SelectCount 链式操作-count(field) as field_new +func (db *Executor) SelectCount(f string, fieldNew string) *Executor { + db.SelectList = append(db.SelectList, "count("+f+") AS "+fieldNew) + return db +} + +// SelectSum 链式操作-sum(field) as field_new +func (db *Executor) SelectSum(f string, fieldNew string) *Executor { + db.SelectList = append(db.SelectList, "sum("+f+") AS "+fieldNew) + return db +} + +// SelectMin 链式操作-min(field) as field_new +func (db *Executor) SelectMin(f string, fieldNew string) *Executor { + db.SelectList = append(db.SelectList, "min("+f+") AS "+fieldNew) + return db +} + +// SelectMax 链式操作-max(field) as field_new +func (db *Executor) SelectMax(f string, fieldNew string) *Executor { + db.SelectList = append(db.SelectList, "max("+f+") AS "+fieldNew) + return db +} + +// SelectAvg 链式操作-avg(field) as field_new +func (db *Executor) SelectAvg(f string, fieldNew string) *Executor { + db.SelectList = append(db.SelectList, "avg("+f+") AS "+fieldNew) + return db +} + +// SelectExp 链式操作-表达式 +func (db *Executor) SelectExp(db2 *Executor, fieldNew string) *Executor { + db.SelectExpList = append(db.SelectExpList, ExpItem{ + Executor: db2, + FieldName: fieldNew, + }) + return db +} + // Table 链式操作-从哪个表查询,允许直接写别名,例如 person p func (db *Executor) Table(tableName string) *Executor { db.TableName = tableName @@ -661,12 +702,19 @@ func (db *Executor) LockForUpdate(isLockForUpdate bool) *Executor { } //拼接SQL,字段相关 -func handleField(selectList []string) string { - if len(selectList) == 0 { - return "*" +func handleField(selectList []string, selectExpList []ExpItem, paramList []any) (string, []any) { + if len(selectList) == 0 && len(selectExpList) == 0 { + return "*", paramList } - return strings.Join(selectList, ",") + //处理子语句 + for i := 0; i < len(selectExpList); i++ { + subSql, subParamList := selectExpList[i].Executor.getSqlAndParams() + selectList = append(selectList, "("+subSql+") AS "+selectExpList[i].FieldName) + paramList = append(paramList, subParamList...) + } + + return strings.Join(selectList, ","), paramList } //拼接SQL,查询条件 @@ -816,6 +864,10 @@ func whereAndHaving(where []WhereItem, paramList []any) ([]string, []any) { whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+"("+strings.Join(placeholder, ",")+")") paramList = append(paramList, values...) } + + //if where[i].Opt == Raw { + // whereList = append(whereList, where[i].Field+" "+fmt.Sprintf("%v", where[i].Val)) + //} } return whereList, paramList diff --git a/test/aorm_test.go b/test/aorm_test.go index 9871835..850427f 100644 --- a/test/aorm_test.go +++ b/test/aorm_test.go @@ -59,6 +59,7 @@ func TestAll(t *testing.T) { id2 := testInsert(db) testTable(db) testSelect(db) + return testWhere(db) testJoin(db) testGroupBy(db) @@ -240,6 +241,9 @@ func testSelect(db *sql.DB) { var listByFiled []Person aorm.Use(db).Debug(true).Select("name,age").Where(&Person{Age: aorm.IntFrom(18)}).GetMany(&listByFiled) + + sub := aorm.Use(db).Table("test_table").SelectCount("test_name", "test_name_count") + aorm.Use(db).Debug(true).SelectExp(sub, "test_name_count_new").Select("name,age").Where(&Person{Age: aorm.IntFrom(18)}).GetMany(&listByFiled) } func testWhere(db *sql.DB) {