mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 17:41:52 +08:00
[CVCUDA] PP-OCR Cls & Rec preprocessor support CV-CUDA (#1470)
* ppocr cls preprocessor use manager * hwc2chw cvcuda * ppocr rec preproc use manager * ocr rec preproc cvcuda * fix rec preproc bug * ppocr cls&rec preproc set normalize * fix pybind * address comment
This commit is contained in:
@@ -22,8 +22,20 @@ namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace ocr {
|
||||
|
||||
void OcrClassifierResizeImage(FDMat* mat,
|
||||
const std::vector<int>& cls_image_shape) {
|
||||
ClassifierPreprocessor::ClassifierPreprocessor() {
|
||||
resize_op_ = std::make_shared<Resize>(-1, -1);
|
||||
|
||||
std::vector<float> value = {0, 0, 0};
|
||||
pad_op_ = std::make_shared<Pad>(0, 0, 0, 0, value);
|
||||
|
||||
normalize_op_ =
|
||||
std::make_shared<Normalize>(std::vector<float>({0.5f, 0.5f, 0.5f}),
|
||||
std::vector<float>({0.5f, 0.5f, 0.5f}), true);
|
||||
hwc2chw_op_ = std::make_shared<HWC2CHW>();
|
||||
}
|
||||
|
||||
void ClassifierPreprocessor::OcrClassifierResizeImage(
|
||||
FDMat* mat, const std::vector<int>& cls_image_shape) {
|
||||
int img_c = cls_image_shape[0];
|
||||
int img_h = cls_image_shape[1];
|
||||
int img_w = cls_image_shape[2];
|
||||
@@ -36,12 +48,8 @@ void OcrClassifierResizeImage(FDMat* mat,
|
||||
else
|
||||
resize_w = int(ceilf(img_h * ratio));
|
||||
|
||||
Resize::Run(mat, resize_w, img_h);
|
||||
}
|
||||
|
||||
bool ClassifierPreprocessor::Run(std::vector<FDMat>* images,
|
||||
std::vector<FDTensor>* outputs) {
|
||||
return Run(images, outputs, 0, images->size());
|
||||
resize_op_->SetWidthAndHeight(resize_w, img_h);
|
||||
(*resize_op_)(mat);
|
||||
}
|
||||
|
||||
bool ClassifierPreprocessor::Run(std::vector<FDMat>* images,
|
||||
@@ -55,36 +63,37 @@ bool ClassifierPreprocessor::Run(std::vector<FDMat>* images,
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<FDMat> mats(end_index - start_index);
|
||||
for (size_t i = start_index; i < end_index; ++i) {
|
||||
FDMat* mat = &(images->at(i));
|
||||
mats[i - start_index] = images->at(i);
|
||||
}
|
||||
return Run(&mats, outputs);
|
||||
}
|
||||
|
||||
bool ClassifierPreprocessor::Apply(FDMatBatch* image_batch,
|
||||
std::vector<FDTensor>* outputs) {
|
||||
for (size_t i = 0; i < image_batch->mats->size(); ++i) {
|
||||
FDMat* mat = &(image_batch->mats->at(i));
|
||||
OcrClassifierResizeImage(mat, cls_image_shape_);
|
||||
if (!disable_normalize_) {
|
||||
Normalize::Run(mat, mean_, scale_, is_scale_);
|
||||
(*normalize_op_)(mat);
|
||||
}
|
||||
std::vector<float> value = {0, 0, 0};
|
||||
if (mat->Width() < cls_image_shape_[2]) {
|
||||
Pad::Run(mat, 0, 0, 0, cls_image_shape_[2] - mat->Width(), value);
|
||||
pad_op_->SetPaddingSize(0, 0, 0, cls_image_shape_[2] - mat->Width());
|
||||
(*pad_op_)(mat);
|
||||
}
|
||||
|
||||
if (!disable_permute_) {
|
||||
HWC2CHW::Run(mat);
|
||||
Cast::Run(mat, "float");
|
||||
(*hwc2chw_op_)(mat);
|
||||
}
|
||||
}
|
||||
// Only have 1 output Tensor.
|
||||
// Only have 1 output tensor.
|
||||
outputs->resize(1);
|
||||
// Concat all the preprocessed data to a batch tensor
|
||||
size_t tensor_size = end_index - start_index;
|
||||
std::vector<FDTensor> tensors(tensor_size);
|
||||
for (size_t i = 0; i < tensor_size; ++i) {
|
||||
(*images)[i + start_index].ShareWithTensor(&(tensors[i]));
|
||||
tensors[i].ExpandDim(0);
|
||||
}
|
||||
if (tensors.size() == 1) {
|
||||
(*outputs)[0] = std::move(tensors[0]);
|
||||
} else {
|
||||
function::Concat(tensors, &((*outputs)[0]), 0);
|
||||
}
|
||||
// Get the NCHW tensor
|
||||
FDTensor* tensor = image_batch->Tensor();
|
||||
(*outputs)[0].SetExternalData(tensor->Shape(), tensor->Dtype(),
|
||||
tensor->Data(), tensor->device,
|
||||
tensor->device_id);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user