diff --git a/docs/api/vision_results/segmentation_result.md b/docs/api/vision_results/segmentation_result.md index d684d1a23..81a48c323 100644 --- a/docs/api/vision_results/segmentation_result.md +++ b/docs/api/vision_results/segmentation_result.md @@ -18,7 +18,7 @@ struct DetectionResult { ``` - **label_map**: 成员变量,表示单张图片每个像素点的分割类别,`label_map.size()`表示图片像素点的个数 -- **score_map**: 成员变量,与label_map一一对应的所预测的分割类别概率值(当导出模型时指定`without_argmax`)或者经过softmax归一化化后的概率值(当导出模型时指定`without_argmax`以及`with_softmax`或者导出模型时指定`without_argmax`同时模型初始化的时候设置模型[类成员属性](../../../examples/vision/segmentation/paddleseg/cpp/)`with_softmax=True`) +- **score_map**: 成员变量,与label_map一一对应的所预测的分割类别概率值(当导出模型时指定`--output_op argmax`)或者经过softmax归一化化后的概率值(当导出模型时指定`--output_op softmax`或者导出模型时指定`--output_op none`同时模型初始化的时候设置模型[类成员属性](../../../examples/vision/segmentation/paddleseg/cpp/)`apply_softmax=True`) - **shape**: 成员变量,表示输出图片的shape,为H\*W - **Clear()**: 成员函数,用于清除结构体中存储的结果 - **Str()**: 成员函数,将结构体中的信息以字符串形式输出(用于Debug) @@ -28,5 +28,5 @@ struct DetectionResult { `fastdeploy.vision.SegmentationResult` - **label_map**(list of int): 成员变量,表示单张图片每个像素点的分割类别 -- **score_map**(list of float): 成员变量,与label_map一一对应的所预测的分割类别概率值(当导出模型时指定`without_argmax`)或者经过softmax归一化化后的概率值(当导出模型时指定`without_argmax`以及`with_softmax`或者导出模型时指定`without_argmax`同时模型初始化的时候设置模型[类成员属性](../../../examples/vision/segmentation/paddleseg/python/)`with_softmax=true`) +- **score_map**(list of float): 成员变量,与label_map一一对应的所预测的分割类别概率值(当导出模型时指定`--output_op argmax`)或者经过softmax归一化化后的概率值(当导出模型时指定`--output_op softmax`或者导出模型时指定`--output_op none`同时模型初始化的时候设置模型[类成员属性](../../../examples/vision/segmentation/paddleseg/python/)`apply_softmax=true`) - **shape**(list of int): 成员变量,表示输出图片的shape,为H\*W diff --git a/examples/vision/segmentation/paddleseg/README.md b/examples/vision/segmentation/paddleseg/README.md index ee60f1e1f..758a32ef1 100644 --- a/examples/vision/segmentation/paddleseg/README.md +++ b/examples/vision/segmentation/paddleseg/README.md @@ -21,16 +21,19 @@ PaddleSeg模型导出,请参考其文档说明[模型导出](https://github.co ## 下载预训练模型 -为了方便开发者的测试,下面提供了PaddleSeg导出的部分模型(导出方式为:**不指定**`input_shape`和`with_softmax`,**指定**`without_argmax`),开发者可直接下载使用。 +为了方便开发者的测试,下面提供了PaddleSeg导出的部分模型(导出方式为:**不指定**`--input_shape`,**指定**`--output_op none`),开发者可直接下载使用。 | 模型 | 参数文件大小 |输入Shape | mIoU | mIoU (flip) | mIoU (ms+flip) | |:---------------------------------------------------------------- |:----- |:----- | :----- | :----- | :----- | | [Unet-cityscapes](https://bj.bcebos.com/paddlehub/fastdeploy/Unet_cityscapes_without_argmax_infer.tgz) | 52MB | 1024x512 | 65.00% | 66.02% | 66.89% | | [PP-LiteSeg-T(STDC1)-cityscapes](https://bj.bcebos.com/paddlehub/fastdeploy/PP_LiteSeg_T_STDC1_cityscapes_without_argmax_infer.tgz) | 31MB | 1024x512 |73.10% | 73.89% | - | -| [PP-HumanSegV1-Lite](https://bj.bcebos.com/paddlehub/fastdeploy/PP_HumanSegV1_Lite_infer.tgz) | 543KB | 192x192 | 86.2% | - | - | -| [PP-HumanSegV1-Server](https://bj.bcebos.com/paddlehub/fastdeploy/PP_HumanSegV1_Server_infer.tgz) | 103MB | 512x512 | 96.47% | - | - | +| [PP-HumanSegV1-Lite(通用人像分割模型)](https://bj.bcebos.com/paddlehub/fastdeploy/PP_HumanSegV1_Lite_infer.tgz) | 543KB | 192x192 | 86.2% | - | - | +| [PP-HumanSegV2-Lite(通用人像分割模型)](https://bj.bcebos.com/paddle2onnx/libs/PP_HumanSegV2_Lite_192x192_infer.tgz) | 12MB | 192x192 | 92.52% | - | - | +| [PP-HumanSegV2-Mobile(通用人像分割模型)](https://bj.bcebos.com/paddlehub/fastdeploy/PP_HumanSegV2_Mobile_192x192_infer.tgz) | 29MB | 192x192 | 93.13% | - | - | +| [PP-HumanSegV1-Server(通用人像分割模型)](https://bj.bcebos.com/paddlehub/fastdeploy/PP_HumanSegV1_Server_infer.tgz) | 103MB | 512x512 | 96.47% | - | - | +| [Portait-PP-HumanSegV2_Lite(肖像分割模型)](https://bj.bcebos.com/paddlehub/fastdeploy/Portrait_PP_HumanSegV2_Lite_256x144_infer.tgz) | 3.6M | 256x144 | 96.63% | - | - | | [FCN-HRNet-W18-cityscapes](https://bj.bcebos.com/paddlehub/fastdeploy/FCN_HRNet_W18_cityscapes_without_argmax_infer.tgz) | 37MB | 1024x512 | 78.97% | 79.49% | 79.74% | -| [Deeplabv3-ResNet50-OS8-cityscapes](https://bj.bcebos.com/paddlehub/fastdeploy/Deeplabv3_ResNet50_OS8_cityscapes_without_argmax_infer.tgz) | 150MB | 1024x512 | 79.90% | 80.22% | 80.47% | +| [Deeplabv3-ResNet101-OS8-cityscapes](https://bj.bcebos.com/paddlehub/fastdeploy/Deeplabv3_ResNet101_OS8_cityscapes_without_argmax_infer.tgz) | 150MB | 1024x512 | 79.90% | 80.22% | 80.47% | ## 详细部署文档 diff --git a/examples/vision/segmentation/paddleseg/cpp/README.md b/examples/vision/segmentation/paddleseg/cpp/README.md index 16f267a28..5ff71e441 100644 --- a/examples/vision/segmentation/paddleseg/cpp/README.md +++ b/examples/vision/segmentation/paddleseg/cpp/README.md @@ -7,7 +7,7 @@ - 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/environment.md) - 2. 根据开发环境,下载预编译部署库和samples代码,参考[FastDeploy预编译库](../../../../../docs/quick_start) -以Linux上CPU推理为例,在本目录执行如下命令即可完成编译测试 +以Linux上推理为例,在本目录执行如下命令即可完成编译测试 ```bash wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-gpu-0.2.1.tgz @@ -25,16 +25,16 @@ wget https://paddleseg.bj.bcebos.com/dygraph/demo/cityscapes_demo.png # CPU推理 -./infer_demo Unet_cityscapes_without_argmax_infer Unet_cityscapes_without_argmax_infer cityscapes_demo.png 0 +./infer_demo Unet_cityscapes_without_argmax_infer cityscapes_demo.png 0 # GPU推理 -./infer_demo Unet_cityscapes_without_argmax_infer Unet_cityscapes_without_argmax_infer cityscapes_demo.png 1 +./infer_demo Unet_cityscapes_without_argmax_infer cityscapes_demo.png 1 # GPU上TensorRT推理 -./infer_demo Unet_cityscapes_without_argmax_infer Unet_cityscapes_without_argmax_infer cityscapes_demo.png 2 +./infer_demo Unet_cityscapes_without_argmax_infer cityscapes_demo.png 2 ``` 运行完成可视化结果如下图所示
- +
以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考: @@ -80,10 +80,10 @@ PaddleSegModel模型加载和初始化,其中model_file为导出的Paddle模 #### 预处理参数 用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果 -> > * **is_vertical_screen**(bool): PP-HumanSeg系列模型通过设置此参数为`True`表明输入图片是竖屏,即height大于width的图片 +> > * **is_vertical_screen**(bool): PP-HumanSeg系列模型通过设置此参数为`true`表明输入图片是竖屏,即height大于width的图片 #### 后处理参数 -> > * **with_softmax**(bool): 当模型导出时,并未指定`with_softmax`参数,可通过此设置此参数为`True`,将预测的输出分割标签(label_map)对应的概率结果(score_map)做softmax归一化处理 +> > * **appy_softmax**(bool): 当模型导出时,并未指定`apply_softmax`参数,可通过此设置此参数为`true`,将预测的输出分割标签(label_map)对应的概率结果(score_map)做softmax归一化处理 - [模型介绍](../../) - [Python部署](../python) diff --git a/examples/vision/segmentation/paddleseg/cpp/infer.cc b/examples/vision/segmentation/paddleseg/cpp/infer.cc index 1f20ea9c8..41a90201a 100644 --- a/examples/vision/segmentation/paddleseg/cpp/infer.cc +++ b/examples/vision/segmentation/paddleseg/cpp/infer.cc @@ -26,6 +26,7 @@ void CpuInfer(const std::string& model_dir, const std::string& image_file) { auto config_file = model_dir + sep + "deploy.yaml"; auto model = fastdeploy::vision::segmentation::PaddleSegModel( model_file, params_file, config_file); + if (!model.Initialized()) { std::cerr << "Failed to initialize." << std::endl; return; @@ -40,6 +41,7 @@ void CpuInfer(const std::string& model_dir, const std::string& image_file) { return; } + std::cout << res.Str() << std::endl; auto vis_im = fastdeploy::vision::Visualize::VisSegmentation(im_bak, res); cv::imwrite("vis_result.jpg", vis_im); std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; @@ -54,6 +56,7 @@ void GpuInfer(const std::string& model_dir, const std::string& image_file) { option.UseGpu(); auto model = fastdeploy::vision::segmentation::PaddleSegModel( model_file, params_file, config_file, option); + if (!model.Initialized()) { std::cerr << "Failed to initialize." << std::endl; return; @@ -68,6 +71,7 @@ void GpuInfer(const std::string& model_dir, const std::string& image_file) { return; } + std::cout << res.Str() << std::endl; auto vis_im = fastdeploy::vision::Visualize::VisSegmentation(im_bak, res); cv::imwrite("vis_result.jpg", vis_im); std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; @@ -85,6 +89,7 @@ void TrtInfer(const std::string& model_dir, const std::string& image_file) { {1, 3, 2048, 2048}); auto model = fastdeploy::vision::segmentation::PaddleSegModel( model_file, params_file, config_file, option); + if (!model.Initialized()) { std::cerr << "Failed to initialize." << std::endl; return; @@ -99,6 +104,7 @@ void TrtInfer(const std::string& model_dir, const std::string& image_file) { return; } + std::cout << res.Str() << std::endl; auto vis_im = fastdeploy::vision::Visualize::VisSegmentation(im_bak, res); cv::imwrite("vis_result.jpg", vis_im); std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; diff --git a/examples/vision/segmentation/paddleseg/python/README.md b/examples/vision/segmentation/paddleseg/python/README.md index 46fce690a..135302914 100644 --- a/examples/vision/segmentation/paddleseg/python/README.md +++ b/examples/vision/segmentation/paddleseg/python/README.md @@ -27,7 +27,7 @@ python infer.py --model Unet_cityscapes_without_argmax_infer --image cityscapes_ 运行完成可视化结果如下图所示
- +
## PaddleSegModel Python接口 @@ -69,7 +69,7 @@ PaddleSeg模型加载和初始化,其中model_file, params_file以及config_fi > > * **is_vertical_screen**(bool): PP-HumanSeg系列模型通过设置此参数为`true`表明输入图片是竖屏,即height大于width的图片 #### 后处理参数 -> > * **with_softmax**(bool): 当模型导出时,并未指定`with_softmax`参数,可通过此设置此参数为`true`,将预测的输出分割标签(label_map)对应的概率结果(score_map)做softmax归一化处理 +> > * **apply_softmax**(bool): 当模型导出时,并未指定`apply_softmax`参数,可通过此设置此参数为`true`,将预测的输出分割标签(label_map)对应的概率结果(score_map)做softmax归一化处理 ## 其它文档 diff --git a/fastdeploy/vision/common/processors/mat.cc b/fastdeploy/vision/common/processors/mat.cc index 580066229..6e5917a59 100644 --- a/fastdeploy/vision/common/processors/mat.cc +++ b/fastdeploy/vision/common/processors/mat.cc @@ -116,5 +116,56 @@ FDDataType Mat::Type() { } } +Mat CreateFromTensor(const FDTensor& tensor) { + int type = tensor.dtype; + cv::Mat temp_mat; + FDASSERT(tensor.shape.size() == 3, + "When create FD Mat from tensor, tensor shape should be 3-Dim, HWC " + "layout"); + int64_t height = tensor.shape[0]; + int64_t width = tensor.shape[1]; + int64_t channel = tensor.shape[2]; + switch (type) { + case FDDataType::UINT8: + temp_mat = cv::Mat(height, width, CV_8UC(channel), + const_cast(tensor.Data())); + break; + + case FDDataType::INT8: + temp_mat = cv::Mat(height, width, CV_8SC(channel), + const_cast(tensor.Data())); + break; + + case FDDataType::INT16: + temp_mat = cv::Mat(height, width, CV_16SC(channel), + const_cast(tensor.Data())); + break; + + case FDDataType::INT32: + temp_mat = cv::Mat(height, width, CV_32SC(channel), + const_cast(tensor.Data())); + break; + + case FDDataType::FP32: + temp_mat = cv::Mat(height, width, CV_32FC(channel), + const_cast(tensor.Data())); + break; + + case FDDataType::FP64: + temp_mat = cv::Mat(height, width, CV_64FC(channel), + const_cast(tensor.Data())); + break; + + default: + FDASSERT( + false, + "Tensor type %d is not supported While calling CreateFromTensor.", + type); + break; + } + Mat mat = Mat(temp_mat); + return mat; +} + } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/mat.h b/fastdeploy/vision/common/processors/mat.h index cf4736238..14acfd3df 100644 --- a/fastdeploy/vision/common/processors/mat.h +++ b/fastdeploy/vision/common/processors/mat.h @@ -76,5 +76,7 @@ struct FASTDEPLOY_DECL Mat { Device device = Device::CPU; }; +Mat CreateFromTensor(const FDTensor& tensor); + } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/segmentation/ppseg/FDTensor2CVMat.cc b/fastdeploy/vision/segmentation/ppseg/FDTensor2CVMat.cc deleted file mode 100644 index 3e7a58366..000000000 --- a/fastdeploy/vision/segmentation/ppseg/FDTensor2CVMat.cc +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "fastdeploy/vision/segmentation/ppseg/model.h" - -namespace fastdeploy { -namespace vision { -namespace segmentation { - -void FDTensor2FP32CVMat(cv::Mat& mat, FDTensor& infer_result, - bool contain_score_map) { - // output with argmax channel is 1 - int channel = 1; - int height = infer_result.shape[1]; - int width = infer_result.shape[2]; - - if (contain_score_map) { - // output without argmax and convent to NHWC - channel = infer_result.shape[3]; - } - // create FP32 cvmat - if (infer_result.dtype == FDDataType::INT64) { - FDWARNING << "The PaddleSeg model is exported with argmax. Inference " - "result type is " + - Str(infer_result.dtype) + - ". If you want the edge of segmentation image more " - "smoother. Please export model with --without_argmax " - "--with_softmax." - << std::endl; - int64_t chw = channel * height * width; - int64_t* infer_result_buffer = static_cast(infer_result.Data()); - std::vector float_result_buffer(chw); - mat = cv::Mat(height, width, CV_32FC(channel)); - int index = 0; - for (int i = 0; i < height; i++) { - for (int j = 0; j < width; j++) { - mat.at(i, j) = - static_cast(infer_result_buffer[index++]); - } - } - } else if (infer_result.dtype == FDDataType::FP32) { - mat = cv::Mat(height, width, CV_32FC(channel), infer_result.Data()); - } -} - -} // namespace segmentation -} // namespace vision -} // namespace fastdeploy diff --git a/fastdeploy/vision/segmentation/ppseg/model.cc b/fastdeploy/vision/segmentation/ppseg/model.cc index 918bd75d2..b413351e8 100644 --- a/fastdeploy/vision/segmentation/ppseg/model.cc +++ b/fastdeploy/vision/segmentation/ppseg/model.cc @@ -14,7 +14,7 @@ PaddleSegModel::PaddleSegModel(const std::string& model_file, const ModelFormat& model_format) { config_file_ = config_file; valid_cpu_backends = {Backend::OPENVINO, Backend::PDINFER}; - valid_gpu_backends = {Backend::PDINFER, Backend::TRT}; + valid_gpu_backends = {Backend::PDINFER}; runtime_option = custom_option; runtime_option.model_format = model_format; runtime_option.model_file = model_file; @@ -79,12 +79,32 @@ bool PaddleSegModel::BuildPreprocessPipelineFromConfig() { } processors_.push_back(std::make_shared()); } + if (cfg["Deploy"]["output_op"]) { + std::string output_op = cfg["Deploy"]["output_op"].as(); + if (output_op == "softmax") { + is_with_softmax = true; + is_with_argmax = false; + } else if (output_op == "argmax") { + is_with_softmax = false; + is_with_argmax = true; + } else if (output_op == "none") { + is_with_softmax = false; + is_with_argmax = false; + } else { + FDERROR << "Unexcepted output_op operator in deploy.yml: " << output_op + << "." << std::endl; + } + } + if (is_with_argmax) { + FDWARNING << "The PaddleSeg model is exported with argmax." + << " If you want the edge of segmentation image more" + << " smoother. Please export model with parameters" + << " --output_op softmax." << std::endl; + } return true; } -bool PaddleSegModel::Preprocess( - Mat* mat, FDTensor* output, - std::map>* im_info) { +bool PaddleSegModel::Preprocess(Mat* mat, FDTensor* output) { for (size_t i = 0; i < processors_.size(); ++i) { if (processors_[i]->Name().compare("Resize") == 0) { auto processor = dynamic_cast(processors_[i].get()); @@ -105,10 +125,6 @@ bool PaddleSegModel::Preprocess( } } - // Record output shape of preprocessed image - (*im_info)["output_shape"] = {static_cast(mat->Height()), - static_cast(mat->Width())}; - mat->ShareWithTensor(output); output->shape.insert(output->shape.begin(), 1); output->name = InputInfoOfRuntime(0).name; @@ -116,13 +132,15 @@ bool PaddleSegModel::Preprocess( } bool PaddleSegModel::Postprocess( - FDTensor& infer_result, SegmentationResult* result, - std::map>* im_info) { + FDTensor* infer_result, SegmentationResult* result, + const std::map>& im_info) { // PaddleSeg has three types of inference output: - // 1. output with argmax and without softmax. 3-D matrix CHW, Channel + // 1. output with argmax and without softmax. 3-D matrix N(C)HW, Channel // always 1, the element in matrix is classified label_id INT64 Type. - // 2. output without argmax and without softmax. 4-D matrix NCHW, N always - // 1, Channel is the num of classes. The element is the logits of classes + // 2. output without argmax and without softmax. 4-D matrix NCHW, N(batch) + // always + // 1(only support batch size 1), Channel is the num of classes. The + // element is the logits of classes // FP32 // 3. output without argmax and with softmax. 4-D matrix NCHW, the result // of 2 with softmax layer @@ -130,59 +148,117 @@ bool PaddleSegModel::Postprocess( // 1. label_map // 2. score_map(optional) // 3. shape: 2-D HW - FDASSERT(infer_result.dtype == FDDataType::INT64 || - infer_result.dtype == FDDataType::FP32, - "Require the data type of output is int64 or fp32, but now it's %s.", - Str(infer_result.dtype).c_str()); + FDASSERT(infer_result->dtype == FDDataType::INT64 || + infer_result->dtype == FDDataType::FP32 || + infer_result->dtype == FDDataType::INT32, + "Require the data type of output is int64, fp32 or int32, but now " + "it's %s.", + Str(infer_result->dtype).c_str()); result->Clear(); + FDASSERT(infer_result->shape[0] == 1, "Only support batch size = 1."); - if (infer_result.shape.size() == 4) { - FDASSERT(infer_result.shape[0] == 1, "Only support batch size = 1."); + int64_t batch = infer_result->shape[0]; + int64_t channel = 0; + int64_t height = 0; + int64_t width = 0; + + if (is_with_argmax) { + channel = 1; + height = infer_result->shape[1]; + width = infer_result->shape[2]; + } else { + channel = infer_result->shape[1]; + height = infer_result->shape[2]; + width = infer_result->shape[3]; + } + int64_t chw = channel * height * width; + + if (!is_with_softmax && apply_softmax) { + Softmax(*infer_result, infer_result, 1); + } + + if (!is_with_argmax) { // output without argmax result->contain_score_map = true; - utils::NCHW2NHWC(infer_result); + + std::vector dim{0, 2, 3, 1}; + Transpose(*infer_result, infer_result, dim); } + // batch always 1, so ignore + infer_result->shape = {height, width, channel}; // for resize mat below FDTensor new_infer_result; Mat* mat = nullptr; + std::vector* fp32_result_buffer = nullptr; if (is_resized) { - cv::Mat temp_mat; - FDTensor2FP32CVMat(temp_mat, infer_result, result->contain_score_map); - - // original image shape - auto iter_ipt = (*im_info).find("input_shape"); - FDASSERT(iter_ipt != im_info->end(), + if (infer_result->dtype == FDDataType::INT64 || + infer_result->dtype == FDDataType::INT32) { + if (infer_result->dtype == FDDataType::INT64) { + int64_t* infer_result_buffer = + static_cast(infer_result->Data()); + // cv::resize don't support `CV_8S` or `CV_32S` + // refer to https://github.com/opencv/opencv/issues/20991 + // https://github.com/opencv/opencv/issues/7862 + fp32_result_buffer = new std::vector( + infer_result_buffer, infer_result_buffer + chw); + } + if (infer_result->dtype == FDDataType::INT32) { + int32_t* infer_result_buffer = + static_cast(infer_result->Data()); + // cv::resize don't support `CV_8S` or `CV_32S` + // refer to https://github.com/opencv/opencv/issues/20991 + // https://github.com/opencv/opencv/issues/7862 + fp32_result_buffer = new std::vector( + infer_result_buffer, infer_result_buffer + chw); + } + infer_result->Resize(infer_result->shape, FDDataType::FP32); + infer_result->SetExternalData( + infer_result->shape, FDDataType::FP32, + static_cast(fp32_result_buffer->data())); + } + auto iter_ipt = im_info.find("input_shape"); + FDASSERT(iter_ipt != im_info.end(), "Cannot find input_shape from im_info."); int ipt_h = iter_ipt->second[0]; int ipt_w = iter_ipt->second[1]; - - mat = new Mat(temp_mat); - - Resize::Run(mat, ipt_w, ipt_h, -1, -1, 1); + mat = new Mat(CreateFromTensor(*infer_result)); + Resize::Run(mat, ipt_w, ipt_h, -1.0f, -1.0f, 1); mat->ShareWithTensor(&new_infer_result); - new_infer_result.shape.insert(new_infer_result.shape.begin(), 1); result->shape = new_infer_result.shape; } else { - result->shape = infer_result.shape; + result->shape = infer_result->shape; } + // output shape is 2-D HW layout, so out_num = H * W int out_num = - std::accumulate(result->shape.begin(), result->shape.begin() + 3, 1, + std::accumulate(result->shape.begin(), result->shape.begin() + 2, 1, std::multiplies()); - // NCHW remove N or CHW remove C - result->shape.erase(result->shape.begin()); result->Resize(out_num); if (result->contain_score_map) { // output with label_map and score_map - float_t* infer_result_buffer = nullptr; - if (is_resized) { - infer_result_buffer = static_cast(new_infer_result.Data()); - } else { - infer_result_buffer = static_cast(infer_result.Data()); - } + int32_t* argmax_infer_result_buffer = nullptr; + float_t* score_infer_result_buffer = nullptr; + FDTensor argmax_infer_result; + FDTensor max_score_result; + std::vector reduce_dim{-1}; // argmax - utils::ArgmaxScoreMap(infer_result_buffer, result, with_softmax); - result->shape.erase(result->shape.begin() + 2); + if (is_resized) { + ArgMax(new_infer_result, &argmax_infer_result, -1, FDDataType::INT32); + Max(new_infer_result, &max_score_result, reduce_dim); + } else { + ArgMax(*infer_result, &argmax_infer_result, -1, FDDataType::INT32); + Max(*infer_result, &max_score_result, reduce_dim); + } + argmax_infer_result_buffer = + static_cast(argmax_infer_result.Data()); + score_infer_result_buffer = static_cast(max_score_result.Data()); + for (int i = 0; i < out_num; i++) { + result->label_map[i] = + static_cast(*(argmax_infer_result_buffer + i)); + } + std::memcpy(result->score_map.data(), score_infer_result_buffer, + out_num * sizeof(float_t)); + } else { // output only with label_map if (is_resized) { @@ -192,13 +268,27 @@ bool PaddleSegModel::Postprocess( result->label_map[i] = static_cast(*(infer_result_buffer + i)); } } else { - const int64_t* infer_result_buffer = - reinterpret_cast(infer_result.Data()); - for (int i = 0; i < out_num; i++) { - result->label_map[i] = static_cast(*(infer_result_buffer + i)); + if (infer_result->dtype == FDDataType::INT64) { + const int64_t* infer_result_buffer = + static_cast(infer_result->Data()); + for (int i = 0; i < out_num; i++) { + result->label_map[i] = + static_cast(*(infer_result_buffer + i)); + } + } + if (infer_result->dtype == FDDataType::INT32) { + const int32_t* infer_result_buffer = + static_cast(infer_result->Data()); + for (int i = 0; i < out_num; i++) { + result->label_map[i] = + static_cast(*(infer_result_buffer + i)); + } } } } + // HWC remove C + result->shape.erase(result->shape.begin() + 2); + delete fp32_result_buffer; delete mat; mat = nullptr; return true; @@ -213,10 +303,8 @@ bool PaddleSegModel::Predict(cv::Mat* im, SegmentationResult* result) { // Record the shape of image and the shape of preprocessed image im_info["input_shape"] = {static_cast(mat.Height()), static_cast(mat.Width())}; - im_info["output_shape"] = {static_cast(mat.Height()), - static_cast(mat.Width())}; - if (!Preprocess(&mat, &(processed_data[0]), &im_info)) { + if (!Preprocess(&mat, &(processed_data[0]))) { FDERROR << "Failed to preprocess input data while using model:" << ModelName() << "." << std::endl; return false; @@ -227,7 +315,7 @@ bool PaddleSegModel::Predict(cv::Mat* im, SegmentationResult* result) { << std::endl; return false; } - if (!Postprocess(infer_result[0], result, &im_info)) { + if (!Postprocess(&infer_result[0], result, im_info)) { FDERROR << "Failed to postprocess while using model:" << ModelName() << "." << std::endl; return false; diff --git a/fastdeploy/vision/segmentation/ppseg/model.h b/fastdeploy/vision/segmentation/ppseg/model.h index 06704d81a..0f649566e 100644 --- a/fastdeploy/vision/segmentation/ppseg/model.h +++ b/fastdeploy/vision/segmentation/ppseg/model.h @@ -18,7 +18,7 @@ class FASTDEPLOY_DECL PaddleSegModel : public FastDeployModel { virtual bool Predict(cv::Mat* im, SegmentationResult* result); - bool with_softmax = false; + bool apply_softmax = false; bool is_vertical_screen = false; @@ -27,20 +27,21 @@ class FASTDEPLOY_DECL PaddleSegModel : public FastDeployModel { bool BuildPreprocessPipelineFromConfig(); - bool Preprocess(Mat* mat, FDTensor* outputs, - std::map>* im_info); + bool Preprocess(Mat* mat, FDTensor* outputs); - bool Postprocess(FDTensor& infer_result, SegmentationResult* result, - std::map>* im_info); + bool Postprocess(FDTensor* infer_result, SegmentationResult* result, + const std::map>& im_info); bool is_resized = false; + bool is_with_softmax = false; + + bool is_with_argmax = true; + std::vector> processors_; std::string config_file_; }; -void FDTensor2FP32CVMat(cv::Mat& mat, FDTensor& infer_result, - bool contain_score_map); } // namespace segmentation } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/segmentation/ppseg/ppseg_pybind.cc b/fastdeploy/vision/segmentation/ppseg/ppseg_pybind.cc index c94b7fd19..51bec778f 100644 --- a/fastdeploy/vision/segmentation/ppseg/ppseg_pybind.cc +++ b/fastdeploy/vision/segmentation/ppseg/ppseg_pybind.cc @@ -27,8 +27,8 @@ void BindPPSeg(pybind11::module& m) { self.Predict(&mat, res); return res; }) - .def_readwrite("with_softmax", - &vision::segmentation::PaddleSegModel::with_softmax) + .def_readwrite("apply_softmax", + &vision::segmentation::PaddleSegModel::apply_softmax) .def_readwrite("is_vertical_screen", &vision::segmentation::PaddleSegModel::is_vertical_screen); } diff --git a/fastdeploy/vision/utils/utils.h b/fastdeploy/vision/utils/utils.h index 3fa36e7f1..89fc6a17e 100644 --- a/fastdeploy/vision/utils/utils.h +++ b/fastdeploy/vision/utils/utils.h @@ -20,6 +20,11 @@ #include "fastdeploy/utils/utils.h" #include "fastdeploy/vision/common/result.h" +// #include "unsupported/Eigen/CXX11/Tensor" +#include "fastdeploy/function/reduce.h" +#include "fastdeploy/function/softmax.h" +#include "fastdeploy/function/transpose.h" + namespace fastdeploy { namespace vision { namespace utils { @@ -51,70 +56,6 @@ std::vector TopKIndices(const T* array, int array_size, int topk) { return res; } -template -void ArgmaxScoreMap(T infer_result_buffer, SegmentationResult* result, - bool with_softmax) { - int64_t height = result->shape[0]; - int64_t width = result->shape[1]; - int64_t num_classes = result->shape[2]; - int index = 0; - for (size_t i = 0; i < height; ++i) { - for (size_t j = 0; j < width; ++j) { - int64_t s = (i * width + j) * num_classes; - T max_class_score = std::max_element( - infer_result_buffer + s, infer_result_buffer + s + num_classes); - int label_id = std::distance(infer_result_buffer + s, max_class_score); - if (label_id >= 255) { - FDWARNING << "label_id is stored by uint8_t, now the value is bigger " - "than 255, it's " - << static_cast(label_id) << "." << std::endl; - } - result->label_map[index] = static_cast(label_id); - - if (with_softmax) { - double_t total = 0; - for (int k = 0; k < num_classes; k++) { - total += exp(*(infer_result_buffer + s + k) - *max_class_score); - } - double_t softmax_class_score = 1 / total; - result->score_map[index] = static_cast(softmax_class_score); - - } else { - result->score_map[index] = static_cast(*max_class_score); - } - index++; - } - } -} - -template -void NCHW2NHWC(FDTensor& infer_result) { - T* infer_result_buffer = reinterpret_cast(infer_result.MutableData()); - int num = infer_result.shape[0]; - int channel = infer_result.shape[1]; - int height = infer_result.shape[2]; - int width = infer_result.shape[3]; - int chw = channel * height * width; - int wc = width * channel; - int wh = width * height; - std::vector hwc_data(chw); - int index = 0; - for (int n = 0; n < num; n++) { - for (int c = 0; c < channel; c++) { - for (int h = 0; h < height; h++) { - for (int w = 0; w < width; w++) { - hwc_data[n * chw + h * wc + w * channel + c] = - *(infer_result_buffer + index); - index++; - } - } - } - } - std::memcpy(infer_result.MutableData(), hwc_data.data(), - num * chw * sizeof(T)); - infer_result.shape = {num, height, width, channel}; -} - void NMS(DetectionResult* output, float iou_threshold = 0.5); void NMS(FaceDetectionResult* result, float iou_threshold = 0.5); diff --git a/python/fastdeploy/vision/segmentation/ppseg/__init__.py b/python/fastdeploy/vision/segmentation/ppseg/__init__.py index a7c15f50b..9a3a5b577 100644 --- a/python/fastdeploy/vision/segmentation/ppseg/__init__.py +++ b/python/fastdeploy/vision/segmentation/ppseg/__init__.py @@ -37,15 +37,15 @@ class PaddleSegModel(FastDeployModel): return self._model.predict(input_image) @property - def with_softmax(self): - return self._model.with_softmax + def apply_softmax(self): + return self._model.apply_softmax - @with_softmax.setter - def with_softmax(self, value): + @apply_softmax.setter + def apply_softmax(self, value): assert isinstance( value, - bool), "The value to set `with_softmax` must be type of bool." - self._model.with_softmax = value + bool), "The value to set `apply_softmax` must be type of bool." + self._model.apply_softmax = value @property def is_vertical_screen(self):