[Other]Refactor PaddleSeg with preprocessor && postprocessor && support batch (#639)

* Refactor PaddleSeg with preprocessor && postprocessor

* Fix bugs

* Delete redundancy code

* Modify by comments

* Refactor according to comments

* Add batch evaluation

* Add single test script

* Add ppliteseg single test script && fix eval(raise) error

* fix bug

* Fix evaluation segmentation.py batch predict

* Fix segmentation evaluation bug

* Fix evaluation segmentation bugs

Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
huangjianhui
2022-11-28 15:50:12 +08:00
committed by GitHub
parent d0307192f9
commit 312e1b097d
26 changed files with 1173 additions and 449 deletions

View File

@@ -313,11 +313,11 @@ void ArgMinMax(const FDTensor& x, FDTensor* out, int64_t axis,
FDASSERT(axis < x_rank, FDASSERT(axis < x_rank,
"'axis'(%lld) must be less than or equal to Rank(X)(%lld).", axis, "'axis'(%lld) must be less than or equal to Rank(X)(%lld).", axis,
x_rank); x_rank);
FDASSERT(output_dtype == FDDataType::INT32 || FDDataType::INT64, FDASSERT(output_dtype == FDDataType::INT32 || FDDataType::INT64 || FDDataType::UINT8,
"The attribute of dtype in argmin/argmax must be [%s] or [%s], but " "The attribute of dtype in argmin/argmax must be [%s], [%s] or [%s], but "
"received [%s].", "received [%s].",
Str(FDDataType::INT32).c_str(), Str(FDDataType::INT64).c_str(), Str(FDDataType::INT32).c_str(), Str(FDDataType::INT64).c_str(),
Str(output_dtype).c_str()); Str(FDDataType::UINT8).c_str(), Str(output_dtype).c_str());
if (axis < 0) axis += x_rank; if (axis < 0) axis += x_rank;
if (output_dtype == FDDataType::INT32) { if (output_dtype == FDDataType::INT32) {
int64_t all_element_num = 0; int64_t all_element_num = 0;

View File

@@ -177,7 +177,7 @@ void BindRuntime(pybind11::module& m) {
} }
std::vector<FDTensor> outputs; std::vector<FDTensor> outputs;
if (!self.Infer(inputs, &outputs)) { if (!self.Infer(inputs, &outputs)) {
pybind11::eval("raise Exception('Failed to inference with Runtime.')"); throw std::runtime_error("Failed to inference with Runtime.");
} }
return outputs; return outputs;
}) })

View File

@@ -141,24 +141,26 @@ FASTDEPLOY_DECL bool ReadBinaryFromFile(const std::string& file,
} \ } \
}() }()
#define FD_VISIT_INT_FLOAT_TYPES(TYPE, NAME, ...) \ #define FD_VISIT_INT_FLOAT_TYPES(TYPE, NAME, ...) \
[&] { \ [&] { \
const auto& __dtype__ = TYPE; \ const auto& __dtype__ = TYPE; \
switch (__dtype__) { \ switch (__dtype__) { \
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT32, int32_t, \ FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT32, int32_t, \
__VA_ARGS__) \ __VA_ARGS__) \
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT64, int64_t, \ FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT64, int64_t, \
__VA_ARGS__) \ __VA_ARGS__) \
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP32, float, \ FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP32, float, \
__VA_ARGS__) \ __VA_ARGS__) \
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP64, double, \ FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP64, double, \
__VA_ARGS__) \ __VA_ARGS__) \
default: \ FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::UINT8, uint8_t, \
FDASSERT(false, \ __VA_ARGS__) \
"Invalid enum data type. Expect to accept data type INT32, " \ default: \
"INT64, FP32, FP64, but receive type %s.", \ FDASSERT(false, \
Str(__dtype__).c_str()); \ "Invalid enum data type. Expect to accept data type INT32, " \
} \ "INT64, FP32, FP64, UINT8 but receive type %s.", \
Str(__dtype__).c_str()); \
} \
}() }()
#define FD_VISIT_FLOAT_TYPES(TYPE, NAME, ...) \ #define FD_VISIT_FLOAT_TYPES(TYPE, NAME, ...) \
@@ -177,20 +179,22 @@ FASTDEPLOY_DECL bool ReadBinaryFromFile(const std::string& file,
} \ } \
}() }()
#define FD_VISIT_INT_TYPES(TYPE, NAME, ...) \ #define FD_VISIT_INT_TYPES(TYPE, NAME, ...) \
[&] { \ [&] { \
const auto& __dtype__ = TYPE; \ const auto& __dtype__ = TYPE; \
switch (__dtype__) { \ switch (__dtype__) { \
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT32, int32_t, \ FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT32, int32_t, \
__VA_ARGS__) \ __VA_ARGS__) \
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT64, int64_t, \ FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT64, int64_t, \
__VA_ARGS__) \ __VA_ARGS__) \
default: \ FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::UINT8, uint8_t, \
FDASSERT(false, \ __VA_ARGS__) \
"Invalid enum data type. Expect to accept data type INT32, " \ default: \
"INT64, but receive type %s.", \ FDASSERT(false, \
Str(__dtype__).c_str()); \ "Invalid enum data type. Expect to accept data type INT32, " \
} \ "INT64, UINT8 but receive type %s.", \
Str(__dtype__).c_str()); \
} \
}() }()
FASTDEPLOY_DECL std::vector<int64_t> FASTDEPLOY_DECL std::vector<int64_t>

View File

@@ -25,7 +25,7 @@ void BindPaddleClas(pybind11::module& m) {
} }
std::vector<FDTensor> outputs; std::vector<FDTensor> outputs;
if (!self.Run(&images, &outputs)) { if (!self.Run(&images, &outputs)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in PaddleClasPreprocessor.')"); throw std::runtime_error("Failed to preprocess the input data in PaddleClasPreprocessor.");
} }
if (!self.WithGpu()) { if (!self.WithGpu()) {
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
@@ -44,7 +44,7 @@ void BindPaddleClas(pybind11::module& m) {
.def("run", [](vision::classification::PaddleClasPostprocessor& self, std::vector<FDTensor>& inputs) { .def("run", [](vision::classification::PaddleClasPostprocessor& self, std::vector<FDTensor>& inputs) {
std::vector<vision::ClassifyResult> results; std::vector<vision::ClassifyResult> results;
if (!self.Run(inputs, &results)) { if (!self.Run(inputs, &results)) {
pybind11::eval("raise Exception('Failed to postprocess the runtime result in PaddleClasPostprocessor.')"); throw std::runtime_error("Failed to postprocess the runtime result in PaddleClasPostprocessor.");
} }
return results; return results;
}) })
@@ -53,7 +53,7 @@ void BindPaddleClas(pybind11::module& m) {
std::vector<FDTensor> inputs; std::vector<FDTensor> inputs;
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true); PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
if (!self.Run(inputs, &results)) { if (!self.Run(inputs, &results)) {
pybind11::eval("raise Exception('Failed to postprocess the runtime result in PaddleClasPostprocessor.')"); throw std::runtime_error("Failed to postprocess the runtime result in PaddleClasPostprocessor.");
} }
return results; return results;
}) })

View File

@@ -22,8 +22,8 @@ namespace vision {
enum Layout { HWC, CHW }; enum Layout { HWC, CHW };
struct FASTDEPLOY_DECL Mat { struct FASTDEPLOY_DECL Mat {
Mat() = default;
explicit Mat(const cv::Mat& mat) { explicit Mat(const cv::Mat& mat) {
cpu_mat = mat; cpu_mat = mat;
layout = Layout::HWC; layout = Layout::HWC;
@@ -45,8 +45,12 @@ struct FASTDEPLOY_DECL Mat {
#endif #endif
Mat(const Mat& mat) = default; Mat(const Mat& mat) = default;
// Move assignment
Mat& operator=(const Mat& mat) = default; Mat& operator=(const Mat& mat) = default;
// Move constructor
Mat(Mat&& other) = default;
// Careful if you use this interface // Careful if you use this interface
// this only used if you don't want to write // this only used if you don't want to write
// the original data, and write to a new cv::Mat // the original data, and write to a new cv::Mat

View File

@@ -230,7 +230,7 @@ cv::Mat CreateZeroCopyOpenCVMatFromTensor(const FDTensor& tensor) {
int channels = static_cast<int>(tensor.shape[2]); int channels = static_cast<int>(tensor.shape[2]);
return CreateZeroCopyOpenCVMatFromBuffer( return CreateZeroCopyOpenCVMatFromBuffer(
height, width, channels, type, height, width, channels, type,
const_cast<void*>(tensor.Data())); const_cast<void*>(tensor.CpuData()));
} }
#ifdef ENABLE_FLYCV #ifdef ENABLE_FLYCV

View File

@@ -285,6 +285,13 @@ std::string FaceAlignmentResult::Str() {
} }
void SegmentationResult::Clear() { void SegmentationResult::Clear() {
label_map.clear();
score_map.clear();
shape.clear();
contain_score_map = false;
}
void SegmentationResult::Free() {
std::vector<uint8_t>().swap(label_map); std::vector<uint8_t>().swap(label_map);
std::vector<float>().swap(score_map); std::vector<float>().swap(score_map);
std::vector<int64_t>().swap(shape); std::vector<int64_t>().swap(shape);
@@ -293,7 +300,7 @@ void SegmentationResult::Clear() {
void SegmentationResult::Reserve(int size) { void SegmentationResult::Reserve(int size) {
label_map.reserve(size); label_map.reserve(size);
if (contain_score_map > 0) { if (contain_score_map) {
score_map.reserve(size); score_map.reserve(size);
} }
} }
@@ -332,6 +339,18 @@ std::string SegmentationResult::Str() {
return out; return out;
} }
SegmentationResult& SegmentationResult::operator=(SegmentationResult&& other) {
if (&other != this) {
label_map = std::move(other.label_map);
shape = std::move(other.shape);
contain_score_map = std::move(other.contain_score_map);
if (contain_score_map) {
score_map.clear();
score_map = std::move(other.score_map);
}
}
return *this;
}
FaceRecognitionResult::FaceRecognitionResult(const FaceRecognitionResult& res) { FaceRecognitionResult::FaceRecognitionResult(const FaceRecognitionResult& res) {
embedding.assign(res.embedding.begin(), res.embedding.end()); embedding.assign(res.embedding.begin(), res.embedding.end());
} }

View File

@@ -247,6 +247,7 @@ struct FASTDEPLOY_DECL FaceAlignmentResult : public BaseResult {
/*! @brief Segmentation result structure for all the segmentation models /*! @brief Segmentation result structure for all the segmentation models
*/ */
struct FASTDEPLOY_DECL SegmentationResult : public BaseResult { struct FASTDEPLOY_DECL SegmentationResult : public BaseResult {
SegmentationResult() = default;
/** \brief /** \brief
* `label_map` stores the pixel-level category labels for input image. the number of pixels is equal to label_map.size() * `label_map` stores the pixel-level category labels for input image. the number of pixels is equal to label_map.size()
*/ */
@@ -257,12 +258,21 @@ struct FASTDEPLOY_DECL SegmentationResult : public BaseResult {
std::vector<float> score_map; std::vector<float> score_map;
/// The output shape, means [H, W] /// The output shape, means [H, W]
std::vector<int64_t> shape; std::vector<int64_t> shape;
/// SegmentationResult whether containing score_map
bool contain_score_map = false; bool contain_score_map = false;
/// Copy constructor
SegmentationResult(const SegmentationResult& other) = default;
/// Move assignment
SegmentationResult& operator=(SegmentationResult&& other);
ResultType type = ResultType::SEGMENTATION; ResultType type = ResultType::SEGMENTATION;
/// Clear detection result /// Clear Segmentation result
void Clear(); void Clear();
/// Clear Segmentation result and free the memory
void Free();
void Reserve(int size); void Reserve(int size);
void Resize(int size); void Resize(int size);

View File

@@ -27,7 +27,7 @@ void BindYOLOv5(pybind11::module& m) {
std::vector<FDTensor> outputs; std::vector<FDTensor> outputs;
std::vector<std::map<std::string, std::array<float, 2>>> ims_info; std::vector<std::map<std::string, std::array<float, 2>>> ims_info;
if (!self.Run(&images, &outputs, &ims_info)) { if (!self.Run(&images, &outputs, &ims_info)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in PaddleClasPreprocessor.')"); throw std::runtime_error("Failed to preprocess the input data in PaddleClasPreprocessor.");
} }
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
outputs[i].StopSharing(); outputs[i].StopSharing();
@@ -45,7 +45,7 @@ void BindYOLOv5(pybind11::module& m) {
const std::vector<std::map<std::string, std::array<float, 2>>>& ims_info) { const std::vector<std::map<std::string, std::array<float, 2>>>& ims_info) {
std::vector<vision::DetectionResult> results; std::vector<vision::DetectionResult> results;
if (!self.Run(inputs, &results, ims_info)) { if (!self.Run(inputs, &results, ims_info)) {
pybind11::eval("raise Exception('Failed to postprocess the runtime result in YOLOv5Postprocessor.')"); throw std::runtime_error("Failed to postprocess the runtime result in YOLOv5Postprocessor.");
} }
return results; return results;
}) })
@@ -55,7 +55,7 @@ void BindYOLOv5(pybind11::module& m) {
std::vector<FDTensor> inputs; std::vector<FDTensor> inputs;
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true); PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
if (!self.Run(inputs, &results, ims_info)) { if (!self.Run(inputs, &results, ims_info)) {
pybind11::eval("raise Exception('Failed to postprocess the runtime result in YOLOv5Postprocessor.')"); throw std::runtime_error("Failed to postprocess the runtime result in YOLOv5Postprocessor.");
} }
return results; return results;
}) })

View File

@@ -27,7 +27,7 @@ void BindYOLOv7(pybind11::module& m) {
std::vector<FDTensor> outputs; std::vector<FDTensor> outputs;
std::vector<std::map<std::string, std::array<float, 2>>> ims_info; std::vector<std::map<std::string, std::array<float, 2>>> ims_info;
if (!self.Run(&images, &outputs, &ims_info)) { if (!self.Run(&images, &outputs, &ims_info)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in PaddleClasPreprocessor.')"); throw std::runtime_error("Failed to preprocess the input data in YOLOV7Preprocessor.");
} }
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
outputs[i].StopSharing(); outputs[i].StopSharing();
@@ -45,7 +45,7 @@ void BindYOLOv7(pybind11::module& m) {
const std::vector<std::map<std::string, std::array<float, 2>>>& ims_info) { const std::vector<std::map<std::string, std::array<float, 2>>>& ims_info) {
std::vector<vision::DetectionResult> results; std::vector<vision::DetectionResult> results;
if (!self.Run(inputs, &results, ims_info)) { if (!self.Run(inputs, &results, ims_info)) {
pybind11::eval("raise Exception('Failed to postprocess the runtime result in YOLOv7Postprocessor.')"); throw std::runtime_error("Failed to postprocess the runtime result in YOLOv7Postprocessor.");
} }
return results; return results;
}) })
@@ -55,7 +55,7 @@ void BindYOLOv7(pybind11::module& m) {
std::vector<FDTensor> inputs; std::vector<FDTensor> inputs;
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true); PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
if (!self.Run(inputs, &results, ims_info)) { if (!self.Run(inputs, &results, ims_info)) {
pybind11::eval("raise Exception('Failed to postprocess the runtime result in YOLOv7Postprocessor.')"); throw std::runtime_error("Failed to postprocess the runtime result in YOLOv7Postprocessor.");
} }
return results; return results;
}) })

View File

@@ -25,7 +25,7 @@ void BindPPDet(pybind11::module& m) {
} }
std::vector<FDTensor> outputs; std::vector<FDTensor> outputs;
if (!self.Run(&images, &outputs)) { if (!self.Run(&images, &outputs)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in PaddleDetPreprocessor.')"); throw std::runtime_error("Failed to preprocess the input data in PaddleDetPreprocessor.");
} }
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
outputs[i].StopSharing(); outputs[i].StopSharing();
@@ -39,7 +39,7 @@ void BindPPDet(pybind11::module& m) {
.def("run", [](vision::detection::PaddleDetPostprocessor& self, std::vector<FDTensor>& inputs) { .def("run", [](vision::detection::PaddleDetPostprocessor& self, std::vector<FDTensor>& inputs) {
std::vector<vision::DetectionResult> results; std::vector<vision::DetectionResult> results;
if (!self.Run(inputs, &results)) { if (!self.Run(inputs, &results)) {
pybind11::eval("raise Exception('Failed to postprocess the runtime result in PaddleDetPostprocessor.')"); throw std::runtime_error("Failed to postprocess the runtime result in PaddleDetPostprocessor.");
} }
return results; return results;
}) })
@@ -52,7 +52,7 @@ void BindPPDet(pybind11::module& m) {
std::vector<FDTensor> inputs; std::vector<FDTensor> inputs;
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true); PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
if (!self.Run(inputs, &results)) { if (!self.Run(inputs, &results)) {
pybind11::eval("raise Exception('Failed to postprocess the runtime result in PaddleDetPostprocessor.')"); throw std::runtime_error("Failed to postprocess the runtime result in PaddleDetPostprocessor.");
} }
return results; return results;
}); });

View File

@@ -141,7 +141,7 @@ bool SCRFD::Preprocess(Mat* mat, FDTensor* output,
is_scale_up, stride); is_scale_up, stride);
BGR2RGB::Run(mat); BGR2RGB::Run(mat);
if(!this->disable_normalize_and_permute_){ if (!disable_normalize_and_permute_) {
// Normalize::Run(mat, std::vector<float>(mat->Channels(), 0.0), // Normalize::Run(mat, std::vector<float>(mat->Channels(), 0.0),
// std::vector<float>(mat->Channels(), 1.0)); // std::vector<float>(mat->Channels(), 1.0));
// Compute `result = mat * alpha + beta` directly by channel // Compute `result = mat * alpha + beta` directly by channel
@@ -368,7 +368,7 @@ bool SCRFD::Predict(cv::Mat* im, FaceDetectionResult* result,
return true; return true;
} }
void SCRFD::DisableNormalizeAndPermute(){ void SCRFD::DisableNormalizeAndPermute(){
this->disable_normalize_and_permute_ = true; disable_normalize_and_permute_ = true;
} }
} // namespace facedet } // namespace facedet
} // namespace vision } // namespace vision

View File

@@ -61,7 +61,7 @@ void BindPPOCRModel(pybind11::module& m) {
std::vector<std::vector<std::array<int, 8>>> results; std::vector<std::vector<std::array<int, 8>>> results;
if (!self.Run(inputs, &results, batch_det_img_info)) { if (!self.Run(inputs, &results, batch_det_img_info)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in DBDetectorPostprocessor.')"); throw std::runtime_error("Failed to preprocess the input data in DBDetectorPostprocessor.");
} }
return results; return results;
}) })
@@ -72,7 +72,7 @@ void BindPPOCRModel(pybind11::module& m) {
std::vector<FDTensor> inputs; std::vector<FDTensor> inputs;
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true); PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
if (!self.Run(inputs, &results, batch_det_img_info)) { if (!self.Run(inputs, &results, batch_det_img_info)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in DBDetectorPostprocessor.')"); throw std::runtime_error("Failed to preprocess the input data in DBDetectorPostprocessor.");
} }
return results; return results;
}); });
@@ -98,7 +98,7 @@ void BindPPOCRModel(pybind11::module& m) {
} }
std::vector<FDTensor> outputs; std::vector<FDTensor> outputs;
if (!self.Run(&images, &outputs)) { if (!self.Run(&images, &outputs)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in ClassifierPreprocessor.')"); throw std::runtime_error("Failed to preprocess the input data in ClassifierPreprocessor.");
} }
for(size_t i = 0; i< outputs.size(); ++i){ for(size_t i = 0; i< outputs.size(); ++i){
outputs[i].StopSharing(); outputs[i].StopSharing();
@@ -114,7 +114,7 @@ void BindPPOCRModel(pybind11::module& m) {
std::vector<int> cls_labels; std::vector<int> cls_labels;
std::vector<float> cls_scores; std::vector<float> cls_scores;
if (!self.Run(inputs, &cls_labels, &cls_scores)) { if (!self.Run(inputs, &cls_labels, &cls_scores)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in ClassifierPostprocessor.')"); throw std::runtime_error("Failed to preprocess the input data in ClassifierPostprocessor.");
} }
return make_pair(cls_labels,cls_scores); return make_pair(cls_labels,cls_scores);
}) })
@@ -125,7 +125,7 @@ void BindPPOCRModel(pybind11::module& m) {
std::vector<int> cls_labels; std::vector<int> cls_labels;
std::vector<float> cls_scores; std::vector<float> cls_scores;
if (!self.Run(inputs, &cls_labels, &cls_scores)) { if (!self.Run(inputs, &cls_labels, &cls_scores)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in ClassifierPostprocessor.')"); throw std::runtime_error("Failed to preprocess the input data in ClassifierPostprocessor.");
} }
return make_pair(cls_labels,cls_scores); return make_pair(cls_labels,cls_scores);
}); });
@@ -152,7 +152,7 @@ void BindPPOCRModel(pybind11::module& m) {
} }
std::vector<FDTensor> outputs; std::vector<FDTensor> outputs;
if (!self.Run(&images, &outputs)) { if (!self.Run(&images, &outputs)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in RecognizerPreprocessor.')"); throw std::runtime_error("Failed to preprocess the input data in RecognizerPreprocessor.");
} }
for(size_t i = 0; i< outputs.size(); ++i){ for(size_t i = 0; i< outputs.size(); ++i){
outputs[i].StopSharing(); outputs[i].StopSharing();
@@ -167,7 +167,7 @@ void BindPPOCRModel(pybind11::module& m) {
std::vector<std::string> texts; std::vector<std::string> texts;
std::vector<float> rec_scores; std::vector<float> rec_scores;
if (!self.Run(inputs, &texts, &rec_scores)) { if (!self.Run(inputs, &texts, &rec_scores)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in RecognizerPostprocessor.')"); throw std::runtime_error("Failed to preprocess the input data in RecognizerPostprocessor.");
} }
return make_pair(texts, rec_scores); return make_pair(texts, rec_scores);
}) })
@@ -178,7 +178,7 @@ void BindPPOCRModel(pybind11::module& m) {
std::vector<std::string> texts; std::vector<std::string> texts;
std::vector<float> rec_scores; std::vector<float> rec_scores;
if (!self.Run(inputs, &texts, &rec_scores)) { if (!self.Run(inputs, &texts, &rec_scores)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in RecognizerPostprocessor.')"); throw std::runtime_error("Failed to preprocess the input data in RecognizerPostprocessor.");
} }
return make_pair(texts, rec_scores); return make_pair(texts, rec_scores);
}); });

View File

@@ -14,21 +14,17 @@
#include "fastdeploy/vision/segmentation/ppseg/model.h" #include "fastdeploy/vision/segmentation/ppseg/model.h"
#include "fastdeploy/vision/utils/utils.h"
#include "yaml-cpp/yaml.h"
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
namespace segmentation { namespace segmentation {
PaddleSegModel::PaddleSegModel(const std::string& model_file, PaddleSegModel::PaddleSegModel(const std::string& model_file,
const std::string& params_file, const std::string& params_file,
const std::string& config_file, const std::string& config_file,
const RuntimeOption& custom_option, const RuntimeOption& custom_option,
const ModelFormat& model_format) { const ModelFormat& model_format) : preprocessor_(config_file),
config_file_ = config_file; postprocessor_(config_file) {
valid_cpu_backends = {Backend::OPENVINO, Backend::PDINFER, Backend::ORT, valid_cpu_backends = {Backend::OPENVINO, Backend::PDINFER, Backend::ORT, Backend::LITE};
Backend::LITE};
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT}; valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
valid_rknpu_backends = {Backend::RKNPU2}; valid_rknpu_backends = {Backend::RKNPU2};
runtime_option = custom_option; runtime_option = custom_option;
@@ -39,13 +35,6 @@ PaddleSegModel::PaddleSegModel(const std::string& model_file,
} }
bool PaddleSegModel::Initialize() { bool PaddleSegModel::Initialize() {
reused_input_tensors_.resize(1);
reused_output_tensors_.resize(1);
if (!BuildPreprocessPipelineFromConfig()) {
FDERROR << "Failed to build preprocess pipeline from configuration file."
<< std::endl;
return false;
}
if (!InitRuntime()) { if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl; FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false; return false;
@@ -53,326 +42,42 @@ bool PaddleSegModel::Initialize() {
return true; return true;
} }
bool PaddleSegModel::BuildPreprocessPipelineFromConfig() { bool PaddleSegModel::Predict(cv::Mat* im, SegmentationResult* result) {
processors_.clear(); return Predict(*im, result);
YAML::Node cfg; }
processors_.push_back(std::make_shared<BGR2RGB>());
try { bool PaddleSegModel::Predict(const cv::Mat& im, SegmentationResult* result) {
cfg = YAML::LoadFile(config_file_); std::vector<SegmentationResult> results;
} catch (YAML::BadFile& e) { if (!BatchPredict({im}, &results)) {
FDERROR << "Failed to load yaml file " << config_file_
<< ", maybe you should check this file." << std::endl;
return false; return false;
} }
bool yml_contain_resize_op = false; *result = std::move(results[0]);
if (cfg["Deploy"]["transforms"]) {
auto preprocess_cfg = cfg["Deploy"]["transforms"];
for (const auto& op : preprocess_cfg) {
FDASSERT(op.IsMap(),
"Require the transform information in yaml be Map type.");
if (op["type"].as<std::string>() == "Normalize") {
if (!(this->disable_normalize_and_permute)) {
std::vector<float> mean = {0.5, 0.5, 0.5};
std::vector<float> std = {0.5, 0.5, 0.5};
if (op["mean"]) {
mean = op["mean"].as<std::vector<float>>();
}
if (op["std"]) {
std = op["std"].as<std::vector<float>>();
}
processors_.push_back(std::make_shared<Normalize>(mean, std));
}
} else if (op["type"].as<std::string>() == "Resize") {
yml_contain_resize_op = true;
const auto& target_size = op["target_size"];
int resize_width = target_size[0].as<int>();
int resize_height = target_size[1].as<int>();
processors_.push_back(
std::make_shared<Resize>(resize_width, resize_height));
} else {
std::string op_name = op["type"].as<std::string>();
FDERROR << "Unexcepted preprocess operator: " << op_name << "."
<< std::endl;
return false;
}
}
}
if (cfg["Deploy"]["input_shape"]) {
auto input_shape = cfg["Deploy"]["input_shape"];
int input_batch = input_shape[0].as<int>();
int input_channel = input_shape[1].as<int>();
int input_height = input_shape[2].as<int>();
int input_width = input_shape[3].as<int>();
if (input_height == -1 || input_width == -1) {
FDWARNING << "Some exportd PaddleSeg models with dynamic shape may "
"not be able inference with ONNX Runtime/TensorRT, if error "
"happend, please try to change to use Paddle "
"Inference/OpenVINO backends instead, or export model with "
"fixed input shape." << std::endl;
}
if (input_height != -1 && input_width != -1 && !yml_contain_resize_op) {
processors_.push_back(
std::make_shared<Resize>(input_width, input_height));
}
}
if (cfg["Deploy"]["output_op"]) {
std::string output_op = cfg["Deploy"]["output_op"].as<std::string>();
if (output_op == "softmax") {
is_with_softmax = true;
is_with_argmax = false;
} else if (output_op == "argmax") {
is_with_softmax = false;
is_with_argmax = true;
} else if (output_op == "none") {
is_with_softmax = false;
is_with_argmax = false;
} else {
FDERROR << "Unexcepted output_op operator in deploy.yml: " << output_op
<< "." << std::endl;
}
}
if (!(this->disable_normalize_and_permute)) {
processors_.push_back(std::make_shared<HWC2CHW>());
}
// Fusion will improve performance
FuseTransforms(&processors_);
return true; return true;
} }
bool PaddleSegModel::Preprocess(Mat* mat, FDTensor* output) { bool PaddleSegModel::BatchPredict(const std::vector<cv::Mat>& imgs,
for (size_t i = 0; i < processors_.size(); ++i) { std::vector<SegmentationResult>* results) {
if (processors_[i]->Name().compare("Resize") == 0) { std::vector<FDMat> fd_images = WrapMat(imgs);
auto processor = dynamic_cast<Resize*>(processors_[i].get()); // Record the shape of input images
int resize_width = -1; std::map<std::string, std::vector<std::array<int, 2>>> imgs_info;
int resize_height = -1; if (!preprocessor_.Run(&fd_images, &reused_input_tensors_, &imgs_info)) {
std::tie(resize_width, resize_height) = processor->GetWidthAndHeight();
if (is_vertical_screen && (resize_width > resize_height)) {
if (!(processor->SetWidthAndHeight(resize_height, resize_width))) {
FDERROR << "Failed to set width and height of "
<< processors_[i]->Name() << " processor." << std::endl;
}
}
}
if (!(*(processors_[i].get()))(mat)) {
FDERROR << "Failed to process image data in " << processors_[i]->Name()
<< "." << std::endl;
return false;
}
}
mat->ShareWithTensor(output);
output->shape.insert(output->shape.begin(), 1);
output->name = InputInfoOfRuntime(0).name;
return true;
}
bool PaddleSegModel::Postprocess(
FDTensor* infer_result, SegmentationResult* result,
const std::map<std::string, std::array<int, 2>>& im_info) {
// PaddleSeg has three types of inference output:
// 1. output with argmax and without softmax. 3-D matrix N(C)HW, Channel
// always 1, the element in matrix is classified label_id INT64 Type.
// 2. output without argmax and without softmax. 4-D matrix NCHW, N(batch)
// always
// 1(only support batch size 1), Channel is the num of classes. The
// element is the logits of classes
// FP32
// 3. output without argmax and with softmax. 4-D matrix NCHW, the result
// of 2 with softmax layer
// Fastdeploy output:
// 1. label_map
// 2. score_map(optional)
// 3. shape: 2-D HW
FDASSERT(infer_result->dtype == FDDataType::INT64 ||
infer_result->dtype == FDDataType::FP32 ||
infer_result->dtype == FDDataType::INT32,
"Require the data type of output is int64, fp32 or int32, but now "
"it's %s.",
Str(infer_result->dtype).c_str());
result->Clear();
FDASSERT(infer_result->shape[0] == 1, "Only support batch size = 1.");
int64_t infer_batch = infer_result->shape[0];
int64_t infer_channel = 0;
int64_t infer_height = 0;
int64_t infer_width = 0;
if (is_with_argmax) {
infer_channel = 1;
infer_height = infer_result->shape[1];
infer_width = infer_result->shape[2];
} else {
infer_channel = infer_result->shape[1];
infer_height = infer_result->shape[2];
infer_width = infer_result->shape[3];
}
int64_t infer_chw = infer_channel * infer_height * infer_width;
bool is_resized = false;
auto iter_ipt = im_info.find("input_shape");
FDASSERT(iter_ipt != im_info.end(), "Cannot find input_shape from im_info.");
int ipt_h = iter_ipt->second[0];
int ipt_w = iter_ipt->second[1];
if (ipt_h != infer_height || ipt_w != infer_width) {
is_resized = true;
}
if (!is_with_softmax && apply_softmax) {
function::Softmax(*infer_result, infer_result, 1);
}
if (!is_with_argmax) {
// output without argmax
result->contain_score_map = true;
std::vector<int64_t> dim{0, 2, 3, 1};
function::Transpose(*infer_result, infer_result, dim);
}
// batch always 1, so ignore
infer_result->shape = {infer_height, infer_width, infer_channel};
// for resize mat below
FDTensor new_infer_result;
Mat* mat = nullptr;
std::vector<float_t>* fp32_result_buffer = nullptr;
if (is_resized) {
if (infer_result->dtype == FDDataType::INT64 ||
infer_result->dtype == FDDataType::INT32) {
if (infer_result->dtype == FDDataType::INT64) {
int64_t* infer_result_buffer =
static_cast<int64_t*>(infer_result->Data());
// cv::resize don't support `CV_8S` or `CV_32S`
// refer to https://github.com/opencv/opencv/issues/20991
// https://github.com/opencv/opencv/issues/7862
fp32_result_buffer = new std::vector<float_t>(
infer_result_buffer, infer_result_buffer + infer_chw);
}
if (infer_result->dtype == FDDataType::INT32) {
int32_t* infer_result_buffer =
static_cast<int32_t*>(infer_result->Data());
// cv::resize don't support `CV_8S` or `CV_32S`
// refer to https://github.com/opencv/opencv/issues/20991
// https://github.com/opencv/opencv/issues/7862
fp32_result_buffer = new std::vector<float_t>(
infer_result_buffer, infer_result_buffer + infer_chw);
}
infer_result->Resize(infer_result->shape, FDDataType::FP32);
infer_result->SetExternalData(
infer_result->shape, FDDataType::FP32,
static_cast<void*>(fp32_result_buffer->data()));
}
mat = new Mat(Mat::Create(*infer_result, ProcLib::OPENCV));
Resize::Run(mat, ipt_w, ipt_h, -1.0f, -1.0f, 1, false, ProcLib::OPENCV);
mat->ShareWithTensor(&new_infer_result);
result->shape = new_infer_result.shape;
} else {
result->shape = infer_result->shape;
}
// output shape is 2-D HW layout, so out_num = H * W
int out_num =
std::accumulate(result->shape.begin(), result->shape.begin() + 2, 1,
std::multiplies<int>());
result->Resize(out_num);
if (result->contain_score_map) {
// output with label_map and score_map
int32_t* argmax_infer_result_buffer = nullptr;
float_t* score_infer_result_buffer = nullptr;
FDTensor argmax_infer_result;
FDTensor max_score_result;
std::vector<int64_t> reduce_dim{-1};
// argmax
if (is_resized) {
function::ArgMax(new_infer_result, &argmax_infer_result, -1, FDDataType::INT32);
function::Max(new_infer_result, &max_score_result, reduce_dim);
} else {
function::ArgMax(*infer_result, &argmax_infer_result, -1, FDDataType::INT32);
function::Max(*infer_result, &max_score_result, reduce_dim);
}
argmax_infer_result_buffer =
static_cast<int32_t*>(argmax_infer_result.Data());
score_infer_result_buffer = static_cast<float_t*>(max_score_result.Data());
for (int i = 0; i < out_num; i++) {
result->label_map[i] =
static_cast<uint8_t>(*(argmax_infer_result_buffer + i));
}
std::memcpy(result->score_map.data(), score_infer_result_buffer,
out_num * sizeof(float_t));
} else {
// output only with label_map
if (is_resized) {
float_t* infer_result_buffer =
static_cast<float_t*>(new_infer_result.Data());
for (int i = 0; i < out_num; i++) {
result->label_map[i] = static_cast<uint8_t>(*(infer_result_buffer + i));
}
} else {
if (infer_result->dtype == FDDataType::INT64) {
const int64_t* infer_result_buffer =
static_cast<const int64_t*>(infer_result->Data());
for (int i = 0; i < out_num; i++) {
result->label_map[i] =
static_cast<uint8_t>(*(infer_result_buffer + i));
}
}
if (infer_result->dtype == FDDataType::INT32) {
const int32_t* infer_result_buffer =
static_cast<const int32_t*>(infer_result->Data());
for (int i = 0; i < out_num; i++) {
result->label_map[i] =
static_cast<uint8_t>(*(infer_result_buffer + i));
}
}
}
}
// HWC remove C
result->shape.erase(result->shape.begin() + 2);
delete fp32_result_buffer;
delete mat;
mat = nullptr;
return true;
}
bool PaddleSegModel::Predict(cv::Mat* im, SegmentationResult* result) {
Mat mat(*im);
std::map<std::string, std::array<int, 2>> im_info;
// Record the shape of image and the shape of preprocessed image
im_info["input_shape"] = {static_cast<int>(mat.Height()),
static_cast<int>(mat.Width())};
if (!Preprocess(&mat, &(reused_input_tensors_[0]))) {
FDERROR << "Failed to preprocess input data while using model:" FDERROR << "Failed to preprocess input data while using model:"
<< ModelName() << "." << std::endl; << ModelName() << "." << std::endl;
return false; return false;
} }
reused_input_tensors_[0].name = InputInfoOfRuntime(0).name;
if (!Infer()) { if (!Infer(reused_input_tensors_, &reused_output_tensors_)) {
FDERROR << "Failed to inference while using model:" << ModelName() << "." FDERROR << "Failed to inference while using model:" << ModelName() << "."
<< std::endl; << std::endl;
return false; return false;
} }
if (!Postprocess(&reused_output_tensors_[0], result, im_info)) { if (!postprocessor_.Run(reused_output_tensors_, results, imgs_info)) {
FDERROR << "Failed to postprocess while using model:" << ModelName() << "." FDERROR << "Failed to postprocess while using model:" << ModelName() << "."
<< std::endl; << std::endl;
return false; return false;
} }
return true; return true;
} }
void PaddleSegModel::DisableNormalizeAndPermute() {
this->disable_normalize_and_permute = true;
// the DisableNormalizeAndPermute function will be invalid if the
// configuration file is loaded during preprocessing
if (!BuildPreprocessPipelineFromConfig()) {
FDERROR << "Failed to build preprocess pipeline from configuration file."
<< std::endl;
}
}
} // namespace segmentation } // namespace segmentation
} // namespace vision } // namespace vision
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -14,8 +14,8 @@
#pragma once #pragma once
#include "fastdeploy/fastdeploy_model.h" #include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h" #include "fastdeploy/vision/segmentation/ppseg/preprocessor.h"
#include "fastdeploy/vision/common/result.h" #include "fastdeploy/vision/segmentation/ppseg/postprocessor.h"
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
@@ -44,7 +44,7 @@ class FASTDEPLOY_DECL PaddleSegModel : public FastDeployModel {
/// Get model's name /// Get model's name
std::string ModelName() const { return "PaddleSeg"; } std::string ModelName() const { return "PaddleSeg"; }
/** \brief Predict the segmentation result for an input image /** \brief DEPRECATED Predict the segmentation result for an input image
* *
* \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format * \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format
* \param[in] result The output segmentation result will be writen to this structure * \param[in] result The output segmentation result will be writen to this structure
@@ -52,36 +52,37 @@ class FASTDEPLOY_DECL PaddleSegModel : public FastDeployModel {
*/ */
virtual bool Predict(cv::Mat* im, SegmentationResult* result); virtual bool Predict(cv::Mat* im, SegmentationResult* result);
/** \brief Whether applying softmax operator in the postprocess, default value is false /** \brief Predict the segmentation result for an input image
*
* \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format
* \param[in] result The output segmentation result will be writen to this structure
* \return true if the segmentation prediction successed, otherwise false
*/ */
bool apply_softmax = false; virtual bool Predict(const cv::Mat& im, SegmentationResult* result);
/** \brief For PP-HumanSeg model, set true if the input image is vertical image(height > width), default value is false /** \brief Predict the segmentation results for a batch of input images
*
* \param[in] imgs, The input image list, each element comes from cv::imread()
* \param[in] results The output segmentation result list
* \return true if the prediction successed, otherwise false
*/ */
bool is_vertical_screen = false; virtual bool BatchPredict(const std::vector<cv::Mat>& imgs,
std::vector<SegmentationResult>* results);
/// Get preprocessor reference of PaddleSegModel
virtual PaddleSegPreprocessor& GetPreprocessor() {
return preprocessor_;
}
// This function will disable normalize and hwc2chw in preprocessing step. /// Get postprocessor reference of PaddleSegModel
void DisableNormalizeAndPermute(); virtual PaddleSegPostprocessor& GetPostprocessor() {
private: return postprocessor_;
}
protected:
bool Initialize(); bool Initialize();
PaddleSegPreprocessor preprocessor_;
bool BuildPreprocessPipelineFromConfig(); PaddleSegPostprocessor postprocessor_;
bool Preprocess(Mat* mat, FDTensor* outputs);
bool Postprocess(FDTensor* infer_result, SegmentationResult* result,
const std::map<std::string, std::array<int, 2>>& im_info);
bool is_with_softmax = false;
bool is_with_argmax = true;
std::vector<std::shared_ptr<Processor>> processors_;
std::string config_file_;
// for recording the switch of normalize and hwc2chw
bool disable_normalize_and_permute = false;
}; };
} // namespace segmentation } // namespace segmentation

View File

@@ -0,0 +1,314 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/segmentation/ppseg/postprocessor.h"
#include "yaml-cpp/yaml.h"
namespace fastdeploy {
namespace vision {
namespace segmentation {
PaddleSegPostprocessor::PaddleSegPostprocessor(const std::string& config_file) {
FDASSERT(ReadFromConfig(config_file), "Failed to create PaddleSegPreprocessor.");
initialized_ = true;
}
bool PaddleSegPostprocessor::ReadFromConfig(const std::string& config_file) {
YAML::Node cfg;
try {
cfg = YAML::LoadFile(config_file);
} catch (YAML::BadFile& e) {
FDERROR << "Failed to load yaml file " << config_file
<< ", maybe you should check this file." << std::endl;
return false;
}
if (cfg["Deploy"]["output_op"]) {
std::string output_op = cfg["Deploy"]["output_op"].as<std::string>();
if (output_op == "softmax") {
is_with_softmax_ = true;
is_with_argmax_ = false;
} else if (output_op == "argmax") {
is_with_softmax_ = false;
is_with_argmax_ = true;
} else if (output_op == "none") {
is_with_softmax_ = false;
is_with_argmax_ = false;
} else {
FDERROR << "Unexcepted output_op operator in deploy.yml: " << output_op
<< "." << std::endl;
return false;
}
}
return true;
}
bool PaddleSegPostprocessor::SliceOneResultFromBatchInferResults(const FDTensor& infer_results,
FDTensor* infer_result,
const std::vector<int64_t>& infer_result_shape,
const int64_t& start_idx) {
int64_t infer_batch = infer_results.shape[0];
if(infer_batch == 1) {
*infer_result = infer_results;
// batch is 1, so ignore
infer_result->shape = infer_result_shape;
} else {
if (infer_results.dtype == FDDataType::FP32) {
const float_t* infer_results_ptr =
reinterpret_cast<const float_t*>(infer_results.CpuData()) + start_idx;
infer_result->SetExternalData(
infer_result_shape, FDDataType::FP32,
reinterpret_cast<void*>(const_cast<float_t *>(infer_results_ptr)));
} else if (infer_results.dtype == FDDataType::INT64) {
const int64_t* infer_results_ptr =
reinterpret_cast<const int64_t*>(infer_results.CpuData()) + start_idx;
infer_result->SetExternalData(
infer_result_shape, FDDataType::INT64,
reinterpret_cast<void*>(const_cast<int64_t *>(infer_results_ptr)));
} else if (infer_results.dtype == FDDataType::INT32) {
const int32_t* infer_results_ptr =
reinterpret_cast<const int32_t*>(infer_results.CpuData()) + start_idx;
infer_result->SetExternalData(
infer_result_shape, FDDataType::INT32,
reinterpret_cast<void*>(const_cast<int32_t *>(infer_results_ptr)));
} else if (infer_results.dtype == FDDataType::UINT8) {
const uint8_t* infer_results_ptr =
reinterpret_cast<const uint8_t*>(infer_results.CpuData()) + start_idx;
infer_result->SetExternalData(
infer_result_shape, FDDataType::UINT8,
reinterpret_cast<void*>(const_cast<uint8_t *>(infer_results_ptr)));
} else {
FDASSERT(false,
"Require the data type for slicing is int64, fp32 or int32, but now "
"it's %s.",
Str(infer_results.dtype).c_str() )
return false;
}
}
return true;
}
bool PaddleSegPostprocessor::ProcessWithScoreResult(const FDTensor& infer_result,
const int64_t& out_num,
SegmentationResult* result) {
const uint8_t* argmax_infer_result_buffer = nullptr;
const float_t* score_infer_result_buffer = nullptr;
FDTensor argmax_infer_result;
FDTensor max_score_result;
std::vector<int64_t> reduce_dim{-1};
function::ArgMax(infer_result, &argmax_infer_result, -1, FDDataType::UINT8);
function::Max(infer_result, &max_score_result, reduce_dim);
score_infer_result_buffer = reinterpret_cast<const float_t*>(max_score_result.CpuData());
std::memcpy(result->score_map.data(), score_infer_result_buffer,
out_num * sizeof(float_t));
argmax_infer_result_buffer =
reinterpret_cast<const uint8_t*>(argmax_infer_result.CpuData());
std::memcpy(result->label_map.data(), argmax_infer_result_buffer,
out_num * sizeof(uint8_t));
return true;
}
bool PaddleSegPostprocessor::ProcessWithLabelResult(const FDTensor& infer_result,
const int64_t& out_num,
SegmentationResult* result) {
if (infer_result.dtype == FDDataType::INT64) {
const int64_t* infer_result_buffer =
reinterpret_cast<const int64_t*>(infer_result.CpuData());
for (int i = 0; i < out_num; i++) {
result->label_map[i] =
static_cast<uint8_t>(*(infer_result_buffer + i));
}
} else if (infer_result.dtype == FDDataType::INT32) {
const int32_t* infer_result_buffer =
reinterpret_cast<const int32_t*>(infer_result.CpuData());
for (int i = 0; i < out_num; i++) {
result->label_map[i] =
static_cast<uint8_t>(*(infer_result_buffer + i));
}
} else if (infer_result.dtype == FDDataType::UINT8) {
const uint8_t* infer_result_buffer =
reinterpret_cast<const uint8_t*>(infer_result.CpuData());
memcpy(result->label_map.data(), infer_result_buffer, out_num * sizeof(uint8_t));
}
else {
FDASSERT(false,
"Require the data type to process is int64, int32 or uint8, but now "
"it's %s.",
Str(infer_result.dtype).c_str());
return false;
}
return true;
}
bool PaddleSegPostprocessor::FDTensorCast2Uint8(FDTensor* infer_result,
const int64_t& offset,
std::vector<uint8_t>* uint8_result_buffer) {
FDDataType infer_result_dtype = infer_result->dtype;
if (infer_result_dtype == FDDataType::INT64) {
const int64_t* infer_result_buffer =
reinterpret_cast<const int64_t*>(infer_result->CpuData());
// cv::resize don't support `CV_8S` or `CV_32S`
// refer to https://github.com/opencv/opencv/issues/20991
// https://github.com/opencv/opencv/issues/7862
uint8_result_buffer = new std::vector<uint8_t>(
infer_result_buffer, infer_result_buffer + offset);
} else if (infer_result_dtype == FDDataType::INT32) {
const int32_t* infer_result_buffer =
reinterpret_cast<const int32_t*>(infer_result->CpuData());
// cv::resize don't support `CV_8S` or `CV_32S`
// refer to https://github.com/opencv/opencv/issues/20991
// https://github.com/opencv/opencv/issues/7862
uint8_result_buffer = new std::vector<uint8_t>(
infer_result_buffer, infer_result_buffer + offset);
} else {
FDASSERT(false,
"Require the data type for casting uint8 is int64, int32, but now "
"it's %s.",
Str(infer_result_dtype).c_str());
return false;
}
infer_result->SetExternalData(
infer_result->shape, FDDataType::UINT8,
reinterpret_cast<void*>(uint8_result_buffer->data()));
return true;
}
bool PaddleSegPostprocessor::Run(
const std::vector<FDTensor>& infer_results,
std::vector<SegmentationResult>* results,
const std::map<std::string, std::vector<std::array<int, 2>>>& imgs_info) {
// PaddleSeg has three types of inference output:
// 1. output with argmax and without softmax. 3-D matrix N(C)HW, Channel
// is batch_size, the element in matrix is classified label_id INT64 type.
// 2. output without argmax and without softmax. 4-D matrix NCHW, N(batch)
// is batch_size, Channel is the num of classes. The element is the logits
// of classes FP32 type
// 3. output without argmax and with softmax. 4-D matrix NCHW, the result
// of 2 with softmax layer
// Fastdeploy output:
// 1. label_map
// 2. score_map(optional)
// 3. shape: 2-D HW
if (!initialized_) {
FDERROR << "Postprocessor is not initialized." << std::endl;
return false;
}
FDDataType infer_results_dtype = infer_results[0].dtype;
FDASSERT(infer_results_dtype == FDDataType::INT64 ||
infer_results_dtype == FDDataType::FP32 ||
infer_results_dtype == FDDataType::INT32,
"Require the data type of output is int64, fp32 or int32, but now "
"it's %s.",
Str(infer_results_dtype).c_str());
auto iter_input_imgs_shape_list = imgs_info.find("shape_info");
FDASSERT(iter_input_imgs_shape_list != imgs_info.end(), "Cannot find shape_info from imgs_info.");
// For Argmax Softmax function to store transformed result below
FDTensor transform_infer_results;
int64_t infer_batch = infer_results[0].shape[0];
int64_t infer_channel = 0;
int64_t infer_height = 0;
int64_t infer_width = 0;
if (is_with_argmax_) {
// infer_results with argmax
infer_channel = 1;
infer_height = infer_results[0].shape[1];
infer_width = infer_results[0].shape[2];
} else {
// infer_results without argmax
infer_channel = 1;
infer_height = infer_results[0].shape[2];
infer_width = infer_results[0].shape[3];
if (store_score_map_) {
infer_channel = infer_results[0].shape[1];
std::vector<int64_t> dim{0, 2, 3, 1};
function::Transpose(infer_results[0], &transform_infer_results, dim);
if (!is_with_softmax_ && apply_softmax_) {
function::Softmax(transform_infer_results, &transform_infer_results, 1);
}
} else {
function::ArgMax(infer_results[0], &transform_infer_results, 1, FDDataType::UINT8);
infer_results_dtype = transform_infer_results.dtype;
}
}
int64_t infer_chw = infer_channel * infer_height * infer_width;
results->resize(infer_batch);
for (int i = 0; i < infer_batch; i++) {
SegmentationResult* result = &((*results)[i]);
result->Clear();
int64_t start_idx = i * infer_chw;
FDTensor infer_result;
std::vector<int64_t> infer_result_shape = {infer_height, infer_width, infer_channel};
if (is_with_argmax_) {
SliceOneResultFromBatchInferResults(infer_results[0],
&infer_result,
infer_result_shape,
start_idx);
} else {
SliceOneResultFromBatchInferResults(transform_infer_results,
&infer_result,
infer_result_shape,
start_idx);
}
bool is_resized = false;
int input_height = iter_input_imgs_shape_list->second[i][0];
int input_width = iter_input_imgs_shape_list->second[i][1];
if (input_height != infer_height || input_width != infer_width) {
is_resized = true;
}
FDMat mat;
std::vector<uint8_t> uint8_result_buffer;
if (is_resized) {
if (infer_results_dtype == FDDataType::INT64 ||
infer_results_dtype == FDDataType::INT32 ){
FDTensorCast2Uint8(&infer_result, infer_chw, &uint8_result_buffer);
}
mat = std::move(Mat::Create(infer_result, ProcLib::OPENCV));
Resize::Run(&mat, input_width, input_height, -1.0f, -1.0f, 1, false, ProcLib::OPENCV);
mat.ShareWithTensor(&infer_result);
}
result->shape = infer_result.shape;
// output shape is 2-D HW layout, so out_num = H * W
int out_num =
std::accumulate(result->shape.begin(), result->shape.begin() + 2, 1,
std::multiplies<int>());
if (!is_with_argmax_ && store_score_map_) {
// output with label_map and score_map
result->contain_score_map = true;
result->Resize(out_num);
ProcessWithScoreResult(infer_result, out_num, result);
} else {
result->Resize(out_num);
ProcessWithLabelResult(infer_result, out_num, result);
}
// HWC remove C
result->shape.erase(result->shape.begin() + 2);
}
return true;
}
} // namespace segmentation
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,99 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/utils/utils.h"
namespace fastdeploy {
namespace vision {
namespace segmentation {
class FASTDEPLOY_DECL PaddleSegPostprocessor {
public:
/** \brief Create a postprocessor instance for PaddleSeg serials model
*
* \param[in] config_file Path of configuration file for deployment, e.g ppliteseg/deploy.yaml
*/
explicit PaddleSegPostprocessor(const std::string& config_file);
/** \brief Process the result of runtime and fill to SegmentationResult structure
*
* \param[in] tensors The inference result from runtime
* \param[in] result The output result of detection
* \param[in] imgs_info The original input images shape info map, key is "shape_info", value is vector<array<int, 2>> a{{height, width}}
* \return true if the postprocess successed, otherwise false
*/
virtual bool Run(
const std::vector<FDTensor>& infer_results,
std::vector<SegmentationResult>* results,
const std::map<std::string, std::vector<std::array<int, 2>>>& imgs_info);
/** \brief Get apply_softmax property of PaddleSeg model, default is false
*/
bool GetApplySoftmax() const {
return apply_softmax_;
}
/// Set apply_softmax value, bool type required
void SetApplySoftmax(bool value) {
apply_softmax_ = value;
}
/// Get store_score_map property of PaddleSeg model, default is false
bool GetStoreScoreMap() const {
return store_score_map_;
}
/// Set store_score_map value, bool type required
void SetStoreScoreMap(bool value) {
store_score_map_ = value;
}
private:
virtual bool ReadFromConfig(const std::string& config_file);
virtual bool SliceOneResultFromBatchInferResults(
const FDTensor& infer_results,
FDTensor* infer_result,
const std::vector<int64_t>& infer_result_shape,
const int64_t& start_idx);
virtual bool ProcessWithScoreResult(const FDTensor& infer_result,
const int64_t& out_num,
SegmentationResult* result);
virtual bool ProcessWithLabelResult(const FDTensor& infer_result,
const int64_t& out_num,
SegmentationResult* result);
virtual bool FDTensorCast2Uint8(FDTensor* infer_result,
const int64_t& offset,
std::vector<uint8_t>* uint8_result_buffer);
bool is_with_softmax_ = false;
bool is_with_argmax_ = true;
bool apply_softmax_ = false;
bool store_score_map_ = false;
bool initialized_ = false;
};
} // namespace segmentation
} // namespace vision
} // namespace fastdeploy

View File

@@ -15,6 +15,34 @@
namespace fastdeploy { namespace fastdeploy {
void BindPPSeg(pybind11::module& m) { void BindPPSeg(pybind11::module& m) {
pybind11::class_<vision::segmentation::PaddleSegPreprocessor>(
m, "PaddleSegPreprocessor")
.def(pybind11::init<std::string>())
.def("run",
[](vision::segmentation::PaddleSegPreprocessor& self,
std::vector<pybind11::array>& im_list) {
std::vector<vision::FDMat> images;
for (size_t i = 0; i < im_list.size(); ++i) {
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
}
// Record the shape of input images
std::map<std::string, std::vector<std::array<int, 2>>> imgs_info;
std::vector<FDTensor> outputs;
if (!self.Run(&images, &outputs, &imgs_info)) {
throw std::runtime_error("Failed to preprocess the input data in PaddleSegPreprocessor.");
}
for (size_t i = 0; i < outputs.size(); ++i) {
outputs[i].StopSharing();
}
return make_pair(outputs, imgs_info);;
})
.def("disable_normalize_and_permute",
&vision::segmentation::PaddleSegPreprocessor::DisableNormalizeAndPermute)
.def_property("is_vertical_screen",
&vision::segmentation::PaddleSegPreprocessor::GetIsVerticalScreen,
&vision::segmentation::PaddleSegPreprocessor::SetIsVerticalScreen);
pybind11::class_<vision::segmentation::PaddleSegModel, FastDeployModel>( pybind11::class_<vision::segmentation::PaddleSegModel, FastDeployModel>(
m, "PaddleSegModel") m, "PaddleSegModel")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption, .def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
@@ -23,14 +51,53 @@ void BindPPSeg(pybind11::module& m) {
[](vision::segmentation::PaddleSegModel& self, [](vision::segmentation::PaddleSegModel& self,
pybind11::array& data) { pybind11::array& data) {
auto mat = PyArrayToCvMat(data); auto mat = PyArrayToCvMat(data);
vision::SegmentationResult* res = new vision::SegmentationResult(); vision::SegmentationResult res;
self.Predict(&mat, res); self.Predict(&mat, &res);
return res; return res;
}) })
.def("disable_normalize_and_permute",&vision::segmentation::PaddleSegModel::DisableNormalizeAndPermute) .def("batch_predict",
.def_readwrite("apply_softmax", [](vision::segmentation::PaddleSegModel& self, std::vector<pybind11::array>& data) {
&vision::segmentation::PaddleSegModel::apply_softmax) std::vector<cv::Mat> images;
.def_readwrite("is_vertical_screen", for (size_t i = 0; i < data.size(); ++i) {
&vision::segmentation::PaddleSegModel::is_vertical_screen); images.push_back(PyArrayToCvMat(data[i]));
}
std::vector<vision::SegmentationResult> results;
self.BatchPredict(images, &results);
return results;
})
.def_property_readonly("preprocessor", &vision::segmentation::PaddleSegModel::GetPreprocessor)
.def_property_readonly("postprocessor", &vision::segmentation::PaddleSegModel::GetPostprocessor);
pybind11::class_<vision::segmentation::PaddleSegPostprocessor>(
m, "PaddleSegPostprocessor")
.def(pybind11::init<std::string>())
.def("run",
[](vision::segmentation::PaddleSegPostprocessor& self,
std::vector<FDTensor>& inputs,
const std::map<std::string, std::vector<std::array<int, 2>>>& imgs_info) {
std::vector<vision::SegmentationResult> results;
if (!self.Run(inputs, &results, imgs_info)) {
throw std::runtime_error("Failed to postprocess the runtime result in PaddleSegPostprocessor.");
}
return results;
})
.def("run",
[](vision::segmentation::PaddleSegPostprocessor& self,
std::vector<pybind11::array>& input_array,
const std::map<std::string, std::vector<std::array<int, 2>>>& imgs_info) {
std::vector<vision::SegmentationResult> results;
std::vector<FDTensor> inputs;
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
if (!self.Run(inputs, &results, imgs_info)) {
throw std::runtime_error("Failed to postprocess the runtime result in PaddleSegPostprocessor.");
}
return results;
})
.def_property("apply_softmax",
&vision::segmentation::PaddleSegPostprocessor::GetApplySoftmax,
&vision::segmentation::PaddleSegPostprocessor::SetApplySoftmax)
.def_property("store_score_map",
&vision::segmentation::PaddleSegPostprocessor::GetStoreScoreMap,
&vision::segmentation::PaddleSegPostprocessor::SetStoreScoreMap);
} }
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -0,0 +1,169 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/segmentation/ppseg/preprocessor.h"
#include "fastdeploy/function/concat.h"
#include "yaml-cpp/yaml.h"
namespace fastdeploy {
namespace vision {
namespace segmentation {
PaddleSegPreprocessor::PaddleSegPreprocessor(const std::string& config_file) {
this->config_file_ = config_file;
FDASSERT(BuildPreprocessPipelineFromConfig(), "Failed to create PaddleSegPreprocessor.");
initialized_ = true;
}
bool PaddleSegPreprocessor::BuildPreprocessPipelineFromConfig() {
processors_.clear();
YAML::Node cfg;
processors_.push_back(std::make_shared<BGR2RGB>());
try {
cfg = YAML::LoadFile(config_file_);
} catch (YAML::BadFile& e) {
FDERROR << "Failed to load yaml file " << config_file_
<< ", maybe you should check this file." << std::endl;
return false;
}
if (cfg["Deploy"]["transforms"]) {
auto preprocess_cfg = cfg["Deploy"]["transforms"];
for (const auto& op : preprocess_cfg) {
FDASSERT(op.IsMap(),
"Require the transform information in yaml be Map type.");
if (op["type"].as<std::string>() == "Normalize") {
if (!disable_normalize_and_permute_) {
std::vector<float> mean = {0.5, 0.5, 0.5};
std::vector<float> std = {0.5, 0.5, 0.5};
if (op["mean"]) {
mean = op["mean"].as<std::vector<float>>();
}
if (op["std"]) {
std = op["std"].as<std::vector<float>>();
}
processors_.push_back(std::make_shared<Normalize>(mean, std));
}
} else if (op["type"].as<std::string>() == "Resize") {
is_contain_resize_op = true;
const auto& target_size = op["target_size"];
int resize_width = target_size[0].as<int>();
int resize_height = target_size[1].as<int>();
processors_.push_back(
std::make_shared<Resize>(resize_width, resize_height));
} else {
std::string op_name = op["type"].as<std::string>();
FDERROR << "Unexcepted preprocess operator: " << op_name << "."
<< std::endl;
return false;
}
}
}
if (cfg["Deploy"]["input_shape"]) {
auto input_shape = cfg["Deploy"]["input_shape"];
int input_height = input_shape[2].as<int>();
int input_width = input_shape[3].as<int>();
if (input_height != -1 && input_width != -1 && !is_contain_resize_op) {
is_contain_resize_op = true;
processors_.insert(processors_.begin(),
std::make_shared<Resize>(input_width, input_height));
}
}
if (!disable_normalize_and_permute_) {
processors_.push_back(std::make_shared<HWC2CHW>());
}
// Fusion will improve performance
FuseTransforms(&processors_);
return true;
}
bool PaddleSegPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs, std::map<std::string, std::vector<std::array<int, 2>>>* imgs_info) {
if (!initialized_) {
FDERROR << "The preprocessor is not initialized." << std::endl;
return false;
}
if (images->size() == 0) {
FDERROR << "The size of input images should be greater than 0." << std::endl;
return false;
}
std::vector<std::array<int, 2>> shape_info;
for (const auto& image : *images) {
shape_info.push_back({static_cast<int>(image.Height()),
static_cast<int>(image.Width())});
}
(*imgs_info)["shape_info"] = shape_info;
for (size_t i = 0; i < processors_.size(); ++i) {
if (processors_[i]->Name() == "Resize") {
auto processor = dynamic_cast<Resize*>(processors_[i].get());
int resize_width = -1;
int resize_height = -1;
std::tie(resize_width, resize_height) = processor->GetWidthAndHeight();
if (is_vertical_screen_ && (resize_width > resize_height)) {
if (!(processor->SetWidthAndHeight(resize_height, resize_width))) {
FDERROR << "Failed to set width and height of "
<< processors_[i]->Name() << " processor." << std::endl;
}
}
break;
}
}
size_t img_num = images->size();
// Batch preprocess : resize all images to the largest image shape in batch
if (!is_contain_resize_op && img_num > 1) {
int max_width = 0;
int max_height = 0;
for (size_t i = 0; i < img_num; ++i) {
max_width = std::max(max_width, ((*images)[i]).Width());
max_height = std::max(max_height, ((*images)[i]).Height());
}
for (size_t i = 0; i < img_num; ++i) {
Resize::Run(&(*images)[i], max_width, max_height);
}
}
for (size_t i = 0; i < img_num; ++i) {
for (size_t j = 0; j < processors_.size(); ++j) {
if (!(*(processors_[j].get()))(&((*images)[i]))) {
FDERROR << "Failed to process image data in " << processors_[i]->Name()
<< "." << std::endl;
return false;
}
}
}
outputs->resize(1);
// Concat all the preprocessed data to a batch tensor
std::vector<FDTensor> tensors(img_num);
for (size_t i = 0; i < img_num; ++i) {
(*images)[i].ShareWithTensor(&(tensors[i]));
tensors[i].ExpandDim(0);
}
if (tensors.size() == 1) {
(*outputs)[0] = std::move(tensors[0]);
} else {
function::Concat(tensors, &((*outputs)[0]), 0);
}
return true;
}
void PaddleSegPreprocessor::DisableNormalizeAndPermute(){
disable_normalize_and_permute_ = true;
// the DisableNormalizeAndPermute function will be invalid if the configuration file is loaded during preprocessing
if (!BuildPreprocessPipelineFromConfig()) {
FDERROR << "Failed to build preprocess pipeline from configuration file." << std::endl;
}
}
} // namespace segmentation
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,73 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
namespace vision {
namespace segmentation {
class FASTDEPLOY_DECL PaddleSegPreprocessor {
public:
/** \brief Create a preprocessor instance for PaddleSeg serials model
*
* \param[in] config_file Path of configuration file for deployment, e.g ppliteseg/deploy.yaml
*/
explicit PaddleSegPreprocessor(const std::string& config_file);
/** \brief Process the input image and prepare input tensors for runtime
*
* \param[in] images The input image data list, all the elements are returned by cv::imread()
* \param[in] outputs The output tensors which will feed in runtime, include image
* \return true if the preprocess successed, otherwise false
*/
virtual bool Run(
std::vector<FDMat>* images,
std::vector<FDTensor>* outputs,
std::map<std::string, std::vector<std::array<int, 2>>>* imgs_info);
/// Get is_vertical_screen property of PP-HumanSeg model, default is false
bool GetIsVerticalScreen() const {
return is_vertical_screen_;
}
/// Set is_vertical_screen value, bool type required
void SetIsVerticalScreen(bool value) {
is_vertical_screen_ = value;
}
// This function will disable normalize and hwc2chw in preprocessing step.
void DisableNormalizeAndPermute();
private:
virtual bool BuildPreprocessPipelineFromConfig();
std::vector<std::shared_ptr<Processor>> processors_;
std::string config_file_;
/** \brief For PP-HumanSeg model, set true if the input image is vertical image(height > width), default value is false
*/
bool is_vertical_screen_ = false;
// for recording the switch of normalize and hwc2chw
bool disable_normalize_and_permute_ = false;
bool is_contain_resize_op = false;
bool initialized_ = false;
};
} // namespace segmentation
} // namespace vision
} // namespace fastdeploy

View File

@@ -103,6 +103,7 @@ void BindVision(pybind11::module& m) {
.def_readwrite("label_map", &vision::SegmentationResult::label_map) .def_readwrite("label_map", &vision::SegmentationResult::label_map)
.def_readwrite("score_map", &vision::SegmentationResult::score_map) .def_readwrite("score_map", &vision::SegmentationResult::score_map)
.def_readwrite("shape", &vision::SegmentationResult::shape) .def_readwrite("shape", &vision::SegmentationResult::shape)
.def_readwrite("contain_score_map", &vision::SegmentationResult::contain_score_map)
.def("__repr__", &vision::SegmentationResult::Str) .def("__repr__", &vision::SegmentationResult::Str)
.def("__str__", &vision::SegmentationResult::Str); .def("__str__", &vision::SegmentationResult::Str);
@@ -111,7 +112,7 @@ void BindVision(pybind11::module& m) {
.def_readwrite("alpha", &vision::MattingResult::alpha) .def_readwrite("alpha", &vision::MattingResult::alpha)
.def_readwrite("foreground", &vision::MattingResult::foreground) .def_readwrite("foreground", &vision::MattingResult::foreground)
.def_readwrite("shape", &vision::MattingResult::shape) .def_readwrite("shape", &vision::MattingResult::shape)
.def_readwrite("contain_foreground", &vision::MattingResult::shape) .def_readwrite("contain_foreground", &vision::MattingResult::contain_foreground)
.def("__repr__", &vision::MattingResult::Str) .def("__repr__", &vision::MattingResult::Str)
.def("__str__", &vision::MattingResult::Str); .def("__str__", &vision::MattingResult::Str);

View File

@@ -20,7 +20,7 @@ import math
import time import time
def eval_segmentation(model, data_dir): def eval_segmentation(model, data_dir, batch_size=1):
import cv2 import cv2
from .utils import Cityscapes from .utils import Cityscapes
from .utils import f1_score, calculate_area, mean_iou, accuracy, kappa from .utils import f1_score, calculate_area, mean_iou, accuracy, kappa
@@ -39,6 +39,8 @@ def eval_segmentation(model, data_dir):
start_time = 0 start_time = 0
end_time = 0 end_time = 0
average_inference_time = 0 average_inference_time = 0
im_list = []
label_list = []
for image_label_path, i in zip(file_list, for image_label_path, i in zip(file_list,
trange( trange(
image_num, desc="Inference Progress")): image_num, desc="Inference Progress")):
@@ -46,19 +48,31 @@ def eval_segmentation(model, data_dir):
start_time = time.time() start_time = time.time()
im = cv2.imread(image_label_path[0]) im = cv2.imread(image_label_path[0])
label = cv2.imread(image_label_path[1], cv2.IMREAD_GRAYSCALE) label = cv2.imread(image_label_path[1], cv2.IMREAD_GRAYSCALE)
result = model.predict(im) label_list.append(label)
if batch_size == 1:
result = model.predict(im)
results = [result]
else:
im_list.append(im)
# If the batch_size is not satisfied, the remaining pictures are formed into a batch
if (i + 1) % batch_size != 0 and i != image_num - 1:
continue
results = model.batch_predict(im_list)
if i == image_num - 1: if i == image_num - 1:
end_time = time.time() end_time = time.time()
average_inference_time = round( average_inference_time = round(
(end_time - start_time) / (image_num - twenty_percent_image_num), (end_time - start_time) /
4) (image_num - twenty_percent_image_num), 4)
pred = np.array(result.label_map).reshape(result.shape[0], for result, label in zip(results, label_list):
result.shape[1]) pred = np.array(result.label_map).reshape(result.shape[0],
intersect_area, pred_area, label_area = calculate_area(pred, label, result.shape[1])
num_classes) intersect_area, pred_area, label_area = calculate_area(pred, label,
intersect_area_all = intersect_area_all + intersect_area num_classes)
pred_area_all = pred_area_all + pred_area intersect_area_all = intersect_area_all + intersect_area
label_area_all = label_area_all + label_area pred_area_all = pred_area_all + pred_area
label_area_all = label_area_all + label_area
im_list.clear()
label_list.clear()
class_iou, miou = mean_iou(intersect_area_all, pred_area_all, class_iou, miou = mean_iou(intersect_area_all, pred_area_all,
label_area_all) label_area_all)

View File

@@ -13,4 +13,4 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from .ppseg import PaddleSegModel from .ppseg import *

View File

@@ -41,35 +41,55 @@ class PaddleSegModel(FastDeployModel):
model_format) model_format)
assert self.initialized, "PaddleSeg model initialize failed." assert self.initialized, "PaddleSeg model initialize failed."
def predict(self, input_image): def predict(self, image):
"""Predict the segmentation result for an input image """Predict the segmentation result for an input image
:param im: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format :param im: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
:return: SegmentationResult :return: SegmentationResult
""" """
return self._model.predict(input_image) return self._model.predict(image)
def disable_normalize_and_permute(self): def batch_predict(self, image_list):
return self._model.disable_normalize_and_permute() """Predict the segmentation results for a batch of input image
:param image_list: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format
:return list of SegmentationResult
"""
return self._model.batch_predict(image_list)
@property @property
def apply_softmax(self): def preprocessor(self):
"""Atrribute of PaddleSeg model. Stating Whether applying softmax operator in the postprocess, default value is False """Get PaddleSegPreprocessor object of the loaded model
:return PaddleSegPreprocessor
:return: value of apply_softmax(bool)
""" """
return self._model.apply_softmax return self._model.preprocessor
@apply_softmax.setter @property
def apply_softmax(self, value): def postprocessor(self):
"""Set attribute apply_softmax of PaddleSeg model. """Get PaddleSegPostprocessor object of the loaded model
:return PaddleSegPostprocessor
:param value: (bool)The value to set apply_softmax
""" """
assert isinstance( return self._model.postprocessor
value,
bool), "The value to set `apply_softmax` must be type of bool."
self._model.apply_softmax = value class PaddleSegPreprocessor:
def __init__(self, config_file):
"""Create a preprocessor for PaddleSegModel from configuration file
:param config_file: (str)Path of configuration file, e.g ppliteseg/deploy.yaml
"""
self._preprocessor = C.vision.segmentation.PaddleSegPreprocessor(
config_file)
def run(self, input_ims):
"""Preprocess input images for PaddleSegModel
:param: input_ims: (list of numpy.ndarray)The input image
:return: list of FDTensor
"""
return self._preprocessor.run(input_ims)
def disable_normalize_and_permute(self):
"""To disable normalize and hwc2chw in preprocessing step.
"""
return self._preprocessor.disable_normalize_and_permute()
@property @property
def is_vertical_screen(self): def is_vertical_screen(self):
@@ -77,7 +97,7 @@ class PaddleSegModel(FastDeployModel):
:return: value of is_vertical_screen(bool) :return: value of is_vertical_screen(bool)
""" """
return self._model.is_vertical_screen return self._preprocessor.is_vertical_screen
@is_vertical_screen.setter @is_vertical_screen.setter
def is_vertical_screen(self, value): def is_vertical_screen(self, value):
@@ -88,4 +108,59 @@ class PaddleSegModel(FastDeployModel):
assert isinstance( assert isinstance(
value, value,
bool), "The value to set `is_vertical_screen` must be type of bool." bool), "The value to set `is_vertical_screen` must be type of bool."
self._model.is_vertical_screen = value self._preprocessor.is_vertical_screen = value
class PaddleSegPostprocessor:
def __init__(self, config_file):
"""Create a postprocessor for PaddleSegModel from configuration file
:param config_file: (str)Path of configuration file, e.g ppliteseg/deploy.yaml
"""
self._postprocessor = C.vision.segmentation.PaddleSegPostprocessor(
config_file)
def run(self, runtime_results, imgs_info):
"""Postprocess the runtime results for PaddleSegModel
:param: runtime_results: (list of FDTensor)The output FDTensor results from runtime
:param: imgs_info: The original input images shape info map, key is "shape_info", value is [[image_height, image_width]]
:return: list of SegmentationResult(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
"""
return self._postprocessor.run(runtime_results, imgs_info)
@property
def apply_softmax(self):
"""Atrribute of PaddleSeg model. Stating Whether applying softmax operator in the postprocess, default value is False
:return: value of apply_softmax(bool)
"""
return self._postprocessor.apply_softmax
@apply_softmax.setter
def apply_softmax(self, value):
"""Set attribute apply_softmax of PaddleSeg model.
:param value: (bool)The value to set apply_softmax
"""
assert isinstance(
value,
bool), "The value to set `apply_softmax` must be type of bool."
self._postprocessor.apply_softmax = value
@property
def store_score_map(self):
"""Atrribute of PaddleSeg model. Stating Whether storing score map in the SegmentationResult, default value is False
:return: value of store_score_map(bool)
"""
return self._postprocessor.store_score_map
@store_score_map.setter
def store_score_map(self, value):
"""Set attribute store_score_map of PaddleSeg model.
:param value: (bool)The value to set store_score_map
"""
assert isinstance(
value,
bool), "The value to set `store_score_map` must be type of bool."
self._postprocessor.store_score_map = value

View File

@@ -312,8 +312,10 @@ TEST(fastdeploy, reduce_argmax) {
std::vector<int> inputs = {2, 4, 3, 7, 1, 5}; std::vector<int> inputs = {2, 4, 3, 7, 1, 5};
std::vector<int64_t> expected_result_axis0 = {1, 0, 1}; std::vector<int64_t> expected_result_axis0 = {1, 0, 1};
std::vector<uint8_t> expected_result_uint8_axis0 = {1, 0, 1};
std::vector<int64_t> expected_result_axis1 = {1, 0}; std::vector<int64_t> expected_result_axis1 = {1, 0};
std::vector<int64_t> expected_result_noaxis = {3}; std::vector<int64_t> expected_result_noaxis = {3};
input.SetExternalData({2, 3}, FDDataType::INT32, inputs.data()); input.SetExternalData({2, 3}, FDDataType::INT32, inputs.data());
// axis = 0, output_dtype = FDDataType::INT64, keep_dim = false, flatten = // axis = 0, output_dtype = FDDataType::INT64, keep_dim = false, flatten =
@@ -323,6 +325,13 @@ TEST(fastdeploy, reduce_argmax) {
check_data(reinterpret_cast<const int64_t*>(output.Data()), check_data(reinterpret_cast<const int64_t*>(output.Data()),
expected_result_axis0.data(), expected_result_axis0.size()); expected_result_axis0.data(), expected_result_axis0.size());
// axis = 0, output_dtype = FDDataType::UINT8, keep_dim = false, flatten =
// false
ArgMax(input, &output, 0, FDDataType::UINT8);
check_shape(output.shape, {3});
check_data(reinterpret_cast<const uint8_t*>(output.Data()),
expected_result_uint8_axis0.data(), expected_result_axis0.size());
// axis = -1, output_dtype = FDDataType::INT64, keep_dim = false, flatten = // axis = -1, output_dtype = FDDataType::INT64, keep_dim = false, flatten =
// false // false
ArgMax(input, &output, -1); ArgMax(input, &output, -1);

View File

@@ -0,0 +1,160 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fastdeploy as fd
from fastdeploy import ModelFormat
import cv2
import os
import numpy as np
import runtime_config as rc
import pickle
def test_segmentation_ppliteseg():
pp_liteseg_model_url = "https://bj.bcebos.com/fastdeploy/tests/PP_LiteSeg_T_STDC1_cityscapes_without_argmax_test.tgz"
fd.download_and_decompress(pp_liteseg_model_url, "resources")
model_path = "./resources/PP_LiteSeg_T_STDC1_cityscapes_without_argmax_test"
# 配置runtime加载模型
runtime_option = fd.RuntimeOption()
model_file = os.path.join(model_path, "model.pdmodel")
params_file = os.path.join(model_path, "model.pdiparams")
config_file = os.path.join(model_path, "deploy.yaml")
image_file_1 = os.path.join(model_path, "cityscapes_demo_1.png")
image_file_2 = os.path.join(model_path, "cityscapes_demo_2.png")
result_file_1 = os.path.join(model_path, "ppliteseg_result1.pkl")
result_file_2 = os.path.join(model_path, "ppliteseg_result2.pkl")
model = fd.vision.segmentation.PaddleSegModel(
model_file, params_file, config_file, runtime_option=rc.test_option)
model.postprocessor.store_score_map = True
im1 = cv2.imread(image_file_1)
im2 = cv2.imread(image_file_2)
with open(result_file_1, "rb") as f:
expect1 = pickle.load(f)
with open(result_file_2, "rb") as f:
expect2 = pickle.load(f)
for i in range(3):
# test single predict
result1 = model.predict(im1)
result2 = model.predict(im2)
diff_label_map_1 = np.fabs(
np.array(result1.label_map) - np.array(expect1["label_map"]))
diff_label_map_2 = np.fabs(
np.array(result2.label_map) - np.array(expect2["label_map"]))
diff_score_map_1 = np.fabs(
np.array(result1.score_map) - np.array(expect1["score_map"]))
diff_score_map_2 = np.fabs(
np.array(result2.score_map) - np.array(expect2["score_map"]))
thres = 1e-05
assert diff_label_map_1.max(
) < thres, "The label_map diff is %f, which is bigger than %f" % (
diff_label_map_1.max(), thres)
assert diff_score_map_1.max(
) < thres, "The score map diff is %f, which is bigger than %f" % (
diff_score_map_1.max(), thres)
assert diff_label_map_2.max(
) < thres, "The label_map diff is %f, which is bigger than %f" % (
diff_label_map_2.max(), thres)
assert diff_score_map_2.max(
) < thres, "The score map diff is %f, which is bigger than %f" % (
diff_score_map_2.max(), thres)
print("Single image No diff")
# test batch predict
results = model.batch_predict([im1, im2])
result1 = results[0]
result2 = results[1]
diff_label_map_1 = np.fabs(
np.array(result1.label_map) - np.array(expect1["label_map"]))
diff_label_map_2 = np.fabs(
np.array(result2.label_map) - np.array(expect2["label_map"]))
diff_score_map_1 = np.fabs(
np.array(result1.score_map) - np.array(expect1["score_map"]))
diff_score_map_2 = np.fabs(
np.array(result2.score_map) - np.array(expect2["score_map"]))
thres = 1e-05
assert diff_label_map_1.max(
) < thres, "The label_map diff is %f, which is bigger than %f" % (
diff_label_map_1.max(), thres)
assert diff_score_map_1.max(
) < thres, "The score map diff is %f, which is bigger than %f" % (
diff_score_map_1.max(), thres)
assert diff_label_map_2.max(
) < thres, "The label_map diff is %f, which is bigger than %f" % (
diff_label_map_2.max(), thres)
assert diff_score_map_2.max(
) < thres, "The score map diff is %f, which is bigger than %f" % (
diff_score_map_2.max(), thres)
print("Batch images No diff")
def test_segmentation_ppliteseg_runtime():
pp_liteseg_model_url = "https://bj.bcebos.com/fastdeploy/tests/PP_LiteSeg_T_STDC1_cityscapes_without_argmax_test.tgz"
fd.download_and_decompress(pp_liteseg_model_url, "resources")
model_path = "./resources/PP_LiteSeg_T_STDC1_cityscapes_without_argmax_test"
# 配置runtime加载模型
runtime_option = fd.RuntimeOption()
model_file = os.path.join(model_path, "model.pdmodel")
params_file = os.path.join(model_path, "model.pdiparams")
config_file = os.path.join(model_path, "deploy.yaml")
image_file_1 = os.path.join(model_path, "cityscapes_demo_1.png")
result_file_1 = os.path.join(model_path, "ppliteseg_result1.pkl")
preprocessor = fd.vision.segmentation.PaddleSegPreprocessor(config_file)
postprocessor = fd.vision.segmentation.PaddleSegPostprocessor(config_file)
postprocessor.store_score_map = True
rc.test_option.set_model_path(
model_file, params_file, model_format=ModelFormat.PADDLE)
rc.test_option.use_paddle_backend()
runtime = fd.Runtime(rc.test_option)
with open(result_file_1, "rb") as f:
expect1 = pickle.load(f)
im1 = cv2.imread(image_file_1)
print(image_file_1)
for i in range(3):
# test runtime
input_tensors, ims_info = preprocessor.run([im1])
output_tensors = runtime.infer({"x": input_tensors[0]})
results = postprocessor.run(output_tensors, ims_info)
result1 = results[0]
diff_label_map_1 = np.fabs(
np.array(result1.label_map) - np.array(expect1["label_map"]))
diff_score_map_1 = np.fabs(
np.array(result1.score_map) - np.array(expect1["score_map"]))
thres = 1e-05
assert diff_label_map_1.max(
) < thres, "The label_map diff is %f, which is bigger than %f" % (
diff_label_map_1.max(), thres)
assert diff_score_map_1.max(
) < thres, "The score map diff is %f, which is bigger than %f" % (
diff_score_map_1.max(), thres)
print("Runtime images No diff")
if __name__ == "__main__":
test_segmentation_ppliteseg()
test_segmentation_ppliteseg_runtime()