diff --git a/crud.go b/crud.go index 3f2c517..b7b2672 100644 --- a/crud.go +++ b/crud.go @@ -105,6 +105,86 @@ func (db *Executor) GetMany(values interface{}) error { return nil } +// GetManyNew 查询记录(新) +func (db *Executor) GetManyNew(values interface{}) error { + destSlice := reflect.Indirect(reflect.ValueOf(values)) + destType := destSlice.Type().Elem() + + var paramList []any + fieldStr := handleField(db.SelectList) + whereStr, paramList := handleWhere(db.WhereList, paramList) + joinStr := handleJoin(db.JoinList) + groupStr := handleGroup(db.GroupList) + havingStr, paramList := handleHaving(db.HavingList, paramList) + orderStr := handleOrder(db.OrderList) + limitStr, paramList := handleLimit(db.Offset, db.PageSize, paramList) + lockStr := handleLockForUpdate(db.IsLockForUpdate) + + sqlStr := "SELECT " + fieldStr + " FROM " + db.TableName + joinStr + whereStr + groupStr + havingStr + orderStr + limitStr + lockStr + + if db.IsDebug { + fmt.Println(sqlStr) + fmt.Println(paramList...) + } + + smt, err1 := db.LinkCommon.Prepare(sqlStr) + if err1 != nil { + fmt.Println(err1) + } + defer smt.Close() + + rows, err2 := smt.Query(paramList...) + if err2 != nil { + fmt.Println(err2) + } + defer rows.Close() + + fieldsTypes, _ := rows.ColumnTypes() + fields, _ := rows.Columns() + + for rows.Next() { + data := make(map[string]interface{}) + + scans := make([]interface{}, len(fields)) + for i := range scans { + scans[i] = &scans[i] + } + err := rows.Scan(scans...) + if err != nil { + return err + } + + for i, v := range scans { + if v == nil { + data[fields[i]] = v + } else { + if fieldsTypes[i].DatabaseTypeName() == "VARCHAR" || fieldsTypes[i].DatabaseTypeName() == "TEXT" || fieldsTypes[i].DatabaseTypeName() == "CHAR" || fieldsTypes[i].DatabaseTypeName() == "LONGTEXT" { + data[fields[i]] = fmt.Sprintf("%s", v) + } else if fieldsTypes[i].DatabaseTypeName() == "INT" || fieldsTypes[i].DatabaseTypeName() == "BIGINT" || fieldsTypes[i].DatabaseTypeName() == "UNSIGNED INT" || fieldsTypes[i].DatabaseTypeName() == "UNSIGNED BIGINT" { + data[fields[i]] = fmt.Sprintf("%v", v) + } else if fieldsTypes[i].DatabaseTypeName() == "DECIMAL" { + data[fields[i]] = string(v.([]uint8)) + } else { + data[fields[i]] = v + } + } + } + + dest := reflect.New(destType).Elem() + for k, v := range data { + fieldName := CamelString(k) + if dest.FieldByName(fieldName).CanSet() { + filedType := dest.FieldByName(fieldName).Type().String() + x := transToNullType(v, filedType) + dest.FieldByName(fieldName).Set(x) + } + } + destSlice.Set(reflect.Append(destSlice, dest)) + } + + return nil +} + // GetOne 查询某一条记录 func (db *Executor) GetOne(obj interface{}) error { diff --git a/test/aorm_test.go b/test/aorm_test.go index d45d0cb..865e3bc 100644 --- a/test/aorm_test.go +++ b/test/aorm_test.go @@ -52,6 +52,7 @@ func TestAll(t *testing.T) { testGetOne(db, id) testGetMany(db) + testGetManyNew(db) testUpdate(db, id) testDelete(db, id) @@ -171,7 +172,20 @@ func testGetMany(db *sql.DB) { fmt.Println(errSelect) } for i := 0; i < len(list); i++ { - fmt.Println(list[i]) + //fmt.Println(list[i]) + } +} + +func testGetManyNew(db *sql.DB) { + fmt.Println("--- testGetManyNew ---") + + var list []Person + errSelect := aorm.Use(db).Debug(true).Where(&Person{Type: aorm.IntFrom(0)}).GetManyNew(&list) + if errSelect != nil { + fmt.Println(errSelect) + } + for i := 0; i < len(list); i++ { + //fmt.Println(list[i]) } }