mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-10 02:50:19 +08:00
[RKNPU2] RKYOLO Support FP32 return value (#898)
* RKNPU2 Backend兼容其他模型的量化 fd_tensor正式移除zp和scale的量化参数 * 更新FP32返回值的RKYOLO * 更新rkyolov5支持fp32格式 * 更新rkyolov5支持fp32格式 * 更新YOLOv5速度文档 Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
This commit is contained in:
@@ -15,6 +15,7 @@ ONNX模型不能直接调用RK芯片中的NPU进行运算,需要把ONNX模型
|
||||
| 任务场景 | 模型 | 模型版本(表示已经测试的版本) | ARM CPU/RKNN速度(ms) |
|
||||
|------------------|-------------------|-------------------------------|--------------------|
|
||||
| Detection | Picodet | Picodet-s | 162/112 |
|
||||
| Detection | RKYOLOV5 | YOLOV5-S-Relu(int8) | -/57 |
|
||||
| Segmentation | Unet | Unet-cityscapes | -/- |
|
||||
| Segmentation | PP-LiteSeg | PP_LiteSeg_T_STDC1_cityscapes | -/- |
|
||||
| Segmentation | PP-HumanSegV2Lite | portrait | 53/50 |
|
||||
|
@@ -25,12 +25,16 @@ void RKNPU2Infer(const std::string& model_file, const std::string& image_file) {
|
||||
auto im = cv::imread(image_file);
|
||||
|
||||
fastdeploy::vision::DetectionResult res;
|
||||
fastdeploy::TimeCounter tc;
|
||||
tc.Start();
|
||||
if (!model.Predict(im, &res)) {
|
||||
std::cerr << "Failed to predict." << std::endl;
|
||||
return;
|
||||
}
|
||||
std::cout << res.Str() << std::endl;
|
||||
auto vis_im = fastdeploy::vision::VisDetection(im, res,0.5);
|
||||
tc.End();
|
||||
tc.PrintInfo("RKYOLOV5 in RKNN");
|
||||
std::cout << res.Str() << std::endl;
|
||||
cv::imwrite("vis_result.jpg", vis_im);
|
||||
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
|
||||
}
|
||||
|
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include "fastdeploy/backends/rknpu/rknpu2/rknpu2_backend.h"
|
||||
|
||||
#include "fastdeploy/utils/perf.h"
|
||||
namespace fastdeploy {
|
||||
RKNPU2Backend::~RKNPU2Backend() {
|
||||
// Release memory uniformly here
|
||||
@@ -190,7 +190,6 @@ bool RKNPU2Backend::GetModelInputOutputInfos() {
|
||||
FDERROR << "rknpu2_backend only support input format is NHWC or UNDEFINED" << std::endl;
|
||||
}
|
||||
|
||||
DumpTensorAttr(input_attrs_[i]);
|
||||
|
||||
// copy input_attrs_ to input tensor info
|
||||
std::string temp_name = input_attrs_[i].name;
|
||||
@@ -199,16 +198,13 @@ bool RKNPU2Backend::GetModelInputOutputInfos() {
|
||||
for (int j = 0; j < input_attrs_[i].n_dims; j++) {
|
||||
temp_shape[j] = (int)input_attrs_[i].dims[j];
|
||||
}
|
||||
FDDataType temp_dtype =
|
||||
fastdeploy::RKNPU2Backend::RknnTensorTypeToFDDataType(
|
||||
input_attrs_[i].type);
|
||||
FDDataType temp_dtype = fastdeploy::RKNPU2Backend::RknnTensorTypeToFDDataType(input_attrs_[i].type);
|
||||
TensorInfo temp_input_info = {temp_name, temp_shape, temp_dtype};
|
||||
inputs_desc_[i] = temp_input_info;
|
||||
}
|
||||
|
||||
// Get detailed output parameters
|
||||
output_attrs_ =
|
||||
(rknn_tensor_attr*)malloc(sizeof(rknn_tensor_attr) * io_num.n_output);
|
||||
output_attrs_ = (rknn_tensor_attr*)malloc(sizeof(rknn_tensor_attr) * io_num.n_output);
|
||||
memset(output_attrs_, 0, io_num.n_output * sizeof(rknn_tensor_attr));
|
||||
outputs_desc_.resize(io_num.n_output);
|
||||
|
||||
@@ -230,14 +226,8 @@ bool RKNPU2Backend::GetModelInputOutputInfos() {
|
||||
int n_dims = output_attrs_[i].n_dims;
|
||||
if((n_dims == 4) && (output_attrs_[i].dims[3] == 1)){
|
||||
n_dims--;
|
||||
FDWARNING << "The output["
|
||||
<< i
|
||||
<< "].shape[3] is 1, remove this dim."
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
DumpTensorAttr(output_attrs_[i]);
|
||||
|
||||
// copy output_attrs_ to output tensor
|
||||
std::string temp_name = output_attrs_[i].name;
|
||||
std::vector<int> temp_shape{};
|
||||
@@ -246,9 +236,8 @@ bool RKNPU2Backend::GetModelInputOutputInfos() {
|
||||
temp_shape[j] = (int)output_attrs_[i].dims[j];
|
||||
}
|
||||
|
||||
FDDataType temp_dtype =
|
||||
fastdeploy::RKNPU2Backend::RknnTensorTypeToFDDataType(
|
||||
output_attrs_[i].type);
|
||||
// The data type of output data is changed to FP32
|
||||
FDDataType temp_dtype = FDDataType::FP32;
|
||||
TensorInfo temp_input_info = {temp_name, temp_shape, temp_dtype};
|
||||
outputs_desc_[i] = temp_input_info;
|
||||
}
|
||||
@@ -265,11 +254,12 @@ bool RKNPU2Backend::GetModelInputOutputInfos() {
|
||||
void RKNPU2Backend::DumpTensorAttr(rknn_tensor_attr& attr) {
|
||||
printf("index=%d, name=%s, n_dims=%d, dims=[%d, %d, %d, %d], "
|
||||
"n_elems=%d, size=%d, fmt=%s, type=%s, "
|
||||
"qnt_type=%s, zp=%d, scale=%f\n",
|
||||
"qnt_type=%s, zp=%d, scale=%f, pass_through=%d",
|
||||
attr.index, attr.name, attr.n_dims, attr.dims[0], attr.dims[1],
|
||||
attr.dims[2], attr.dims[3], attr.n_elems, attr.size,
|
||||
get_format_string(attr.fmt), get_type_string(attr.type),
|
||||
get_qnt_type_string(attr.qnt_type), attr.zp, attr.scale);
|
||||
get_qnt_type_string(attr.qnt_type), attr.zp, attr.scale,
|
||||
attr.pass_through);
|
||||
}
|
||||
|
||||
TensorInfo RKNPU2Backend::GetInputInfo(int index) {
|
||||
@@ -320,7 +310,12 @@ bool RKNPU2Backend::Infer(std::vector<FDTensor>& inputs,
|
||||
input_attrs_[i].type = input_type;
|
||||
input_attrs_[i].size = inputs[0].Nbytes();
|
||||
input_attrs_[i].size_with_stride = inputs[0].Nbytes();
|
||||
input_attrs_[i].pass_through = 0;
|
||||
if(input_attrs_[i].type == RKNN_TENSOR_FLOAT16 ||
|
||||
input_attrs_[i].type == RKNN_TENSOR_FLOAT32){
|
||||
FDINFO << "The input model is not a quantitative model. "
|
||||
"Close the normalize operation." << std::endl;
|
||||
}
|
||||
|
||||
input_mems_[i] = rknn_create_mem(ctx, inputs[i].Nbytes());
|
||||
if (input_mems_[i] == nullptr) {
|
||||
FDERROR << "rknn_create_mem input_mems_ error." << std::endl;
|
||||
@@ -345,11 +340,13 @@ bool RKNPU2Backend::Infer(std::vector<FDTensor>& inputs,
|
||||
FDERROR << "rknn_create_mem output_mems_ error." << std::endl;
|
||||
return false;
|
||||
}
|
||||
if(output_attrs_[i].type == RKNN_TENSOR_FLOAT16){
|
||||
|
||||
// The data type of output data is changed to FP32
|
||||
output_attrs_[i].type = RKNN_TENSOR_FLOAT32;
|
||||
}
|
||||
|
||||
// default output type is depend on model, this requires float32 to compute top5
|
||||
ret = rknn_set_io_mem(ctx, output_mems_[i], &output_attrs_[i]);
|
||||
|
||||
// set output memory and attribute
|
||||
if (ret != RKNN_SUCC) {
|
||||
FDERROR << "output tensor memory rknn_set_io_mem fail! ret=" << ret
|
||||
@@ -377,7 +374,6 @@ bool RKNPU2Backend::Infer(std::vector<FDTensor>& inputs,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// run rknn
|
||||
ret = rknn_run(ctx, nullptr);
|
||||
if (ret != RKNN_SUCC) {
|
||||
@@ -395,8 +391,6 @@ bool RKNPU2Backend::Infer(std::vector<FDTensor>& inputs,
|
||||
}
|
||||
(*outputs)[i].Resize(temp_shape, outputs_desc_[i].dtype,
|
||||
outputs_desc_[i].name);
|
||||
std::vector<float> output_scale = {output_attrs_[i].scale};
|
||||
(*outputs)[i].SetQuantizationInfo(output_attrs_[i].zp, output_scale);
|
||||
memcpy((*outputs)[i].MutableData(), (float*)output_mems_[i]->virt_addr,
|
||||
(*outputs)[i].Nbytes());
|
||||
}
|
||||
|
@@ -138,11 +138,6 @@ void FDTensor::Resize(const std::vector<int64_t>& new_shape) {
|
||||
external_data_ptr = nullptr;
|
||||
}
|
||||
|
||||
void FDTensor::SetQuantizationInfo(int32_t zero_point,std::vector<float>& scale){
|
||||
quantized_parameter_.first = zero_point;
|
||||
quantized_parameter_.second = scale;
|
||||
}
|
||||
|
||||
void FDTensor::Resize(const std::vector<int64_t>& new_shape,
|
||||
const FDDataType& data_type,
|
||||
const std::string& tensor_name,
|
||||
@@ -455,9 +450,4 @@ FDTensor& FDTensor::operator=(FDTensor&& other) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
const std::pair<int32_t, std::vector<float>>
|
||||
FDTensor::GetQuantizationInfo() const{
|
||||
return quantized_parameter_;
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
|
@@ -25,10 +25,6 @@
|
||||
namespace fastdeploy {
|
||||
|
||||
struct FASTDEPLOY_DECL FDTensor {
|
||||
// These two parameters are general parameters of quantitative model.
|
||||
std::pair<int32_t, std::vector<float>> quantized_parameter_ = {0, {0}};
|
||||
void SetQuantizationInfo(int32_t zero_point, std::vector<float>& scale);
|
||||
const std::pair<int32_t, std::vector<float>> GetQuantizationInfo() const;
|
||||
|
||||
// std::vector<int8_t> data;
|
||||
void* buffer_ = nullptr;
|
||||
|
@@ -11,7 +11,6 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/detection/contrib/rknpu2/postprocessor.h"
|
||||
#include "fastdeploy/vision/utils/utils.h"
|
||||
|
||||
@@ -38,17 +37,16 @@ bool RKYOLOPostprocessor::Run(const std::vector<FDTensor>& tensors,
|
||||
int grid_h = height_ / stride;
|
||||
int grid_w = width_ / stride;
|
||||
int* anchor = &(anchors_.data()[i * 2 * anchor_per_branch_]);
|
||||
if (tensors[i].dtype == FDDataType::INT8 ||
|
||||
tensors[i].dtype == FDDataType::UINT8) {
|
||||
auto quantization_info = tensors[i].GetQuantizationInfo();
|
||||
validCount =
|
||||
validCount + ProcessInt8((int8_t*)tensors[i].Data() + skip_address,
|
||||
anchor, grid_h, grid_w, stride,
|
||||
filterBoxes, boxesScore, classId,
|
||||
conf_threshold_, quantization_info.first,
|
||||
quantization_info.second[0]);
|
||||
if (tensors[i].dtype == FDDataType::FP32) {
|
||||
validCount = validCount +
|
||||
ProcessFP16((float*)tensors[i].Data() + skip_address,
|
||||
anchor, grid_h, grid_w, stride, filterBoxes,
|
||||
boxesScore, classId, conf_threshold_);
|
||||
} else {
|
||||
FDERROR << "RKYOLO Only Support INT8 Model" << std::endl;
|
||||
FDERROR << "RKYOLO Only Support FP32 Model."
|
||||
<< "But the result's type is "
|
||||
<< Str(tensors[i].dtype)
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,7 +67,7 @@ bool RKYOLOPostprocessor::Run(const std::vector<FDTensor>& tensors,
|
||||
NMS(validCount, filterBoxes, classId, indexArray, nms_threshold_, false);
|
||||
} else if (anchor_per_branch_ == 1) {
|
||||
NMS(validCount, filterBoxes, classId, indexArray, nms_threshold_, true);
|
||||
}else{
|
||||
} else {
|
||||
FDERROR << "anchor_per_branch_ only support 3 or 1." << std::endl;
|
||||
return false;
|
||||
}
|
||||
@@ -107,60 +105,57 @@ bool RKYOLOPostprocessor::Run(const std::vector<FDTensor>& tensors,
|
||||
return true;
|
||||
}
|
||||
|
||||
int RKYOLOPostprocessor::ProcessInt8(int8_t* input, int* anchor, int grid_h,
|
||||
int RKYOLOPostprocessor::ProcessFP16(float* input, int* anchor, int grid_h,
|
||||
int grid_w, int stride,
|
||||
std::vector<float>& boxes,
|
||||
std::vector<float>& boxScores,
|
||||
std::vector<int>& classId, float threshold,
|
||||
int32_t zp, float scale) {
|
||||
std::vector<int>& classId,
|
||||
float threshold) {
|
||||
|
||||
int validCount = 0;
|
||||
int grid_len = grid_h * grid_w;
|
||||
float thres = threshold;
|
||||
auto thres_i8 = QntF32ToAffine(thres, zp, scale);
|
||||
// float thres_sigmoid = threshold;
|
||||
for (int a = 0; a < anchor_per_branch_; a++) {
|
||||
for (int i = 0; i < grid_h; i++) {
|
||||
for (int j = 0; j < grid_w; j++) {
|
||||
int8_t box_confidence =
|
||||
input[(prob_box_size * a + 4) * grid_len + i * grid_w + j];
|
||||
if (box_confidence >= thres_i8) {
|
||||
int offset = (prob_box_size * a) * grid_len + i * grid_w + j;
|
||||
int8_t* in_ptr = input + offset;
|
||||
float box_confidence =
|
||||
input[(prob_box_size_ * a + 4) * grid_len + i * grid_w + j];
|
||||
if (box_confidence >= threshold) {
|
||||
int offset = (prob_box_size_ * a) * grid_len + i * grid_w + j;
|
||||
float* in_ptr = input + offset;
|
||||
|
||||
int8_t maxClassProbs = in_ptr[5 * grid_len];
|
||||
float maxClassProbs = in_ptr[5 * grid_len];
|
||||
int maxClassId = 0;
|
||||
for (int k = 1; k < obj_class_num; ++k) {
|
||||
int8_t prob = in_ptr[(5 + k) * grid_len];
|
||||
for (int k = 1; k < obj_class_num_; ++k) {
|
||||
float prob = in_ptr[(5 + k) * grid_len];
|
||||
if (prob > maxClassProbs) {
|
||||
maxClassId = k;
|
||||
maxClassProbs = prob;
|
||||
}
|
||||
}
|
||||
|
||||
float box_conf_f32 = DeqntAffineToF32(box_confidence, zp, scale);
|
||||
float class_prob_f32 = DeqntAffineToF32(maxClassProbs, zp, scale);
|
||||
float box_conf_f32 = (box_confidence);
|
||||
float class_prob_f32 = (maxClassProbs);
|
||||
float limit_score = 0;
|
||||
if (anchor_per_branch_ == 1) {
|
||||
limit_score = box_conf_f32 * class_prob_f32;
|
||||
} else {
|
||||
limit_score = class_prob_f32;
|
||||
} else {
|
||||
limit_score = box_conf_f32 * class_prob_f32;
|
||||
}
|
||||
//printf("limit score: %f\n", limit_score);
|
||||
// printf("limit score: %f", limit_score);
|
||||
if (limit_score > conf_threshold_) {
|
||||
float box_x, box_y, box_w, box_h;
|
||||
if (anchor_per_branch_ == 1) {
|
||||
box_x = DeqntAffineToF32(*in_ptr, zp, scale);
|
||||
box_y = DeqntAffineToF32(in_ptr[grid_len], zp, scale);
|
||||
box_w = DeqntAffineToF32(in_ptr[2 * grid_len], zp, scale);
|
||||
box_h = DeqntAffineToF32(in_ptr[3 * grid_len], zp, scale);
|
||||
box_w = exp(box_w) * stride;
|
||||
box_h = exp(box_h) * stride;
|
||||
box_x = *in_ptr;
|
||||
box_y = (in_ptr[grid_len]);
|
||||
box_w = exp(in_ptr[2 * grid_len]) * stride;
|
||||
box_h = exp(in_ptr[3 * grid_len]) * stride;
|
||||
} else {
|
||||
box_x = DeqntAffineToF32(*in_ptr, zp, scale) * 2.0 - 0.5;
|
||||
box_y = DeqntAffineToF32(in_ptr[grid_len], zp, scale) * 2.0 - 0.5;
|
||||
box_w = DeqntAffineToF32(in_ptr[2 * grid_len], zp, scale) * 2.0;
|
||||
box_h = DeqntAffineToF32(in_ptr[3 * grid_len], zp, scale) * 2.0;
|
||||
box_w = box_w * box_w;
|
||||
box_h = box_h * box_h;
|
||||
box_x = *in_ptr * 2.0 - 0.5;
|
||||
box_y = (in_ptr[grid_len]) * 2.0 - 0.5;
|
||||
box_w = (in_ptr[2 * grid_len]) * 2.0;
|
||||
box_h = (in_ptr[3 * grid_len]) * 2.0;
|
||||
box_w *= box_w;
|
||||
box_h *= box_h;
|
||||
}
|
||||
box_x = (box_x + j) * (float)stride;
|
||||
box_y = (box_y + i) * (float)stride;
|
||||
|
@@ -85,12 +85,12 @@ class FASTDEPLOY_DECL RKYOLOPostprocessor {
|
||||
int width_ = 0;
|
||||
int anchor_per_branch_ = 0;
|
||||
|
||||
// Process Int8 Model
|
||||
int ProcessInt8(int8_t* input, int* anchor, int grid_h, int grid_w,
|
||||
int stride, std::vector<float>& boxes,
|
||||
std::vector<float>& boxScores, std::vector<int>& classId,
|
||||
float threshold, int32_t zp, float scale);
|
||||
|
||||
int ProcessFP16(float *input, int *anchor, int grid_h,
|
||||
int grid_w, int stride,
|
||||
std::vector<float> &boxes,
|
||||
std::vector<float> &boxScores,
|
||||
std::vector<int> &classId,
|
||||
float threshold);
|
||||
// Model
|
||||
int QuickSortIndiceInverse(std::vector<float>& input, int left, int right,
|
||||
std::vector<int>& indices);
|
||||
@@ -100,8 +100,8 @@ class FASTDEPLOY_DECL RKYOLOPostprocessor {
|
||||
std::vector<float> scale_;
|
||||
float nms_threshold_ = 0.45;
|
||||
float conf_threshold_ = 0.25;
|
||||
int prob_box_size = 85;
|
||||
int obj_class_num = 80;
|
||||
int prob_box_size_ = 85;
|
||||
int obj_class_num_ = 80;
|
||||
int obj_num_bbox_max_size = 200;
|
||||
};
|
||||
|
||||
|
@@ -30,16 +30,11 @@ RKYOLOPreprocessor::RKYOLOPreprocessor() {
|
||||
}
|
||||
|
||||
void RKYOLOPreprocessor::LetterBox(FDMat* mat) {
|
||||
std::cout << "mat->Height() = " << mat->Height() << std::endl;
|
||||
std::cout << "mat->Width() = " << mat->Width() << std::endl;
|
||||
|
||||
float scale =
|
||||
std::min(size_[1] * 1.0 / mat->Height(), size_[0] * 1.0 / mat->Width());
|
||||
std::cout << "RKYOLOPreprocessor scale_ = " << scale << std::endl;
|
||||
if (!is_scale_up_) {
|
||||
scale = std::min(scale, 1.0f);
|
||||
}
|
||||
std::cout << "RKYOLOPreprocessor scale_ = " << scale << std::endl;
|
||||
scale_.push_back(scale);
|
||||
|
||||
int resize_h = int(round(mat->Height() * scale));
|
||||
@@ -74,19 +69,6 @@ void RKYOLOPreprocessor::LetterBox(FDMat* mat) {
|
||||
}
|
||||
|
||||
bool RKYOLOPreprocessor::Preprocess(FDMat* mat, FDTensor* output) {
|
||||
// process after image load
|
||||
// float ratio = std::min(size_[1] * 1.0f / static_cast<float>(mat->Height()),
|
||||
// size_[0] * 1.0f / static_cast<float>(mat->Width()));
|
||||
// if (std::fabs(ratio - 1.0f) > 1e-06) {
|
||||
// int interp = cv::INTER_AREA;
|
||||
// if (ratio > 1.0) {
|
||||
// interp = cv::INTER_LINEAR;
|
||||
// }
|
||||
// int resize_h = int(mat->Height() * ratio);
|
||||
// int resize_w = int(mat->Width() * ratio);
|
||||
// Resize::Run(mat, resize_w, resize_h, -1, -1, interp);
|
||||
// }
|
||||
|
||||
// RKYOLO's preprocess steps
|
||||
// 1. letterbox
|
||||
// 2. convert_and_permute(swap_rb=true)
|
||||
|
@@ -18,7 +18,6 @@
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
namespace detection {
|
||||
/*! @brief Preprocessor object for YOLOv5 serials model.
|
||||
*/
|
||||
|
@@ -60,11 +60,6 @@ bool RKYOLO::BatchPredict(const std::vector<cv::Mat>& images,
|
||||
FDERROR << "Failed to preprocess the input image." << std::endl;
|
||||
return false;
|
||||
}
|
||||
auto pad_hw_values_ = preprocessor_.GetPadHWValues();
|
||||
postprocessor_.SetPadHWValues(preprocessor_.GetPadHWValues());
|
||||
std::cout << "preprocessor_ scale_ = " << preprocessor_.GetScale()[0]
|
||||
<< std::endl;
|
||||
postprocessor_.SetScale(preprocessor_.GetScale());
|
||||
|
||||
reused_input_tensors_[0].name = InputInfoOfRuntime(0).name;
|
||||
if (!Infer(reused_input_tensors_, &reused_output_tensors_)) {
|
||||
@@ -72,12 +67,14 @@ bool RKYOLO::BatchPredict(const std::vector<cv::Mat>& images,
|
||||
return false;
|
||||
}
|
||||
|
||||
auto pad_hw_values_ = preprocessor_.GetPadHWValues();
|
||||
postprocessor_.SetPadHWValues(preprocessor_.GetPadHWValues());
|
||||
postprocessor_.SetScale(preprocessor_.GetScale());
|
||||
if (!postprocessor_.Run(reused_output_tensors_, results)) {
|
||||
FDERROR << "Failed to postprocess the inference results by runtime."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user