diff --git a/docs/api/vision_results/segmentation_result.md b/docs/api/vision_results/segmentation_result.md
index d684d1a23..81a48c323 100644
--- a/docs/api/vision_results/segmentation_result.md
+++ b/docs/api/vision_results/segmentation_result.md
@@ -18,7 +18,7 @@ struct DetectionResult {
```
- **label_map**: 成员变量,表示单张图片每个像素点的分割类别,`label_map.size()`表示图片像素点的个数
-- **score_map**: 成员变量,与label_map一一对应的所预测的分割类别概率值(当导出模型时指定`without_argmax`)或者经过softmax归一化化后的概率值(当导出模型时指定`without_argmax`以及`with_softmax`或者导出模型时指定`without_argmax`同时模型初始化的时候设置模型[类成员属性](../../../examples/vision/segmentation/paddleseg/cpp/)`with_softmax=True`)
+- **score_map**: 成员变量,与label_map一一对应的所预测的分割类别概率值(当导出模型时指定`--output_op argmax`)或者经过softmax归一化化后的概率值(当导出模型时指定`--output_op softmax`或者导出模型时指定`--output_op none`同时模型初始化的时候设置模型[类成员属性](../../../examples/vision/segmentation/paddleseg/cpp/)`apply_softmax=True`)
- **shape**: 成员变量,表示输出图片的shape,为H\*W
- **Clear()**: 成员函数,用于清除结构体中存储的结果
- **Str()**: 成员函数,将结构体中的信息以字符串形式输出(用于Debug)
@@ -28,5 +28,5 @@ struct DetectionResult {
`fastdeploy.vision.SegmentationResult`
- **label_map**(list of int): 成员变量,表示单张图片每个像素点的分割类别
-- **score_map**(list of float): 成员变量,与label_map一一对应的所预测的分割类别概率值(当导出模型时指定`without_argmax`)或者经过softmax归一化化后的概率值(当导出模型时指定`without_argmax`以及`with_softmax`或者导出模型时指定`without_argmax`同时模型初始化的时候设置模型[类成员属性](../../../examples/vision/segmentation/paddleseg/python/)`with_softmax=true`)
+- **score_map**(list of float): 成员变量,与label_map一一对应的所预测的分割类别概率值(当导出模型时指定`--output_op argmax`)或者经过softmax归一化化后的概率值(当导出模型时指定`--output_op softmax`或者导出模型时指定`--output_op none`同时模型初始化的时候设置模型[类成员属性](../../../examples/vision/segmentation/paddleseg/python/)`apply_softmax=true`)
- **shape**(list of int): 成员变量,表示输出图片的shape,为H\*W
diff --git a/examples/vision/segmentation/paddleseg/README.md b/examples/vision/segmentation/paddleseg/README.md
index ee60f1e1f..758a32ef1 100644
--- a/examples/vision/segmentation/paddleseg/README.md
+++ b/examples/vision/segmentation/paddleseg/README.md
@@ -21,16 +21,19 @@ PaddleSeg模型导出,请参考其文档说明[模型导出](https://github.co
## 下载预训练模型
-为了方便开发者的测试,下面提供了PaddleSeg导出的部分模型(导出方式为:**不指定**`input_shape`和`with_softmax`,**指定**`without_argmax`),开发者可直接下载使用。
+为了方便开发者的测试,下面提供了PaddleSeg导出的部分模型(导出方式为:**不指定**`--input_shape`,**指定**`--output_op none`),开发者可直接下载使用。
| 模型 | 参数文件大小 |输入Shape | mIoU | mIoU (flip) | mIoU (ms+flip) |
|:---------------------------------------------------------------- |:----- |:----- | :----- | :----- | :----- |
| [Unet-cityscapes](https://bj.bcebos.com/paddlehub/fastdeploy/Unet_cityscapes_without_argmax_infer.tgz) | 52MB | 1024x512 | 65.00% | 66.02% | 66.89% |
| [PP-LiteSeg-T(STDC1)-cityscapes](https://bj.bcebos.com/paddlehub/fastdeploy/PP_LiteSeg_T_STDC1_cityscapes_without_argmax_infer.tgz) | 31MB | 1024x512 |73.10% | 73.89% | - |
-| [PP-HumanSegV1-Lite](https://bj.bcebos.com/paddlehub/fastdeploy/PP_HumanSegV1_Lite_infer.tgz) | 543KB | 192x192 | 86.2% | - | - |
-| [PP-HumanSegV1-Server](https://bj.bcebos.com/paddlehub/fastdeploy/PP_HumanSegV1_Server_infer.tgz) | 103MB | 512x512 | 96.47% | - | - |
+| [PP-HumanSegV1-Lite(通用人像分割模型)](https://bj.bcebos.com/paddlehub/fastdeploy/PP_HumanSegV1_Lite_infer.tgz) | 543KB | 192x192 | 86.2% | - | - |
+| [PP-HumanSegV2-Lite(通用人像分割模型)](https://bj.bcebos.com/paddle2onnx/libs/PP_HumanSegV2_Lite_192x192_infer.tgz) | 12MB | 192x192 | 92.52% | - | - |
+| [PP-HumanSegV2-Mobile(通用人像分割模型)](https://bj.bcebos.com/paddlehub/fastdeploy/PP_HumanSegV2_Mobile_192x192_infer.tgz) | 29MB | 192x192 | 93.13% | - | - |
+| [PP-HumanSegV1-Server(通用人像分割模型)](https://bj.bcebos.com/paddlehub/fastdeploy/PP_HumanSegV1_Server_infer.tgz) | 103MB | 512x512 | 96.47% | - | - |
+| [Portait-PP-HumanSegV2_Lite(肖像分割模型)](https://bj.bcebos.com/paddlehub/fastdeploy/Portrait_PP_HumanSegV2_Lite_256x144_infer.tgz) | 3.6M | 256x144 | 96.63% | - | - |
| [FCN-HRNet-W18-cityscapes](https://bj.bcebos.com/paddlehub/fastdeploy/FCN_HRNet_W18_cityscapes_without_argmax_infer.tgz) | 37MB | 1024x512 | 78.97% | 79.49% | 79.74% |
-| [Deeplabv3-ResNet50-OS8-cityscapes](https://bj.bcebos.com/paddlehub/fastdeploy/Deeplabv3_ResNet50_OS8_cityscapes_without_argmax_infer.tgz) | 150MB | 1024x512 | 79.90% | 80.22% | 80.47% |
+| [Deeplabv3-ResNet101-OS8-cityscapes](https://bj.bcebos.com/paddlehub/fastdeploy/Deeplabv3_ResNet101_OS8_cityscapes_without_argmax_infer.tgz) | 150MB | 1024x512 | 79.90% | 80.22% | 80.47% |
## 详细部署文档
diff --git a/examples/vision/segmentation/paddleseg/cpp/README.md b/examples/vision/segmentation/paddleseg/cpp/README.md
index 16f267a28..5ff71e441 100644
--- a/examples/vision/segmentation/paddleseg/cpp/README.md
+++ b/examples/vision/segmentation/paddleseg/cpp/README.md
@@ -7,7 +7,7 @@
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/environment.md)
- 2. 根据开发环境,下载预编译部署库和samples代码,参考[FastDeploy预编译库](../../../../../docs/quick_start)
-以Linux上CPU推理为例,在本目录执行如下命令即可完成编译测试
+以Linux上推理为例,在本目录执行如下命令即可完成编译测试
```bash
wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-gpu-0.2.1.tgz
@@ -25,16 +25,16 @@ wget https://paddleseg.bj.bcebos.com/dygraph/demo/cityscapes_demo.png
# CPU推理
-./infer_demo Unet_cityscapes_without_argmax_infer Unet_cityscapes_without_argmax_infer cityscapes_demo.png 0
+./infer_demo Unet_cityscapes_without_argmax_infer cityscapes_demo.png 0
# GPU推理
-./infer_demo Unet_cityscapes_without_argmax_infer Unet_cityscapes_without_argmax_infer cityscapes_demo.png 1
+./infer_demo Unet_cityscapes_without_argmax_infer cityscapes_demo.png 1
# GPU上TensorRT推理
-./infer_demo Unet_cityscapes_without_argmax_infer Unet_cityscapes_without_argmax_infer cityscapes_demo.png 2
+./infer_demo Unet_cityscapes_without_argmax_infer cityscapes_demo.png 2
```
运行完成可视化结果如下图所示
-

+
以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考:
@@ -80,10 +80,10 @@ PaddleSegModel模型加载和初始化,其中model_file为导出的Paddle模
#### 预处理参数
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
-> > * **is_vertical_screen**(bool): PP-HumanSeg系列模型通过设置此参数为`True`表明输入图片是竖屏,即height大于width的图片
+> > * **is_vertical_screen**(bool): PP-HumanSeg系列模型通过设置此参数为`true`表明输入图片是竖屏,即height大于width的图片
#### 后处理参数
-> > * **with_softmax**(bool): 当模型导出时,并未指定`with_softmax`参数,可通过此设置此参数为`True`,将预测的输出分割标签(label_map)对应的概率结果(score_map)做softmax归一化处理
+> > * **appy_softmax**(bool): 当模型导出时,并未指定`apply_softmax`参数,可通过此设置此参数为`true`,将预测的输出分割标签(label_map)对应的概率结果(score_map)做softmax归一化处理
- [模型介绍](../../)
- [Python部署](../python)
diff --git a/examples/vision/segmentation/paddleseg/cpp/infer.cc b/examples/vision/segmentation/paddleseg/cpp/infer.cc
index 1f20ea9c8..41a90201a 100644
--- a/examples/vision/segmentation/paddleseg/cpp/infer.cc
+++ b/examples/vision/segmentation/paddleseg/cpp/infer.cc
@@ -26,6 +26,7 @@ void CpuInfer(const std::string& model_dir, const std::string& image_file) {
auto config_file = model_dir + sep + "deploy.yaml";
auto model = fastdeploy::vision::segmentation::PaddleSegModel(
model_file, params_file, config_file);
+
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
@@ -40,6 +41,7 @@ void CpuInfer(const std::string& model_dir, const std::string& image_file) {
return;
}
+ std::cout << res.Str() << std::endl;
auto vis_im = fastdeploy::vision::Visualize::VisSegmentation(im_bak, res);
cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
@@ -54,6 +56,7 @@ void GpuInfer(const std::string& model_dir, const std::string& image_file) {
option.UseGpu();
auto model = fastdeploy::vision::segmentation::PaddleSegModel(
model_file, params_file, config_file, option);
+
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
@@ -68,6 +71,7 @@ void GpuInfer(const std::string& model_dir, const std::string& image_file) {
return;
}
+ std::cout << res.Str() << std::endl;
auto vis_im = fastdeploy::vision::Visualize::VisSegmentation(im_bak, res);
cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
@@ -85,6 +89,7 @@ void TrtInfer(const std::string& model_dir, const std::string& image_file) {
{1, 3, 2048, 2048});
auto model = fastdeploy::vision::segmentation::PaddleSegModel(
model_file, params_file, config_file, option);
+
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
@@ -99,6 +104,7 @@ void TrtInfer(const std::string& model_dir, const std::string& image_file) {
return;
}
+ std::cout << res.Str() << std::endl;
auto vis_im = fastdeploy::vision::Visualize::VisSegmentation(im_bak, res);
cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
diff --git a/examples/vision/segmentation/paddleseg/python/README.md b/examples/vision/segmentation/paddleseg/python/README.md
index 46fce690a..135302914 100644
--- a/examples/vision/segmentation/paddleseg/python/README.md
+++ b/examples/vision/segmentation/paddleseg/python/README.md
@@ -27,7 +27,7 @@ python infer.py --model Unet_cityscapes_without_argmax_infer --image cityscapes_
运行完成可视化结果如下图所示
-

+
## PaddleSegModel Python接口
@@ -69,7 +69,7 @@ PaddleSeg模型加载和初始化,其中model_file, params_file以及config_fi
> > * **is_vertical_screen**(bool): PP-HumanSeg系列模型通过设置此参数为`true`表明输入图片是竖屏,即height大于width的图片
#### 后处理参数
-> > * **with_softmax**(bool): 当模型导出时,并未指定`with_softmax`参数,可通过此设置此参数为`true`,将预测的输出分割标签(label_map)对应的概率结果(score_map)做softmax归一化处理
+> > * **apply_softmax**(bool): 当模型导出时,并未指定`apply_softmax`参数,可通过此设置此参数为`true`,将预测的输出分割标签(label_map)对应的概率结果(score_map)做softmax归一化处理
## 其它文档
diff --git a/fastdeploy/vision/common/processors/mat.cc b/fastdeploy/vision/common/processors/mat.cc
index 580066229..6e5917a59 100644
--- a/fastdeploy/vision/common/processors/mat.cc
+++ b/fastdeploy/vision/common/processors/mat.cc
@@ -116,5 +116,56 @@ FDDataType Mat::Type() {
}
}
+Mat CreateFromTensor(const FDTensor& tensor) {
+ int type = tensor.dtype;
+ cv::Mat temp_mat;
+ FDASSERT(tensor.shape.size() == 3,
+ "When create FD Mat from tensor, tensor shape should be 3-Dim, HWC "
+ "layout");
+ int64_t height = tensor.shape[0];
+ int64_t width = tensor.shape[1];
+ int64_t channel = tensor.shape[2];
+ switch (type) {
+ case FDDataType::UINT8:
+ temp_mat = cv::Mat(height, width, CV_8UC(channel),
+ const_cast(tensor.Data()));
+ break;
+
+ case FDDataType::INT8:
+ temp_mat = cv::Mat(height, width, CV_8SC(channel),
+ const_cast(tensor.Data()));
+ break;
+
+ case FDDataType::INT16:
+ temp_mat = cv::Mat(height, width, CV_16SC(channel),
+ const_cast(tensor.Data()));
+ break;
+
+ case FDDataType::INT32:
+ temp_mat = cv::Mat(height, width, CV_32SC(channel),
+ const_cast(tensor.Data()));
+ break;
+
+ case FDDataType::FP32:
+ temp_mat = cv::Mat(height, width, CV_32FC(channel),
+ const_cast(tensor.Data()));
+ break;
+
+ case FDDataType::FP64:
+ temp_mat = cv::Mat(height, width, CV_64FC(channel),
+ const_cast(tensor.Data()));
+ break;
+
+ default:
+ FDASSERT(
+ false,
+ "Tensor type %d is not supported While calling CreateFromTensor.",
+ type);
+ break;
+ }
+ Mat mat = Mat(temp_mat);
+ return mat;
+}
+
} // namespace vision
} // namespace fastdeploy
diff --git a/fastdeploy/vision/common/processors/mat.h b/fastdeploy/vision/common/processors/mat.h
index cf4736238..14acfd3df 100644
--- a/fastdeploy/vision/common/processors/mat.h
+++ b/fastdeploy/vision/common/processors/mat.h
@@ -76,5 +76,7 @@ struct FASTDEPLOY_DECL Mat {
Device device = Device::CPU;
};
+Mat CreateFromTensor(const FDTensor& tensor);
+
} // namespace vision
} // namespace fastdeploy
diff --git a/fastdeploy/vision/segmentation/ppseg/FDTensor2CVMat.cc b/fastdeploy/vision/segmentation/ppseg/FDTensor2CVMat.cc
deleted file mode 100644
index 3e7a58366..000000000
--- a/fastdeploy/vision/segmentation/ppseg/FDTensor2CVMat.cc
+++ /dev/null
@@ -1,59 +0,0 @@
-// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "fastdeploy/vision/segmentation/ppseg/model.h"
-
-namespace fastdeploy {
-namespace vision {
-namespace segmentation {
-
-void FDTensor2FP32CVMat(cv::Mat& mat, FDTensor& infer_result,
- bool contain_score_map) {
- // output with argmax channel is 1
- int channel = 1;
- int height = infer_result.shape[1];
- int width = infer_result.shape[2];
-
- if (contain_score_map) {
- // output without argmax and convent to NHWC
- channel = infer_result.shape[3];
- }
- // create FP32 cvmat
- if (infer_result.dtype == FDDataType::INT64) {
- FDWARNING << "The PaddleSeg model is exported with argmax. Inference "
- "result type is " +
- Str(infer_result.dtype) +
- ". If you want the edge of segmentation image more "
- "smoother. Please export model with --without_argmax "
- "--with_softmax."
- << std::endl;
- int64_t chw = channel * height * width;
- int64_t* infer_result_buffer = static_cast(infer_result.Data());
- std::vector float_result_buffer(chw);
- mat = cv::Mat(height, width, CV_32FC(channel));
- int index = 0;
- for (int i = 0; i < height; i++) {
- for (int j = 0; j < width; j++) {
- mat.at(i, j) =
- static_cast(infer_result_buffer[index++]);
- }
- }
- } else if (infer_result.dtype == FDDataType::FP32) {
- mat = cv::Mat(height, width, CV_32FC(channel), infer_result.Data());
- }
-}
-
-} // namespace segmentation
-} // namespace vision
-} // namespace fastdeploy
diff --git a/fastdeploy/vision/segmentation/ppseg/model.cc b/fastdeploy/vision/segmentation/ppseg/model.cc
index 918bd75d2..b413351e8 100644
--- a/fastdeploy/vision/segmentation/ppseg/model.cc
+++ b/fastdeploy/vision/segmentation/ppseg/model.cc
@@ -14,7 +14,7 @@ PaddleSegModel::PaddleSegModel(const std::string& model_file,
const ModelFormat& model_format) {
config_file_ = config_file;
valid_cpu_backends = {Backend::OPENVINO, Backend::PDINFER};
- valid_gpu_backends = {Backend::PDINFER, Backend::TRT};
+ valid_gpu_backends = {Backend::PDINFER};
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
@@ -79,12 +79,32 @@ bool PaddleSegModel::BuildPreprocessPipelineFromConfig() {
}
processors_.push_back(std::make_shared());
}
+ if (cfg["Deploy"]["output_op"]) {
+ std::string output_op = cfg["Deploy"]["output_op"].as();
+ if (output_op == "softmax") {
+ is_with_softmax = true;
+ is_with_argmax = false;
+ } else if (output_op == "argmax") {
+ is_with_softmax = false;
+ is_with_argmax = true;
+ } else if (output_op == "none") {
+ is_with_softmax = false;
+ is_with_argmax = false;
+ } else {
+ FDERROR << "Unexcepted output_op operator in deploy.yml: " << output_op
+ << "." << std::endl;
+ }
+ }
+ if (is_with_argmax) {
+ FDWARNING << "The PaddleSeg model is exported with argmax."
+ << " If you want the edge of segmentation image more"
+ << " smoother. Please export model with parameters"
+ << " --output_op softmax." << std::endl;
+ }
return true;
}
-bool PaddleSegModel::Preprocess(
- Mat* mat, FDTensor* output,
- std::map>* im_info) {
+bool PaddleSegModel::Preprocess(Mat* mat, FDTensor* output) {
for (size_t i = 0; i < processors_.size(); ++i) {
if (processors_[i]->Name().compare("Resize") == 0) {
auto processor = dynamic_cast(processors_[i].get());
@@ -105,10 +125,6 @@ bool PaddleSegModel::Preprocess(
}
}
- // Record output shape of preprocessed image
- (*im_info)["output_shape"] = {static_cast(mat->Height()),
- static_cast(mat->Width())};
-
mat->ShareWithTensor(output);
output->shape.insert(output->shape.begin(), 1);
output->name = InputInfoOfRuntime(0).name;
@@ -116,13 +132,15 @@ bool PaddleSegModel::Preprocess(
}
bool PaddleSegModel::Postprocess(
- FDTensor& infer_result, SegmentationResult* result,
- std::map>* im_info) {
+ FDTensor* infer_result, SegmentationResult* result,
+ const std::map>& im_info) {
// PaddleSeg has three types of inference output:
- // 1. output with argmax and without softmax. 3-D matrix CHW, Channel
+ // 1. output with argmax and without softmax. 3-D matrix N(C)HW, Channel
// always 1, the element in matrix is classified label_id INT64 Type.
- // 2. output without argmax and without softmax. 4-D matrix NCHW, N always
- // 1, Channel is the num of classes. The element is the logits of classes
+ // 2. output without argmax and without softmax. 4-D matrix NCHW, N(batch)
+ // always
+ // 1(only support batch size 1), Channel is the num of classes. The
+ // element is the logits of classes
// FP32
// 3. output without argmax and with softmax. 4-D matrix NCHW, the result
// of 2 with softmax layer
@@ -130,59 +148,117 @@ bool PaddleSegModel::Postprocess(
// 1. label_map
// 2. score_map(optional)
// 3. shape: 2-D HW
- FDASSERT(infer_result.dtype == FDDataType::INT64 ||
- infer_result.dtype == FDDataType::FP32,
- "Require the data type of output is int64 or fp32, but now it's %s.",
- Str(infer_result.dtype).c_str());
+ FDASSERT(infer_result->dtype == FDDataType::INT64 ||
+ infer_result->dtype == FDDataType::FP32 ||
+ infer_result->dtype == FDDataType::INT32,
+ "Require the data type of output is int64, fp32 or int32, but now "
+ "it's %s.",
+ Str(infer_result->dtype).c_str());
result->Clear();
+ FDASSERT(infer_result->shape[0] == 1, "Only support batch size = 1.");
- if (infer_result.shape.size() == 4) {
- FDASSERT(infer_result.shape[0] == 1, "Only support batch size = 1.");
+ int64_t batch = infer_result->shape[0];
+ int64_t channel = 0;
+ int64_t height = 0;
+ int64_t width = 0;
+
+ if (is_with_argmax) {
+ channel = 1;
+ height = infer_result->shape[1];
+ width = infer_result->shape[2];
+ } else {
+ channel = infer_result->shape[1];
+ height = infer_result->shape[2];
+ width = infer_result->shape[3];
+ }
+ int64_t chw = channel * height * width;
+
+ if (!is_with_softmax && apply_softmax) {
+ Softmax(*infer_result, infer_result, 1);
+ }
+
+ if (!is_with_argmax) {
// output without argmax
result->contain_score_map = true;
- utils::NCHW2NHWC(infer_result);
+
+ std::vector dim{0, 2, 3, 1};
+ Transpose(*infer_result, infer_result, dim);
}
+ // batch always 1, so ignore
+ infer_result->shape = {height, width, channel};
// for resize mat below
FDTensor new_infer_result;
Mat* mat = nullptr;
+ std::vector* fp32_result_buffer = nullptr;
if (is_resized) {
- cv::Mat temp_mat;
- FDTensor2FP32CVMat(temp_mat, infer_result, result->contain_score_map);
-
- // original image shape
- auto iter_ipt = (*im_info).find("input_shape");
- FDASSERT(iter_ipt != im_info->end(),
+ if (infer_result->dtype == FDDataType::INT64 ||
+ infer_result->dtype == FDDataType::INT32) {
+ if (infer_result->dtype == FDDataType::INT64) {
+ int64_t* infer_result_buffer =
+ static_cast(infer_result->Data());
+ // cv::resize don't support `CV_8S` or `CV_32S`
+ // refer to https://github.com/opencv/opencv/issues/20991
+ // https://github.com/opencv/opencv/issues/7862
+ fp32_result_buffer = new std::vector(
+ infer_result_buffer, infer_result_buffer + chw);
+ }
+ if (infer_result->dtype == FDDataType::INT32) {
+ int32_t* infer_result_buffer =
+ static_cast(infer_result->Data());
+ // cv::resize don't support `CV_8S` or `CV_32S`
+ // refer to https://github.com/opencv/opencv/issues/20991
+ // https://github.com/opencv/opencv/issues/7862
+ fp32_result_buffer = new std::vector(
+ infer_result_buffer, infer_result_buffer + chw);
+ }
+ infer_result->Resize(infer_result->shape, FDDataType::FP32);
+ infer_result->SetExternalData(
+ infer_result->shape, FDDataType::FP32,
+ static_cast(fp32_result_buffer->data()));
+ }
+ auto iter_ipt = im_info.find("input_shape");
+ FDASSERT(iter_ipt != im_info.end(),
"Cannot find input_shape from im_info.");
int ipt_h = iter_ipt->second[0];
int ipt_w = iter_ipt->second[1];
-
- mat = new Mat(temp_mat);
-
- Resize::Run(mat, ipt_w, ipt_h, -1, -1, 1);
+ mat = new Mat(CreateFromTensor(*infer_result));
+ Resize::Run(mat, ipt_w, ipt_h, -1.0f, -1.0f, 1);
mat->ShareWithTensor(&new_infer_result);
- new_infer_result.shape.insert(new_infer_result.shape.begin(), 1);
result->shape = new_infer_result.shape;
} else {
- result->shape = infer_result.shape;
+ result->shape = infer_result->shape;
}
+ // output shape is 2-D HW layout, so out_num = H * W
int out_num =
- std::accumulate(result->shape.begin(), result->shape.begin() + 3, 1,
+ std::accumulate(result->shape.begin(), result->shape.begin() + 2, 1,
std::multiplies());
- // NCHW remove N or CHW remove C
- result->shape.erase(result->shape.begin());
result->Resize(out_num);
if (result->contain_score_map) {
// output with label_map and score_map
- float_t* infer_result_buffer = nullptr;
- if (is_resized) {
- infer_result_buffer = static_cast(new_infer_result.Data());
- } else {
- infer_result_buffer = static_cast(infer_result.Data());
- }
+ int32_t* argmax_infer_result_buffer = nullptr;
+ float_t* score_infer_result_buffer = nullptr;
+ FDTensor argmax_infer_result;
+ FDTensor max_score_result;
+ std::vector reduce_dim{-1};
// argmax
- utils::ArgmaxScoreMap(infer_result_buffer, result, with_softmax);
- result->shape.erase(result->shape.begin() + 2);
+ if (is_resized) {
+ ArgMax(new_infer_result, &argmax_infer_result, -1, FDDataType::INT32);
+ Max(new_infer_result, &max_score_result, reduce_dim);
+ } else {
+ ArgMax(*infer_result, &argmax_infer_result, -1, FDDataType::INT32);
+ Max(*infer_result, &max_score_result, reduce_dim);
+ }
+ argmax_infer_result_buffer =
+ static_cast(argmax_infer_result.Data());
+ score_infer_result_buffer = static_cast(max_score_result.Data());
+ for (int i = 0; i < out_num; i++) {
+ result->label_map[i] =
+ static_cast(*(argmax_infer_result_buffer + i));
+ }
+ std::memcpy(result->score_map.data(), score_infer_result_buffer,
+ out_num * sizeof(float_t));
+
} else {
// output only with label_map
if (is_resized) {
@@ -192,13 +268,27 @@ bool PaddleSegModel::Postprocess(
result->label_map[i] = static_cast(*(infer_result_buffer + i));
}
} else {
- const int64_t* infer_result_buffer =
- reinterpret_cast(infer_result.Data());
- for (int i = 0; i < out_num; i++) {
- result->label_map[i] = static_cast(*(infer_result_buffer + i));
+ if (infer_result->dtype == FDDataType::INT64) {
+ const int64_t* infer_result_buffer =
+ static_cast(infer_result->Data());
+ for (int i = 0; i < out_num; i++) {
+ result->label_map[i] =
+ static_cast(*(infer_result_buffer + i));
+ }
+ }
+ if (infer_result->dtype == FDDataType::INT32) {
+ const int32_t* infer_result_buffer =
+ static_cast(infer_result->Data());
+ for (int i = 0; i < out_num; i++) {
+ result->label_map[i] =
+ static_cast(*(infer_result_buffer + i));
+ }
}
}
}
+ // HWC remove C
+ result->shape.erase(result->shape.begin() + 2);
+ delete fp32_result_buffer;
delete mat;
mat = nullptr;
return true;
@@ -213,10 +303,8 @@ bool PaddleSegModel::Predict(cv::Mat* im, SegmentationResult* result) {
// Record the shape of image and the shape of preprocessed image
im_info["input_shape"] = {static_cast(mat.Height()),
static_cast(mat.Width())};
- im_info["output_shape"] = {static_cast(mat.Height()),
- static_cast(mat.Width())};
- if (!Preprocess(&mat, &(processed_data[0]), &im_info)) {
+ if (!Preprocess(&mat, &(processed_data[0]))) {
FDERROR << "Failed to preprocess input data while using model:"
<< ModelName() << "." << std::endl;
return false;
@@ -227,7 +315,7 @@ bool PaddleSegModel::Predict(cv::Mat* im, SegmentationResult* result) {
<< std::endl;
return false;
}
- if (!Postprocess(infer_result[0], result, &im_info)) {
+ if (!Postprocess(&infer_result[0], result, im_info)) {
FDERROR << "Failed to postprocess while using model:" << ModelName() << "."
<< std::endl;
return false;
diff --git a/fastdeploy/vision/segmentation/ppseg/model.h b/fastdeploy/vision/segmentation/ppseg/model.h
index 06704d81a..0f649566e 100644
--- a/fastdeploy/vision/segmentation/ppseg/model.h
+++ b/fastdeploy/vision/segmentation/ppseg/model.h
@@ -18,7 +18,7 @@ class FASTDEPLOY_DECL PaddleSegModel : public FastDeployModel {
virtual bool Predict(cv::Mat* im, SegmentationResult* result);
- bool with_softmax = false;
+ bool apply_softmax = false;
bool is_vertical_screen = false;
@@ -27,20 +27,21 @@ class FASTDEPLOY_DECL PaddleSegModel : public FastDeployModel {
bool BuildPreprocessPipelineFromConfig();
- bool Preprocess(Mat* mat, FDTensor* outputs,
- std::map>* im_info);
+ bool Preprocess(Mat* mat, FDTensor* outputs);
- bool Postprocess(FDTensor& infer_result, SegmentationResult* result,
- std::map>* im_info);
+ bool Postprocess(FDTensor* infer_result, SegmentationResult* result,
+ const std::map>& im_info);
bool is_resized = false;
+ bool is_with_softmax = false;
+
+ bool is_with_argmax = true;
+
std::vector> processors_;
std::string config_file_;
};
-void FDTensor2FP32CVMat(cv::Mat& mat, FDTensor& infer_result,
- bool contain_score_map);
} // namespace segmentation
} // namespace vision
} // namespace fastdeploy
diff --git a/fastdeploy/vision/segmentation/ppseg/ppseg_pybind.cc b/fastdeploy/vision/segmentation/ppseg/ppseg_pybind.cc
index c94b7fd19..51bec778f 100644
--- a/fastdeploy/vision/segmentation/ppseg/ppseg_pybind.cc
+++ b/fastdeploy/vision/segmentation/ppseg/ppseg_pybind.cc
@@ -27,8 +27,8 @@ void BindPPSeg(pybind11::module& m) {
self.Predict(&mat, res);
return res;
})
- .def_readwrite("with_softmax",
- &vision::segmentation::PaddleSegModel::with_softmax)
+ .def_readwrite("apply_softmax",
+ &vision::segmentation::PaddleSegModel::apply_softmax)
.def_readwrite("is_vertical_screen",
&vision::segmentation::PaddleSegModel::is_vertical_screen);
}
diff --git a/fastdeploy/vision/utils/utils.h b/fastdeploy/vision/utils/utils.h
index 3fa36e7f1..89fc6a17e 100644
--- a/fastdeploy/vision/utils/utils.h
+++ b/fastdeploy/vision/utils/utils.h
@@ -20,6 +20,11 @@
#include "fastdeploy/utils/utils.h"
#include "fastdeploy/vision/common/result.h"
+// #include "unsupported/Eigen/CXX11/Tensor"
+#include "fastdeploy/function/reduce.h"
+#include "fastdeploy/function/softmax.h"
+#include "fastdeploy/function/transpose.h"
+
namespace fastdeploy {
namespace vision {
namespace utils {
@@ -51,70 +56,6 @@ std::vector TopKIndices(const T* array, int array_size, int topk) {
return res;
}
-template
-void ArgmaxScoreMap(T infer_result_buffer, SegmentationResult* result,
- bool with_softmax) {
- int64_t height = result->shape[0];
- int64_t width = result->shape[1];
- int64_t num_classes = result->shape[2];
- int index = 0;
- for (size_t i = 0; i < height; ++i) {
- for (size_t j = 0; j < width; ++j) {
- int64_t s = (i * width + j) * num_classes;
- T max_class_score = std::max_element(
- infer_result_buffer + s, infer_result_buffer + s + num_classes);
- int label_id = std::distance(infer_result_buffer + s, max_class_score);
- if (label_id >= 255) {
- FDWARNING << "label_id is stored by uint8_t, now the value is bigger "
- "than 255, it's "
- << static_cast(label_id) << "." << std::endl;
- }
- result->label_map[index] = static_cast(label_id);
-
- if (with_softmax) {
- double_t total = 0;
- for (int k = 0; k < num_classes; k++) {
- total += exp(*(infer_result_buffer + s + k) - *max_class_score);
- }
- double_t softmax_class_score = 1 / total;
- result->score_map[index] = static_cast(softmax_class_score);
-
- } else {
- result->score_map[index] = static_cast(*max_class_score);
- }
- index++;
- }
- }
-}
-
-template
-void NCHW2NHWC(FDTensor& infer_result) {
- T* infer_result_buffer = reinterpret_cast(infer_result.MutableData());
- int num = infer_result.shape[0];
- int channel = infer_result.shape[1];
- int height = infer_result.shape[2];
- int width = infer_result.shape[3];
- int chw = channel * height * width;
- int wc = width * channel;
- int wh = width * height;
- std::vector hwc_data(chw);
- int index = 0;
- for (int n = 0; n < num; n++) {
- for (int c = 0; c < channel; c++) {
- for (int h = 0; h < height; h++) {
- for (int w = 0; w < width; w++) {
- hwc_data[n * chw + h * wc + w * channel + c] =
- *(infer_result_buffer + index);
- index++;
- }
- }
- }
- }
- std::memcpy(infer_result.MutableData(), hwc_data.data(),
- num * chw * sizeof(T));
- infer_result.shape = {num, height, width, channel};
-}
-
void NMS(DetectionResult* output, float iou_threshold = 0.5);
void NMS(FaceDetectionResult* result, float iou_threshold = 0.5);
diff --git a/python/fastdeploy/vision/segmentation/ppseg/__init__.py b/python/fastdeploy/vision/segmentation/ppseg/__init__.py
index a7c15f50b..9a3a5b577 100644
--- a/python/fastdeploy/vision/segmentation/ppseg/__init__.py
+++ b/python/fastdeploy/vision/segmentation/ppseg/__init__.py
@@ -37,15 +37,15 @@ class PaddleSegModel(FastDeployModel):
return self._model.predict(input_image)
@property
- def with_softmax(self):
- return self._model.with_softmax
+ def apply_softmax(self):
+ return self._model.apply_softmax
- @with_softmax.setter
- def with_softmax(self, value):
+ @apply_softmax.setter
+ def apply_softmax(self, value):
assert isinstance(
value,
- bool), "The value to set `with_softmax` must be type of bool."
- self._model.with_softmax = value
+ bool), "The value to set `apply_softmax` must be type of bool."
+ self._model.apply_softmax = value
@property
def is_vertical_screen(self):