diff --git a/builder/aggregation.go b/builder/aggregation.go index 0c5c389..bdd5c82 100644 --- a/builder/aggregation.go +++ b/builder/aggregation.go @@ -11,9 +11,9 @@ type FloatStruct struct { } // Count 聚合函数-数量 -func (ex *Builder) Count(fieldName string) (int64, error) { +func (b *Builder) Count(fieldName string) (int64, error) { var obj []IntStruct - err := ex.SelectCount(fieldName, "c", "").GetMany(&obj) + err := b.SelectCount(fieldName, "c", "").GetMany(&obj) if err != nil { return 0, err } @@ -22,9 +22,9 @@ func (ex *Builder) Count(fieldName string) (int64, error) { } // Sum 聚合函数-合计 -func (ex *Builder) Sum(fieldName interface{}) (float64, error) { +func (b *Builder) Sum(fieldName interface{}) (float64, error) { var obj []FloatStruct - err := ex.SelectSum(fieldName, "c").GetMany(&obj) + err := b.SelectSum(fieldName, "c").GetMany(&obj) if err != nil { return 0, err } @@ -33,9 +33,9 @@ func (ex *Builder) Sum(fieldName interface{}) (float64, error) { } // Avg 聚合函数-平均值 -func (ex *Builder) Avg(fieldName interface{}) (float64, error) { +func (b *Builder) Avg(fieldName interface{}) (float64, error) { var obj []FloatStruct - err := ex.SelectAvg(fieldName, "c").GetMany(&obj) + err := b.SelectAvg(fieldName, "c").GetMany(&obj) if err != nil { return 0, err } @@ -44,9 +44,9 @@ func (ex *Builder) Avg(fieldName interface{}) (float64, error) { } // Max 聚合函数-最大值 -func (ex *Builder) Max(fieldName interface{}) (float64, error) { +func (b *Builder) Max(fieldName interface{}) (float64, error) { var obj []FloatStruct - err := ex.SelectMax(fieldName, "c").GetMany(&obj) + err := b.SelectMax(fieldName, "c").GetMany(&obj) if err != nil { return 0, err } @@ -55,9 +55,9 @@ func (ex *Builder) Max(fieldName interface{}) (float64, error) { } // Min 聚合函数-最小值 -func (ex *Builder) Min(fieldName interface{}) (float64, error) { +func (b *Builder) Min(fieldName interface{}) (float64, error) { var obj []FloatStruct - err := ex.SelectMin(fieldName, "c").GetMany(&obj) + err := b.SelectMin(fieldName, "c").GetMany(&obj) if err != nil { return 0, err } diff --git a/builder/builder.go b/builder/builder.go index 3632c9a..2609897 100644 --- a/builder/builder.go +++ b/builder/builder.go @@ -7,7 +7,6 @@ import ( "unicode" ) -var RawEq = "rawEq" var TableMap = make(map[uintptr]string) var FieldMap = make(map[uintptr]FieldInfo) @@ -22,11 +21,10 @@ type GroupItem struct { } type WhereItem struct { - FuncName string - Prefix string - Field interface{} - Opt string - Val interface{} + Prefix string + Field interface{} + Opt string + Val interface{} } type SelectItem struct { @@ -185,9 +183,15 @@ func getTableNameByTable(table interface{}) string { if table == nil { panic("当前table不能是nil") } - tableName := TableMap[reflect.ValueOf(table).Pointer()] - strArr := strings.Split(tableName, ".") - return UnderLine(strArr[len(strArr)-1]) + + valueOf := reflect.ValueOf(table) + if reflect.Ptr == valueOf.Kind() { + tableName := TableMap[valueOf.Pointer()] + strArr := strings.Split(tableName, ".") + return UnderLine(strArr[len(strArr)-1]) + } else { + return fmt.Sprintf("%v", table) + } } func getTableNameByField(field interface{}) string { @@ -244,7 +248,7 @@ func getWhereStr(whereList []WhereItem, paramList []interface{}) (string, []inte paramList = append(paramList, whereList[i].Val) } - if whereList[i].Opt == "rawEq" { + if whereList[i].Opt == RawEq { value := getFieldName(whereList[i].Val) sqlList = append(sqlList, prefix+"."+field+"="+value) } diff --git a/builder/crud.go b/builder/crud.go index 0f2d437..d1bb768 100644 --- a/builder/crud.go +++ b/builder/crud.go @@ -32,12 +32,13 @@ const Between = "BETWEEN" const NotBetween = "NOT BETWEEN" const Raw = "Raw" +const RawEq = "RawEq" -// SelectItem 将某子语句重命名为某字段 -//type SelectItem struct { -// Executor **Builder -// FieldName string -//} +// SelectExpItem 将某子语句重命名为某字段 +type SelectExpItem struct { + Executor **Builder + FieldName interface{} +} // Builder 查询记录所需要的条件 type Builder struct { @@ -50,7 +51,7 @@ type Builder struct { //查询参数 tableName string selectList []SelectItem - selectExpList []*SelectItem + selectExpList []*SelectExpItem groupList []GroupItem whereList []WhereItem joinList []JoinItem @@ -58,6 +59,7 @@ type Builder struct { orderList []OrderItem offset int pageSize int + distinct bool isDebug bool isLockForUpdate bool @@ -69,27 +71,29 @@ type Builder struct { driverName string } -//type WhereItem struct { -// Field string -// Opt string -// Val any -//} +func (b *Builder) Distinct(distinct bool) *Builder { + b.distinct = distinct + return b +} -func (ex *Builder) Driver(driverName string) *Builder { - ex.driverName = driverName - return ex +func (b *Builder) Driver(driverName string) *Builder { + b.driverName = driverName + return b +} + +func (b *Builder) getTableNameCommon(typeOf reflect.Type, valueOf reflect.Value) string { + if b.table != nil { + return getTableNameByTable(b.table) + } + + return getTableName(typeOf, valueOf) } // Insert 增加记录 -func (ex *Builder) Insert(dest interface{}) (int64, error) { +func (b *Builder) Insert(dest interface{}) (int64, error) { typeOf := reflect.TypeOf(dest) valueOf := reflect.ValueOf(dest) - //如果没有设置表名 - if ex.tableName == "" { - ex.tableName = getTableName(typeOf, valueOf) - } - var keys []string var paramList []any var place []string @@ -104,24 +108,24 @@ func (ex *Builder) Insert(dest interface{}) (int64, error) { } } - sqlStr := "INSERT INTO " + ex.tableName + " (" + strings.Join(keys, ",") + ") VALUES (" + strings.Join(place, ",") + ")" + sqlStr := "INSERT INTO " + b.getTableNameCommon(typeOf, valueOf) + " (" + strings.Join(keys, ",") + ") VALUES (" + strings.Join(place, ",") + ")" - if ex.driverName == model.Postgres { + if b.driverName == model.Postgres { sqlStr = convertToPostgresSql(sqlStr) } - if ex.driverName == model.Mssql { - return ex.insertForMssqlOrPostgres(sqlStr+"; select ID = convert(bigint, SCOPE_IDENTITY())", paramList...) - } else if ex.driverName == model.Postgres { - return ex.insertForMssqlOrPostgres(sqlStr+" returning id", paramList...) + if b.driverName == model.Mssql { + return b.insertForMssqlOrPostgres(sqlStr+"; select ID = convert(bigint, SCOPE_IDENTITY())", paramList...) + } else if b.driverName == model.Postgres { + return b.insertForMssqlOrPostgres(sqlStr+" returning id", paramList...) } else { - return ex.insertForCommon(sqlStr, paramList...) + return b.insertForCommon(sqlStr, paramList...) } } //对于Mssql,Postgres类型数据库,为了获取最后插入的id,需要改写入为查询 -func (ex *Builder) insertForMssqlOrPostgres(sql string, paramList ...any) (int64, error) { - rows, err := ex.LinkCommon.Query(sql, paramList...) +func (b *Builder) insertForMssqlOrPostgres(sql string, paramList ...any) (int64, error) { + rows, err := b.LinkCommon.Query(sql, paramList...) if err != nil { return 0, err } @@ -134,8 +138,8 @@ func (ex *Builder) insertForMssqlOrPostgres(sql string, paramList ...any) (int64 } //对于非Mssql,Postgres类型数据库,可以直接获取最后插入的id -func (ex *Builder) insertForCommon(sql string, paramList ...any) (int64, error) { - res, err := ex.Exec(sql, paramList...) +func (b *Builder) insertForCommon(sql string, paramList ...any) (int64, error) { + res, err := b.Exec(sql, paramList...) if err != nil { return 0, err } @@ -163,7 +167,7 @@ func convertToPostgresSql(sqlStr string) string { } // InsertBatch 批量增加记录 -func (ex *Builder) InsertBatch(values interface{}) (int64, error) { +func (b *Builder) InsertBatch(values interface{}) (int64, error) { var keys []string var paramList []any @@ -175,11 +179,6 @@ func (ex *Builder) InsertBatch(values interface{}) (int64, error) { } typeOf := reflect.TypeOf(values).Elem().Elem() - //如果没有设置表名 - if ex.tableName == "" { - ex.tableName = getTableName(typeOf, valueOf.Index(0)) - } - for j := 0; j < valueOf.Len(); j++ { var placeItem []string @@ -200,13 +199,13 @@ func (ex *Builder) InsertBatch(values interface{}) (int64, error) { place = append(place, "("+strings.Join(placeItem, ",")+")") } - sqlStr := "INSERT INTO " + ex.tableName + " (" + strings.Join(keys, ",") + ") VALUES " + strings.Join(place, ",") + sqlStr := "INSERT INTO " + b.getTableNameCommon(typeOf, valueOf.Index(0)) + " (" + strings.Join(keys, ",") + ") VALUES " + strings.Join(place, ",") - if ex.driverName == model.Postgres { + if b.driverName == model.Postgres { sqlStr = convertToPostgresSql(sqlStr) } - res, err := ex.Exec(sqlStr, paramList...) + res, err := b.Exec(sqlStr, paramList...) if err != nil { return 0, err } @@ -220,10 +219,10 @@ func (ex *Builder) InsertBatch(values interface{}) (int64, error) { } // GetRows 获取行操作 -func (ex *Builder) GetRows() (*sql.Rows, error) { - sqlStr, paramList := ex.GetSqlAndParams() +func (b *Builder) GetRows() (*sql.Rows, error) { + sqlStr, paramList := b.GetSqlAndParams() - smt, errSmt := ex.LinkCommon.Prepare(sqlStr) + smt, errSmt := b.LinkCommon.Prepare(sqlStr) if errSmt != nil { return nil, errSmt } @@ -238,8 +237,8 @@ func (ex *Builder) GetRows() (*sql.Rows, error) { } // GetMany 查询记录(新) -func (ex *Builder) GetMany(values interface{}) error { - rows, errRows := ex.GetRows() +func (b *Builder) GetMany(values interface{}) error { + rows, errRows := b.GetRows() defer rows.Close() if errRows != nil { return errRows @@ -273,10 +272,10 @@ func (ex *Builder) GetMany(values interface{}) error { } // GetOne 查询某一条记录 -func (ex *Builder) GetOne(obj interface{}) error { - ex.Limit(0, 1) +func (b *Builder) GetOne(obj interface{}) error { + b.Limit(0, 1) - rows, errRows := ex.GetRows() + rows, errRows := b.GetRows() defer rows.Close() if errRows != nil { return errRows @@ -306,35 +305,35 @@ func (ex *Builder) GetOne(obj interface{}) error { } // RawSql 执行原始的sql语句 -func (ex *Builder) RawSql(sql string, paramList ...interface{}) *Builder { - ex.sql = sql - ex.paramList = paramList - return ex +func (b *Builder) RawSql(sql string, paramList ...interface{}) *Builder { + b.sql = sql + b.paramList = paramList + return b } -func (ex *Builder) GetSqlAndParams() (string, []interface{}) { - if ex.sql != "" { - return ex.sql, ex.paramList +func (b *Builder) GetSqlAndParams() (string, []interface{}) { + if b.sql != "" { + return b.sql, b.paramList } var paramList []interface{} - tableName := getTableNameByTable(ex.table) - fieldStr, paramList := ex.handleField(paramList) - whereStr, paramList := ex.handleWhere(paramList) - joinStr, paramList := ex.handleJoin(paramList) - groupStr, paramList := ex.handleGroup(paramList) - havingStr, paramList := ex.handleHaving(paramList) - orderStr, paramList := ex.handleOrder(paramList) - limitStr, paramList := ex.handleLimit(ex.offset, ex.pageSize, paramList) - lockStr := handleLockForUpdate(ex.isLockForUpdate) + tableName := getTableNameByTable(b.table) + fieldStr, paramList := b.handleField(paramList) + whereStr, paramList := b.handleWhere(paramList) + joinStr, paramList := b.handleJoin(paramList) + groupStr, paramList := b.handleGroup(paramList) + havingStr, paramList := b.handleHaving(paramList) + orderStr, paramList := b.handleOrder(paramList) + limitStr, paramList := b.handleLimit(b.offset, b.pageSize, paramList) + lockStr := handleLockForUpdate(b.isLockForUpdate) - sqlStr := "SELECT " + fieldStr + " FROM " + tableName + " " + ex.tableAlias + joinStr + whereStr + groupStr + havingStr + orderStr + limitStr + lockStr + sqlStr := "SELECT " + fieldStr + " FROM " + tableName + " " + b.tableAlias + joinStr + whereStr + groupStr + havingStr + orderStr + limitStr + lockStr - if ex.driverName == model.Postgres { + if b.driverName == model.Postgres { sqlStr = convertToPostgresSql(sqlStr) } - if ex.isDebug { + if b.isDebug { fmt.Println(sqlStr) //fmt.Println(paramList...) } @@ -343,47 +342,64 @@ func (ex *Builder) GetSqlAndParams() (string, []interface{}) { } // Update 更新记录 -func (ex *Builder) Update(dest interface{}) (int64, error) { - var paramList []any - setStr, paramList := ex.handleSet(dest, paramList) - whereStr, paramList := ex.handleWhere(paramList) - sqlStr := "UPDATE " + ex.tableName + setStr + whereStr +func (b *Builder) Update(dest interface{}) (int64, error) { + typeOf := reflect.TypeOf(dest) + valueOf := reflect.ValueOf(dest) - if ex.driverName == model.Postgres { + var paramList []any + setStr, paramList := b.handleSet(typeOf, valueOf, paramList) + whereStr, paramList := b.handleWhere(paramList) + sqlStr := "UPDATE " + b.getTableNameCommon(typeOf, valueOf) + setStr + whereStr + + if b.driverName == model.Postgres { sqlStr = convertToPostgresSql(sqlStr) } - return ex.ExecAffected(sqlStr, paramList...) + return b.ExecAffected(sqlStr, paramList...) } // Delete 删除记录 -func (ex *Builder) Delete() (int64, error) { - var paramList []any - whereStr, paramList := ex.handleWhere(paramList) - sqlStr := "DELETE FROM " + getTableNameByTable(ex.table) + whereStr +func (b *Builder) Delete(destList ...interface{}) (int64, error) { + tableName := "" - if ex.driverName == model.Postgres { + if len(destList) > 0 { + b.Where(destList[0]) + + typeOf := reflect.TypeOf(destList[0]) + valueOf := reflect.ValueOf(destList[0]) + tableName = b.getTableNameCommon(typeOf, valueOf) + } + + if tableName == "" { + tableName = getTableNameByTable(b.table) + } + + var paramList []any + whereStr, paramList := b.handleWhere(paramList) + sqlStr := "DELETE FROM " + tableName + whereStr + + if b.driverName == model.Postgres { sqlStr = convertToPostgresSql(sqlStr) } - return ex.ExecAffected(sqlStr, paramList...) + return b.ExecAffected(sqlStr, paramList...) } // Truncate 清空记录, sqlite3不支持此操作 -func (ex *Builder) Truncate() (int64, error) { - sqlStr := "TRUNCATE TABLE " + ex.tableName - if ex.driverName == model.Sqlite3 { - sqlStr = "DELETE FROM " + ex.tableName +func (b *Builder) Truncate() (int64, error) { + sqlStr := "TRUNCATE TABLE " + getTableNameByTable(b.table) + if b.driverName == model.Sqlite3 { + sqlStr = "DELETE FROM " + getTableNameByTable(b.table) } - return ex.ExecAffected(sqlStr) + return b.ExecAffected(sqlStr) } // Exists 存在某记录 -func (ex *Builder) Exists() (bool, error) { +func (b *Builder) Exists() (bool, error) { var obj IntStruct - err := ex.selectCommon("", "1 as c", nil, "").Limit(0, 1).GetOne(&obj) + err := b.selectCommon("", "1 as c", nil, "").Limit(0, 1).GetOne(&obj) if err != nil { return false, err } @@ -396,18 +412,18 @@ func (ex *Builder) Exists() (bool, error) { } // DoesntExist 不存在某记录 -func (ex *Builder) DoesntExist() (bool, error) { - isE, err := ex.Exists() +func (b *Builder) DoesntExist() (bool, error) { + isE, err := b.Exists() return !isE, err } // Value 字段值 -func (ex *Builder) Value(field interface{}, dest interface{}) error { - ex.Select(field).Limit(0, 1) +func (b *Builder) Value(field interface{}, dest interface{}) error { + b.Select(field).Limit(0, 1) fieldName := getFieldName(field) - rows, errRows := ex.GetRows() + rows, errRows := b.GetRows() defer rows.Close() if errRows != nil { return errRows @@ -442,10 +458,10 @@ func (ex *Builder) Value(field interface{}, dest interface{}) error { } // Pluck 获取某一列的值 -func (ex *Builder) Pluck(fieldName interface{}, values interface{}) error { - ex.Select(fieldName) +func (b *Builder) Pluck(fieldName interface{}, values interface{}) error { + b.Select(fieldName) - rows, errRows := ex.GetRows() + rows, errRows := b.GetRows() defer rows.Close() if errRows != nil { return errRows @@ -484,45 +500,45 @@ func (ex *Builder) Pluck(fieldName interface{}, values interface{}) error { } // Increment 某字段自增 -func (ex *Builder) Increment(field interface{}, step int) (int64, error) { +func (b *Builder) Increment(field interface{}, step int) (int64, error) { var paramList []any paramList = append(paramList, step) - whereStr, paramList := ex.handleWhere(paramList) - sqlStr := "UPDATE " + getTableNameByTable(ex.table) + " SET " + getFieldName(field) + "=" + getFieldName(field) + "+?" + whereStr + whereStr, paramList := b.handleWhere(paramList) + sqlStr := "UPDATE " + getTableNameByTable(b.table) + " SET " + getFieldName(field) + "=" + getFieldName(field) + "+?" + whereStr - if ex.driverName == model.Postgres { + if b.driverName == model.Postgres { sqlStr = convertToPostgresSql(sqlStr) } - return ex.ExecAffected(sqlStr, paramList...) + return b.ExecAffected(sqlStr, paramList...) } // Decrement 某字段自减 -func (ex *Builder) Decrement(field interface{}, step int) (int64, error) { +func (b *Builder) Decrement(field interface{}, step int) (int64, error) { var paramList []any paramList = append(paramList, step) - whereStr, paramList := ex.handleWhere(paramList) - sqlStr := "UPDATE " + getTableNameByTable(ex.table) + " SET " + getFieldName(field) + "=" + getFieldName(field) + "-?" + whereStr + whereStr, paramList := b.handleWhere(paramList) + sqlStr := "UPDATE " + getTableNameByTable(b.table) + " SET " + getFieldName(field) + "=" + getFieldName(field) + "-?" + whereStr - if ex.driverName == model.Postgres { + if b.driverName == model.Postgres { sqlStr = convertToPostgresSql(sqlStr) } - return ex.ExecAffected(sqlStr, paramList...) + return b.ExecAffected(sqlStr, paramList...) } // Exec 通用执行-新增,更新,删除 -func (ex *Builder) Exec(sqlStr string, args ...interface{}) (sql.Result, error) { - if ex.driverName == model.Postgres { +func (b *Builder) Exec(sqlStr string, args ...interface{}) (sql.Result, error) { + if b.driverName == model.Postgres { sqlStr = convertToPostgresSql(sqlStr) } - if ex.isDebug { + if b.isDebug { fmt.Println(sqlStr) //fmt.Println(args...) } - smt, err1 := ex.LinkCommon.Prepare(sqlStr) + smt, err1 := b.LinkCommon.Prepare(sqlStr) if err1 != nil { return nil, err1 } @@ -533,13 +549,13 @@ func (ex *Builder) Exec(sqlStr string, args ...interface{}) (sql.Result, error) return nil, err2 } - //ex.clear() + //b.clear() return res, nil } // ExecAffected 通用执行-更新,删除 -func (ex *Builder) ExecAffected(sqlStr string, args ...interface{}) (int64, error) { - res, err := ex.Exec(sqlStr, args...) +func (b *Builder) ExecAffected(sqlStr string, args ...interface{}) (int64, error) { + res, err := b.Exec(sqlStr, args...) if err != nil { return 0, err } @@ -553,9 +569,9 @@ func (ex *Builder) ExecAffected(sqlStr string, args ...interface{}) (int64, erro } // Debug 链式操作-是否开启调试,打印sql -func (ex *Builder) Debug(isDebug bool) *Builder { - ex.isDebug = isDebug - return ex +func (b *Builder) Debug(isDebug bool) *Builder { + b.isDebug = isDebug + return b } // Table 链式操作-从哪个表查询,允许直接写别名,例如 person p @@ -568,12 +584,12 @@ func (b *Builder) Table(table interface{}, alias ...string) *Builder { } // GroupBy 链式操作,以某字段进行分组 -func (ex *Builder) GroupBy(field interface{}, prefix ...string) *Builder { - ex.groupList = append(ex.groupList, GroupItem{ +func (b *Builder) GroupBy(field interface{}, prefix ...string) *Builder { + b.groupList = append(b.groupList, GroupItem{ Prefix: getPrefixByField(field, prefix...), Field: field, }) - return ex + return b } // OrderBy 链式操作,以某字段进行排序 @@ -588,27 +604,27 @@ func (b *Builder) OrderBy(field interface{}, orderType string, prefix ...string) } // Limit 链式操作,分页 -func (ex *Builder) Limit(offset int, pageSize int) *Builder { - ex.offset = offset - ex.pageSize = pageSize - return ex +func (b *Builder) Limit(offset int, pageSize int) *Builder { + b.offset = offset + b.pageSize = pageSize + return b } // Page 链式操作,分页 -func (ex *Builder) Page(pageNum int, pageSize int) *Builder { - ex.offset = (pageNum - 1) * pageSize - ex.pageSize = pageSize - return ex +func (b *Builder) Page(pageNum int, pageSize int) *Builder { + b.offset = (pageNum - 1) * pageSize + b.pageSize = pageSize + return b } // LockForUpdate 加锁, sqlte3不支持此操作 -func (ex *Builder) LockForUpdate(isLockForUpdate bool) *Builder { - ex.isLockForUpdate = isLockForUpdate - return ex +func (b *Builder) LockForUpdate(isLockForUpdate bool) *Builder { + b.isLockForUpdate = isLockForUpdate + return b } //拼接SQL,查询与筛选通用操作 -func (ex *Builder) whereAndHaving(where []WhereItem, paramList []any) ([]string, []any) { +func (b *Builder) whereAndHaving(where []WhereItem, paramList []any) ([]string, []any) { var whereList []string for i := 0; i < len(where); i++ { allFieldName := "" @@ -616,9 +632,6 @@ func (ex *Builder) whereAndHaving(where []WhereItem, paramList []any) ([]string, allFieldName += where[i].Prefix + "." } allFieldName += getFieldName(where[i].Field) - if where[i].FuncName != "" { - allFieldName = where[i].FuncName + "(" + allFieldName + ")" - } if "**builder.Builder" == reflect.TypeOf(where[i].Val).String() { executor := *(**Builder)(unsafe.Pointer(reflect.ValueOf(where[i].Val).Pointer())) @@ -627,19 +640,17 @@ func (ex *Builder) whereAndHaving(where []WhereItem, paramList []any) ([]string, if where[i].Opt != Raw { whereList = append(whereList, allFieldName+" "+where[i].Opt+" "+"("+subSql+")") paramList = append(paramList, subParams...) - } else { - } } else { if where[i].Opt == Eq || where[i].Opt == Ne || where[i].Opt == Gt || where[i].Opt == Ge || where[i].Opt == Lt || where[i].Opt == Le { - if ex.driverName == model.Sqlite3 { + if b.driverName == model.Sqlite3 { whereList = append(whereList, allFieldName+" "+where[i].Opt+" "+"?") } else { switch where[i].Val.(type) { case float32: - whereList = append(whereList, ex.getConcatForFloat(allFieldName, "''")+" "+where[i].Opt+" "+"?") + whereList = append(whereList, b.getConcatForFloat(allFieldName, "''")+" "+where[i].Opt+" "+"?") case float64: - whereList = append(whereList, ex.getConcatForFloat(allFieldName, "''")+" "+where[i].Opt+" "+"?") + whereList = append(whereList, b.getConcatForFloat(allFieldName, "''")+" "+where[i].Opt+" "+"?") default: whereList = append(whereList, allFieldName+" "+where[i].Opt+" "+"?") } @@ -668,7 +679,7 @@ func (ex *Builder) whereAndHaving(where []WhereItem, paramList []any) ([]string, } } - whereList = append(whereList, allFieldName+" "+where[i].Opt+" "+ex.getConcatForLike(valueStr...)) + whereList = append(whereList, allFieldName+" "+where[i].Opt+" "+b.getConcatForLike(valueStr...)) } if where[i].Opt == In || where[i].Opt == NotIn { @@ -685,6 +696,10 @@ func (ex *Builder) whereAndHaving(where []WhereItem, paramList []any) ([]string, if where[i].Opt == Raw { whereList = append(whereList, allFieldName+fmt.Sprintf("%v", where[i].Val)) } + + if where[i].Opt == RawEq { + whereList = append(whereList, allFieldName+Eq+getPrefixByField(where[i].Val)+"."+getFieldName(where[i].Val)) + } } } return whereList, paramList @@ -758,18 +773,18 @@ func getScans(columnNameList []string, fieldNameMap map[string]int, destValue re return scans } -func (ex *Builder) getConcatForFloat(vars ...string) string { - if ex.driverName == model.Sqlite3 { +func (b *Builder) getConcatForFloat(vars ...string) string { + if b.driverName == model.Sqlite3 { return strings.Join(vars, "||") - } else if ex.driverName == model.Postgres { + } else if b.driverName == model.Postgres { return vars[0] } else { return "CONCAT(" + strings.Join(vars, ",") + ")" } } -func (ex *Builder) getConcatForLike(vars ...string) string { - if ex.driverName == model.Sqlite3 || ex.driverName == model.Postgres { +func (b *Builder) getConcatForLike(vars ...string) string { + if b.driverName == model.Sqlite3 || b.driverName == model.Postgres { return strings.Join(vars, "||") } else { return "CONCAT(" + strings.Join(vars, ",") + ")" diff --git a/builder/handle.go b/builder/handle.go index 0d0ae28..589947b 100644 --- a/builder/handle.go +++ b/builder/handle.go @@ -8,22 +8,22 @@ import ( ) //拼接SQL,字段相关 -func (ex *Builder) handleField(paramList []any) (string, []any) { - if len(ex.selectList) == 0 && len(ex.selectExpList) == 0 { - return "*", paramList +func (b *Builder) handleField(paramList []any) (string, []any) { + fieldStr := "" + if b.distinct { + fieldStr += "DISTINCT " + } + + if len(b.selectList) == 0 && len(b.selectExpList) == 0 { + fieldStr += "*" + return fieldStr, paramList } - //处理子语句 - //for i := 0; i < len(selectExpList); i++ { - // executor := *(selectExpList[i].Executor) - // subSql, subParamList := executor.GetSqlAndParams() - // selectList = append(selectList, "("+subSql+") AS "+selectExpList[i].FieldName) - // paramList = append(paramList, subParamList...) - //} var strList []string - for i := 0; i < len(ex.selectList); i++ { - selectItem := ex.selectList[i] + //处理一般的参数 + for i := 0; i < len(b.selectList); i++ { + selectItem := b.selectList[i] str := "" if selectItem.FuncName != "" { @@ -50,28 +50,35 @@ func (ex *Builder) handleField(paramList []any) (string, []any) { strList = append(strList, str) } - return strings.Join(strList, ","), paramList + //处理子语句 + for i := 0; i < len(b.selectExpList); i++ { + executor := *(b.selectExpList[i].Executor) + subSql, subParamList := executor.GetSqlAndParams() + strList = append(strList, "("+subSql+") AS "+getFieldName(b.selectExpList[i].FieldName)) + paramList = append(paramList, subParamList...) + } + + fieldStr += strings.Join(strList, ",") + return fieldStr, paramList } //拼接SQL,查询条件 -func (ex *Builder) handleWhere(paramList []any) (string, []any) { - if len(ex.whereList) == 0 { +func (b *Builder) handleWhere(paramList []any) (string, []any) { + if len(b.whereList) == 0 { return "", paramList } - strList, paramList := ex.whereAndHaving(ex.whereList, paramList) + strList, paramList := b.whereAndHaving(b.whereList, paramList) return " WHERE " + strings.Join(strList, " AND "), paramList } //拼接SQL,更新信息 -func (ex *Builder) handleSet(dest interface{}, paramList []any) (string, []any) { - typeOf := reflect.TypeOf(dest) - valueOf := reflect.ValueOf(dest) +func (b *Builder) handleSet(typeOf reflect.Type, valueOf reflect.Value, paramList []any) (string, []any) { //如果没有设置表名 - if ex.tableName == "" { - ex.tableName = getTableName(typeOf, valueOf) + if b.tableName == "" { + b.tableName = getTableName(typeOf, valueOf) } var keys []string @@ -110,52 +117,52 @@ func (b *Builder) handleJoin(paramList []interface{}) (string, []interface{}) { } //拼接SQL,结果分组 -func (ex *Builder) handleGroup(paramList []any) (string, []any) { - if len(ex.groupList) == 0 { +func (b *Builder) handleGroup(paramList []any) (string, []any) { + if len(b.groupList) == 0 { return "", paramList } var groupList []string - for i := 0; i < len(ex.groupList); i++ { - groupList = append(groupList, ex.groupList[i].Prefix+"."+getFieldName(ex.groupList[i].Field)) + for i := 0; i < len(b.groupList); i++ { + groupList = append(groupList, b.groupList[i].Prefix+"."+getFieldName(b.groupList[i].Field)) } return " GROUP BY " + strings.Join(groupList, ","), paramList } //拼接SQL,结果筛选 -func (ex *Builder) handleHaving(paramList []any) (string, []any) { - if len(ex.havingList) == 0 { +func (b *Builder) handleHaving(paramList []any) (string, []any) { + if len(b.havingList) == 0 { return "", paramList } - strList, paramList := ex.whereAndHaving(ex.havingList, paramList) + strList, paramList := b.whereAndHaving(b.havingList, paramList) return " Having " + strings.Join(strList, " AND "), paramList } //拼接SQL,结果排序 -func (ex *Builder) handleOrder(paramList []any) (string, []any) { - if len(ex.orderList) == 0 { +func (b *Builder) handleOrder(paramList []any) (string, []any) { + if len(b.orderList) == 0 { return "", paramList } var orderList []string - for i := 0; i < len(ex.orderList); i++ { - orderList = append(orderList, ex.orderList[i].Prefix+"."+getFieldName(ex.orderList[i].Field)+" "+ex.orderList[i].OrderType) + for i := 0; i < len(b.orderList); i++ { + orderList = append(orderList, b.orderList[i].Prefix+"."+getFieldName(b.orderList[i].Field)+" "+b.orderList[i].OrderType) } return " ORDER BY " + strings.Join(orderList, ","), paramList } //拼接SQL,分页相关 Postgres数据库分页数量在前偏移在后,其他数据库偏移量在前分页数量在后,另外Mssql数据库的关键词是offset...next -func (ex *Builder) handleLimit(offset int, pageSize int, paramList []any) (string, []any) { +func (b *Builder) handleLimit(offset int, pageSize int, paramList []any) (string, []any) { if 0 == pageSize { return "", paramList } str := "" - if ex.driverName == model.Postgres { + if b.driverName == model.Postgres { paramList = append(paramList, pageSize) paramList = append(paramList, offset) @@ -165,7 +172,7 @@ func (ex *Builder) handleLimit(offset int, pageSize int, paramList []any) (strin paramList = append(paramList, pageSize) str = " Limit ?,? " - if ex.driverName == model.Mssql { + if b.driverName == model.Mssql { str = " offset ? rows fetch next ? rows only " } } diff --git a/builder/having.go b/builder/having.go index 09d9ac3..7dcbe0f 100644 --- a/builder/having.go +++ b/builder/having.go @@ -6,13 +6,13 @@ import ( ) // Having 链式操作,以对象作为筛选条件 -func (ex *Builder) Having(dest interface{}) *Builder { +func (b *Builder) Having(dest interface{}) *Builder { typeOf := reflect.TypeOf(dest) valueOf := reflect.ValueOf(dest) //如果没有设置表名 - if ex.tableName == "" { - ex.tableName = getTableName(typeOf, valueOf) + if b.tableName == "" { + b.tableName = getTableName(typeOf, valueOf) } for i := 0; i < typeOf.Elem().NumField(); i++ { @@ -20,158 +20,80 @@ func (ex *Builder) Having(dest interface{}) *Builder { if isNotNull { key := helper.UnderLine(typeOf.Elem().Field(i).Name) val := valueOf.Elem().Field(i).Field(0).Field(0).Interface() - ex.havingList = append(ex.havingList, WhereItem{Field: key, Opt: Eq, Val: val}) + b.havingList = append(b.havingList, WhereItem{Field: key, Opt: Eq, Val: val}) } } - return ex + return b } // HavingArr 链式操作,以数组作为筛选条件 -func (ex *Builder) HavingArr(havingList []WhereItem) *Builder { - ex.havingList = append(ex.havingList, havingList...) - return ex +func (b *Builder) HavingArr(havingList []WhereItem) *Builder { + b.havingList = append(b.havingList, havingList...) + return b } -func (ex *Builder) HavingEq(funcName string, field interface{}, val interface{}, prefix ...string) *Builder { - ex.havingList = append(ex.havingList, WhereItem{ - FuncName: funcName, - Prefix: getPrefixByField(field, prefix...), - Field: field, - Opt: Eq, - Val: val, - }) - return ex +func (b *Builder) HavingEq(field interface{}, val interface{}, prefix ...string) *Builder { + b.havingList = append(b.havingList, WhereItem{"", field, Eq, val}) + return b } -func (ex *Builder) HavingNe(funcName string, field interface{}, val interface{}, prefix ...string) *Builder { - ex.havingList = append(ex.havingList, WhereItem{ - FuncName: funcName, - Prefix: getPrefixByField(field, prefix...), - Field: field, - Opt: Ne, - Val: val, - }) - return ex +func (b *Builder) HavingNe(field interface{}, val interface{}, prefix ...string) *Builder { + b.havingList = append(b.havingList, WhereItem{"", field, Ne, val}) + return b } -func (ex *Builder) HavingGt(field interface{}, val interface{}, prefix ...string) *Builder { - ex.havingList = append(ex.havingList, WhereItem{ - FuncName: "", - Prefix: "", - Field: field, - Opt: Gt, - Val: val, - }) - return ex +func (b *Builder) HavingGt(field interface{}, val interface{}, prefix ...string) *Builder { + b.havingList = append(b.havingList, WhereItem{"", field, Gt, val}) + return b } -func (ex *Builder) HavingGe(funcName string, field interface{}, val interface{}, prefix ...string) *Builder { - ex.havingList = append(ex.havingList, WhereItem{ - FuncName: funcName, - Prefix: getPrefixByField(field, prefix...), - Field: field, - Opt: Ge, - Val: val, - }) - return ex +func (b *Builder) HavingGe(field interface{}, val interface{}, prefix ...string) *Builder { + b.havingList = append(b.havingList, WhereItem{"", field, Ge, val}) + return b } -func (ex *Builder) HavingLt(funcName string, field interface{}, val interface{}, prefix ...string) *Builder { - ex.havingList = append(ex.havingList, WhereItem{ - FuncName: funcName, - Prefix: getPrefixByField(field, prefix...), - Field: field, - Opt: Lt, - Val: val, - }) - return ex +func (b *Builder) HavingLt(field interface{}, val interface{}, prefix ...string) *Builder { + b.havingList = append(b.havingList, WhereItem{"", field, Lt, val}) + return b } -func (ex *Builder) HavingLe(funcName string, field interface{}, val interface{}, prefix ...string) *Builder { - ex.havingList = append(ex.havingList, WhereItem{ - FuncName: funcName, - Prefix: getPrefixByField(field, prefix...), - Field: field, - Opt: Le, - Val: val, - }) - return ex +func (b *Builder) HavingLe(field interface{}, val interface{}, prefix ...string) *Builder { + b.havingList = append(b.havingList, WhereItem{"", field, Le, val}) + return b } -func (ex *Builder) HavingIn(funcName string, field interface{}, val interface{}, prefix ...string) *Builder { - ex.havingList = append(ex.havingList, WhereItem{ - FuncName: funcName, - Prefix: getPrefixByField(field, prefix...), - Field: field, - Opt: In, - Val: val, - }) - return ex +func (b *Builder) HavingIn(field interface{}, val interface{}, prefix ...string) *Builder { + b.havingList = append(b.havingList, WhereItem{"", field, In, val}) + return b } -func (ex *Builder) HavingNotIn(funcName string, field interface{}, val interface{}, prefix ...string) *Builder { - ex.havingList = append(ex.havingList, WhereItem{ - FuncName: funcName, - Prefix: getPrefixByField(field, prefix...), - Field: field, - Opt: NotIn, - Val: val, - }) - return ex +func (b *Builder) HavingNotIn(field interface{}, val interface{}, prefix ...string) *Builder { + b.havingList = append(b.havingList, WhereItem{"", field, NotIn, val}) + return b } -func (ex *Builder) HavingBetween(funcName string, field interface{}, val interface{}, prefix ...string) *Builder { - ex.havingList = append(ex.havingList, WhereItem{ - FuncName: funcName, - Prefix: getPrefixByField(field, prefix...), - Field: field, - Opt: Between, - Val: val, - }) - return ex +func (b *Builder) HavingBetween(field interface{}, val interface{}, prefix ...string) *Builder { + b.havingList = append(b.havingList, WhereItem{"", field, Between, val}) + return b } -func (ex *Builder) HavingNotBetween(funcName string, field interface{}, val interface{}, prefix ...string) *Builder { - ex.havingList = append(ex.havingList, WhereItem{ - FuncName: funcName, - Prefix: getPrefixByField(field, prefix...), - Field: field, - Opt: NotBetween, - Val: val, - }) - return ex +func (b *Builder) HavingNotBetween(field interface{}, val interface{}, prefix ...string) *Builder { + b.havingList = append(b.havingList, WhereItem{"", field, NotBetween, val}) + return b } -func (ex *Builder) HavingLike(funcName string, field interface{}, val interface{}, prefix ...string) *Builder { - ex.havingList = append(ex.havingList, WhereItem{ - FuncName: funcName, - Prefix: getPrefixByField(field, prefix...), - Field: field, - Opt: Like, - Val: val, - }) - return ex +func (b *Builder) HavingLike(field interface{}, val interface{}, prefix ...string) *Builder { + b.havingList = append(b.havingList, WhereItem{"", field, Like, val}) + return b } -func (ex *Builder) HavingNotLike(funcName string, field interface{}, val interface{}, prefix ...string) *Builder { - ex.havingList = append(ex.havingList, WhereItem{ - FuncName: funcName, - Prefix: getPrefixByField(field, prefix...), - Field: field, - Opt: NotLike, - Val: val, - }) - return ex +func (b *Builder) HavingNotLike(field interface{}, val interface{}, prefix ...string) *Builder { + b.havingList = append(b.havingList, WhereItem{"", field, NotLike, val}) + return b } -func (ex *Builder) HavingRaw(funcName string, field interface{}, val interface{}, prefix ...string) *Builder { - ex.havingList = append(ex.havingList, WhereItem{ - FuncName: funcName, - Prefix: getPrefixByField(field, prefix...), - Field: field, - Opt: Raw, - Val: val, - }) - return ex +func (b *Builder) HavingRaw(field interface{}, val interface{}, prefix ...string) *Builder { + b.havingList = append(b.havingList, WhereItem{"", field, Raw, val}) + return b } diff --git a/builder/select.go b/builder/select.go index b815ae6..067dd4f 100644 --- a/builder/select.go +++ b/builder/select.go @@ -50,10 +50,10 @@ func (b *Builder) selectCommon(funcName string, field interface{}, fieldNew inte } // SelectExp 链式操作-表达式 -func (b *Builder) SelectExp(dbSub **Builder, fieldName string) *Builder { - //ex.selectExpList = append(ex.selectExpList, &SelectItem{ - // Executor: dbSub, - // FieldName: fieldName, - //}) +func (b *Builder) SelectExp(dbSub **Builder, fieldName interface{}) *Builder { + b.selectExpList = append(b.selectExpList, &SelectExpItem{ + Executor: dbSub, + FieldName: fieldName, + }) return b } diff --git a/builder/where.go b/builder/where.go index 657a833..e9536c7 100644 --- a/builder/where.go +++ b/builder/where.go @@ -34,66 +34,71 @@ func (b *Builder) WhereArr(whereList []WhereItem) *Builder { } func (b *Builder) WhereEq(field interface{}, val interface{}, prefix ...string) *Builder { - b.whereList = append(b.whereList, WhereItem{"", getPrefixByField(field, prefix...), field, Eq, val}) + b.whereList = append(b.whereList, WhereItem{getPrefixByField(field, prefix...), field, Eq, val}) return b } func (b *Builder) WhereNe(field interface{}, val interface{}, prefix ...string) *Builder { - b.whereList = append(b.whereList, WhereItem{"", getPrefixByField(field, prefix...), field, Ne, val}) + b.whereList = append(b.whereList, WhereItem{getPrefixByField(field, prefix...), field, Ne, val}) return b } func (b *Builder) WhereGt(field interface{}, val interface{}, prefix ...string) *Builder { - b.whereList = append(b.whereList, WhereItem{"", getPrefixByField(field, prefix...), field, Gt, val}) + b.whereList = append(b.whereList, WhereItem{getPrefixByField(field, prefix...), field, Gt, val}) return b } func (b *Builder) WhereGe(field interface{}, val interface{}, prefix ...string) *Builder { - b.whereList = append(b.whereList, WhereItem{"", getPrefixByField(field, prefix...), field, Ge, val}) + b.whereList = append(b.whereList, WhereItem{getPrefixByField(field, prefix...), field, Ge, val}) return b } func (b *Builder) WhereLt(field interface{}, val interface{}, prefix ...string) *Builder { - b.whereList = append(b.whereList, WhereItem{"", getPrefixByField(field, prefix...), field, Lt, val}) + b.whereList = append(b.whereList, WhereItem{getPrefixByField(field, prefix...), field, Lt, val}) return b } func (b *Builder) WhereLe(field interface{}, val interface{}, prefix ...string) *Builder { - b.whereList = append(b.whereList, WhereItem{"", getPrefixByField(field, prefix...), field, Le, val}) + b.whereList = append(b.whereList, WhereItem{getPrefixByField(field, prefix...), field, Le, val}) return b } func (b *Builder) WhereIn(field interface{}, val interface{}, prefix ...string) *Builder { - b.whereList = append(b.whereList, WhereItem{"", getPrefixByField(field, prefix...), field, In, val}) + b.whereList = append(b.whereList, WhereItem{getPrefixByField(field, prefix...), field, In, val}) return b } func (b *Builder) WhereNotIn(field interface{}, val interface{}, prefix ...string) *Builder { - b.whereList = append(b.whereList, WhereItem{"", getPrefixByField(field, prefix...), field, NotIn, val}) + b.whereList = append(b.whereList, WhereItem{getPrefixByField(field, prefix...), field, NotIn, val}) return b } func (b *Builder) WhereBetween(field interface{}, val interface{}, prefix ...string) *Builder { - b.whereList = append(b.whereList, WhereItem{"", getPrefixByField(field, prefix...), field, Between, val}) + b.whereList = append(b.whereList, WhereItem{getPrefixByField(field, prefix...), field, Between, val}) return b } func (b *Builder) WhereNotBetween(field interface{}, val interface{}, prefix ...string) *Builder { - b.whereList = append(b.whereList, WhereItem{"", getPrefixByField(field, prefix...), field, NotBetween, val}) + b.whereList = append(b.whereList, WhereItem{getPrefixByField(field, prefix...), field, NotBetween, val}) return b } func (b *Builder) WhereLike(field interface{}, val interface{}, prefix ...string) *Builder { - b.whereList = append(b.whereList, WhereItem{"", getPrefixByField(field, prefix...), field, Like, val}) + b.whereList = append(b.whereList, WhereItem{getPrefixByField(field, prefix...), field, Like, val}) return b } func (b *Builder) WhereNotLike(field interface{}, val interface{}, prefix ...string) *Builder { - b.whereList = append(b.whereList, WhereItem{"", getPrefixByField(field, prefix...), field, NotLike, val}) + b.whereList = append(b.whereList, WhereItem{getPrefixByField(field, prefix...), field, NotLike, val}) return b } func (b *Builder) WhereRaw(field interface{}, val interface{}, prefix ...string) *Builder { - b.whereList = append(b.whereList, WhereItem{"", getPrefixByField(field, prefix...), field, Raw, val}) + b.whereList = append(b.whereList, WhereItem{getPrefixByField(field, prefix...), field, Raw, val}) + return b +} + +func (b *Builder) WhereRawEq(field interface{}, val interface{}, prefix ...string) *Builder { + b.whereList = append(b.whereList, WhereItem{getPrefixByField(field, prefix...), field, RawEq, val}) return b } diff --git a/test/aorm_test.go b/test/aorm_test.go index 8c70241..8826548 100644 --- a/test/aorm_test.go +++ b/test/aorm_test.go @@ -62,11 +62,12 @@ var person = Person{} var article = Article{} var articleVO = ArticleVO{} var personAge = PersonAge{} +var personWithArticleCount = PersonWithArticleCount{} func TestAll(t *testing.T) { aorm.Store(&person, &article) aorm.Store(&articleVO) - aorm.Store(&personAge) + aorm.Store(&personAge, &personWithArticleCount) dbList := make([]aorm.DbContent, 0) //dbList = append(dbList, testSqlite3Connect()) @@ -77,16 +78,15 @@ func TestAll(t *testing.T) { for i := 0; i < len(dbList); i++ { dbItem := dbList[i] - //testMigrate(dbItem.DriverName, dbItem.DbLink) - - //testShowCreateTable(dbItem.DriverName, dbItem.DbLink) + testMigrate(dbItem.DriverName, dbItem.DbLink) + testShowCreateTable(dbItem.DriverName, dbItem.DbLink) id := testInsert(dbItem.DriverName, dbItem.DbLink) testInsertBatch(dbItem.DriverName, dbItem.DbLink) - testGetOne(dbItem.DriverName, dbItem.DbLink, id) testGetMany(dbItem.DriverName, dbItem.DbLink) testUpdate(dbItem.DriverName, dbItem.DbLink, id) + isExists := testExists(dbItem.DriverName, dbItem.DbLink, id) if isExists != true { panic("应该存在,但是数据库不存在") @@ -99,17 +99,16 @@ func TestAll(t *testing.T) { } id2 := testInsert(dbItem.DriverName, dbItem.DbLink) - //testTable(dbItem.DriverName, dbItem.DbLink) - //testSelect(dbItem.DriverName, dbItem.DbLink) - //testSelectWithSub(dbItem.DriverName, dbItem.DbLink) - //testWhereWithSub(dbItem.DriverName, dbItem.DbLink) - //testWhere(dbItem.DriverName, dbItem.DbLink) + testTable(dbItem.DriverName, dbItem.DbLink) + testSelect(dbItem.DriverName, dbItem.DbLink) + testSelectWithSub(dbItem.DriverName, dbItem.DbLink) + testWhereWithSub(dbItem.DriverName, dbItem.DbLink) + testWhere(dbItem.DriverName, dbItem.DbLink) testJoin(dbItem.DriverName, dbItem.DbLink) testJoinWithAlias(dbItem.DriverName, dbItem.DbLink) testGroupBy(dbItem.DriverName, dbItem.DbLink) testHaving(dbItem.DriverName, dbItem.DbLink) - return testOrderBy(dbItem.DriverName, dbItem.DbLink) testLimit(dbItem.DriverName, dbItem.DbLink) testLock(dbItem.DriverName, dbItem.DbLink, id2) @@ -125,12 +124,13 @@ func TestAll(t *testing.T) { testAvg(dbItem.DriverName, dbItem.DbLink) testMin(dbItem.DriverName, dbItem.DbLink) testMax(dbItem.DriverName, dbItem.DbLink) - // - //testExec(dbItem.DriverName, dbItem.DbLink) - // - //testTransaction(dbItem.DriverName, dbItem.DbLink) - //testTruncate(dbItem.DriverName, dbItem.DbLink) - //testHelper(dbItem.DriverName, dbItem.DbLink) + + testDistinct(dbItem.DriverName, dbItem.DbLink) + + testExec(dbItem.DriverName, dbItem.DbLink) + + testTransaction(dbItem.DriverName, dbItem.DbLink) + testTruncate(dbItem.DriverName, dbItem.DbLink) } } @@ -318,6 +318,13 @@ func testDelete(driver string, db *sql.DB, id int64) { if errDelete != nil { panic(driver + "testDelete" + "found err") } + + _, errDelete2 := aorm.Use(db).Driver(driver).Debug(true).Delete(&Person{ + Id: null.IntFrom(id), + }) + if errDelete2 != nil { + panic(driver + "testDelete" + "found err") + } } func testExists(driver string, db *sql.DB, id int64) bool { @@ -328,69 +335,69 @@ func testExists(driver string, db *sql.DB, id int64) bool { return exists } -// -//func testTable(driver string, db *sql.DB) { -// _, err := aorm.Use(db).Debug(true).Driver(driver).Table("person_1").Insert(&Person{Name: null.StringFrom("Cherry")}) -// if err != nil { -// panic(driver + " testTable " + "found err:" + err.Error()) -// } -//} -// -//func testSelect(driver string, db *sql.DB) { -// var listByFiled []Person -// err := aorm.Use(db).Debug(true).Driver(driver).Select("name,age").Where(&Person{Age: null.IntFrom(18)}).GetMany(&listByFiled) -// if err != nil { -// panic(driver + " testSelect " + "found err:" + err.Error()) -// } -//} -// -//func testSelectWithSub(driver string, db *sql.DB) { -// var listByFiled []PersonWithArticleCount -// -// sub := aorm.Sub().Table("article").SelectCount("id", "article_count_tem").WhereRaw("person_id", "=person.id") -// err := aorm.Use(db).Debug(true). -// Driver(driver). -// SelectExp(&sub, "article_count"). -// Select("*"). -// Where(&Person{Age: null.IntFrom(18)}). -// GetMany(&listByFiled) -// -// if err != nil { -// panic(driver + " testSelectWithSub " + "found err:" + err.Error()) -// } -//} -// -//func testWhereWithSub(driver string, db *sql.DB) { -// var listByFiled []Person -// -// sub := aorm.Sub().Table("article").Select("person_id").GroupBy("person_id").HavingGt("count(person_id)", 0) -// -// err := aorm.Use(db).Debug(true). -// Table("person"). -// Driver(driver). -// WhereIn("id", &sub). -// GetMany(&listByFiled) -// -// if err != nil { -// panic(driver + " testWhereWithSub " + "found err:" + err.Error()) -// } -//} -// -//func testWhere(driver string, db *sql.DB) { -// var listByWhere []Person -// -// var where1 []builder.WhereItem -// where1 = append(where1, builder.WhereItem{Field: "type", Opt: builder.Eq, Val: 0}) -// where1 = append(where1, builder.WhereItem{Field: "age", Opt: builder.In, Val: []int{18, 20}}) -// where1 = append(where1, builder.WhereItem{Field: "money", Opt: builder.Between, Val: []float64{100.1, 200.9}}) -// where1 = append(where1, builder.WhereItem{Field: "money", Opt: builder.Eq, Val: 100.15}) -// where1 = append(where1, builder.WhereItem{Field: "name", Opt: builder.Like, Val: []string{"%", "li", "%"}}) -// -// err := aorm.Use(db).Debug(true).Driver(driver).Table("person").WhereArr(where1).GetMany(&listByWhere) -// if err != nil { -// panic(driver + "testWhere" + "found err") -// } -//} +func testTable(driver string, db *sql.DB) { + _, err := aorm.Use(db).Debug(true).Driver(driver).Table("person_1").Insert(&Person{Name: null.StringFrom("Cherry")}) + if err != nil { + panic(driver + " testTable " + "found err:" + err.Error()) + } + + _, err2 := aorm.Use(db).Debug(true).Driver(driver).Table(&person).Insert(&Person{Name: null.StringFrom("Cherry")}) + if err2 != nil { + panic(driver + " testTable " + "found err:" + err2.Error()) + } +} + +func testSelect(driver string, db *sql.DB) { + var listByFiled []Person + err := aorm.Use(db).Debug(true).Driver(driver).Table(&person).Select(&person.Name).Select(&person.Age).WhereEq(&person.Age, 18).GetMany(&listByFiled) + if err != nil { + panic(driver + " testSelect " + "found err:" + err.Error()) + } +} + +func testSelectWithSub(driver string, db *sql.DB) { + var listByFiled []PersonWithArticleCount + + sub := aorm.Sub().Table(&article).SelectCount(&article.Id, "article_count_tem").WhereRawEq(&article.PersonId, &person.Id) + err := aorm.Use(db).Debug(true). + Driver(driver). + SelectExp(&sub, &personWithArticleCount.ArticleCount). + SelectAll(&person). + Table(&person). + WhereEq(&person.Age, 18). + GetMany(&listByFiled) + + if err != nil { + panic(driver + " testSelectWithSub " + "found err:" + err.Error()) + } +} + +func testWhereWithSub(driver string, db *sql.DB) { + var listByFiled []Person + sub := aorm.Sub().Table(&article).SelectCount(&article.PersonId, "count_person_id").GroupBy(&article.PersonId).HavingGt("count_person_id", 0) + err := aorm.Use(db).Debug(true). + Table(&person). + Driver(driver). + WhereIn(&person.Id, &sub). + GetMany(&listByFiled) + if err != nil { + panic(driver + " testWhereWithSub " + "found err:" + err.Error()) + } +} + +func testWhere(driver string, db *sql.DB) { + var listByWhere []Person + err := aorm.Use(db).Debug(true).Driver(driver).Table(&person).WhereArr([]builder.WhereItem{ + {Field: &person.Type, Opt: builder.Eq, Val: 0}, + {Field: &person.Age, Opt: builder.In, Val: []int{18, 20}}, + {Field: &person.Money, Opt: builder.Between, Val: []float64{100.1, 200.9}}, + {Field: &person.Money, Opt: builder.Eq, Val: 100.15}, + {Field: &person.Name, Opt: builder.Like, Val: []string{"%", "li", "%"}}, + }).GetMany(&listByWhere) + if err != nil { + panic(driver + "testWhere" + "found err") + } +} func testJoin(driver string, db *sql.DB) { var list2 []ArticleVO @@ -601,7 +608,6 @@ func testPluck(driver string, db *sql.DB) { } } -// func testCount(driver string, db *sql.DB) { _, err := aorm.Use(db).Debug(true).Table(&person).WhereEq(&person.Age, 18).Driver(driver).Count("*") if err != nil { @@ -609,7 +615,6 @@ func testCount(driver string, db *sql.DB) { } } -// func testSum(driver string, db *sql.DB) { _, err := aorm.Use(db).Debug(true).Table(&person).WhereEq(&person.Age, 18).Driver(driver).Sum(&person.Age) if err != nil { @@ -638,82 +643,67 @@ func testMax(driver string, db *sql.DB) { } } -// -//func testExec(driver string, db *sql.DB) { -// _, err := aorm.Use(db).Debug(true).Driver(driver).Exec("UPDATE person SET name = ? WHERE id=?", "Bob", 3) -// if err != nil { -// panic(driver + "testExec" + "found err") -// } -//} -// -//func testTransaction(driver string, db *sql.DB) { -// tx, _ := db.Begin() -// -// id, errInsert := aorm.Use(tx).Debug(true).Driver(driver).Insert(&Person{ -// Name: null.StringFrom("Alice"), -// }) -// -// if errInsert != nil { -// tx.Rollback() -// panic(driver + " testTransaction " + "found err:" + errInsert.Error()) -// return -// } -// -// _, errCount := aorm.Use(tx).Debug(true).Driver(driver).Where(&Person{ -// Id: null.IntFrom(id), -// }).Count("*") -// if errCount != nil { -// tx.Rollback() -// panic(driver + "testTransaction" + "found err") -// return -// } -// -// var person Person -// errPerson := aorm.Use(tx).Debug(true).Where(&Person{ -// Id: null.IntFrom(id), -// }).Driver(driver).OrderBy("id", "DESC").GetOne(&person) -// if errPerson != nil { -// tx.Rollback() -// panic(driver + "testTransaction" + "found err") -// return -// } -// -// _, errUpdate := aorm.Use(tx).Debug(true).Driver(driver).Where(&Person{ -// Id: null.IntFrom(id), -// }).Update(&Person{ -// Name: null.StringFrom("Bob"), -// }) -// -// if errUpdate != nil { -// tx.Rollback() -// panic(driver + "testTransaction" + "found err") -// return -// } -// -// tx.Commit() -//} -// -//func testTruncate(driver string, db *sql.DB) { -// _, err := aorm.Use(db).Debug(true).Driver(driver).Table("person").Truncate() -// if err != nil { -// panic(driver + " testTruncate " + "found err") -// } -//} +func testDistinct(driver string, db *sql.DB) { + var listByFiled []Person + err := aorm.Use(db).Debug(true).Driver(driver).Distinct(true).Table(&person).Select(&person.Name).Select(&person.Age).WhereEq(&person.Age, 18).GetMany(&listByFiled) + if err != nil { + panic(driver + " testSelect " + "found err:" + err.Error()) + } +} -//func testHelper(driver string, db *sql.DB) { -// var list2 []ArticleVO -// var where2 []builder.WhereItem -// where2 = append(where2, builder.WhereItem{Field: "o.type", Opt: builder.Eq, Val: 0}) -// where2 = append(where2, builder.WhereItem{Field: "p.age", Opt: builder.In, Val: []int{18, 20}}) -// err := aorm.Use(db).Debug(true). -// Table("article o"). -// LeftJoin("person p", helper.Ul("p.id=o.personId")). -// Select("o.*"). -// Select(helper.Ul("p.name as personName")). -// WhereArr(where2). -// Driver(driver). -// GetMany(&list2) -// if err != nil { -// panic(driver + "testHelper" + "found err") -// } -//} +func testExec(driver string, db *sql.DB) { + _, err := aorm.Use(db).Debug(true).Driver(driver).Exec("UPDATE person SET name = ? WHERE person.id=?", "Bob", 3) + if err != nil { + panic(driver + "testExec" + "found err") + } +} + +func testTransaction(driver string, db *sql.DB) { + tx, _ := db.Begin() + + id, errInsert := aorm.Use(tx).Debug(true).Driver(driver).Insert(&Person{ + Name: null.StringFrom("Alice"), + }) + + if errInsert != nil { + tx.Rollback() + panic(driver + " testTransaction " + "found err:" + errInsert.Error()) + return + } + + _, errCount := aorm.Use(tx).Debug(true).Driver(driver).Table(&person).WhereEq(&person.Id, id).Count("*") + if errCount != nil { + tx.Rollback() + panic(driver + "testTransaction" + "found err") + return + } + + var personItem Person + errPerson := aorm.Use(tx).Debug(true).Driver(driver).Table(&person).WhereEq(&person.Id, id).OrderBy(&person.Id, builder.Desc).GetOne(&personItem) + if errPerson != nil { + tx.Rollback() + panic(driver + "testTransaction" + "found err") + return + } + + _, errUpdate := aorm.Use(tx).Debug(true).Driver(driver).Where(&Person{ + Id: null.IntFrom(id), + }).Update(&Person{ + Name: null.StringFrom("Bob"), + }) + + if errUpdate != nil { + tx.Rollback() + panic(driver + "testTransaction" + "found err") + return + } + + tx.Commit() +} + +func testTruncate(driver string, db *sql.DB) { + _, err := aorm.Use(db).Debug(true).Driver(driver).Table(&person).Truncate() + if err != nil { + panic(driver + " testTruncate " + "found err") + } +}