mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-17 14:11:14 +08:00
[vision] support PaddleDetection/MaskRCNN model (#218)
* [vision] support padddetection maskrcnn * [vision] fixed instance mask visualize func * [vision] optimize instance mask visualize func * [docs] update ppdet/maskrcnn docs * [vision] update maskrcnn implementation Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -19,3 +19,4 @@ fastdeploy/ThirdPartyNotices*
|
||||
fastdeploy/libs/third_libs
|
||||
csrc/fastdeploy/core/config.h
|
||||
csrc/fastdeploy/pybind/main.cc
|
||||
__pycache__
|
@@ -11,9 +11,7 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/core/fd_tensor.h"
|
||||
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
|
||||
#ifdef WITH_GPU
|
||||
@@ -78,6 +76,13 @@ void FDTensor::SetExternalData(const std::vector<int64_t>& new_shape,
|
||||
device = new_device;
|
||||
}
|
||||
|
||||
void FDTensor::ExpandDim(int64_t axis) {
|
||||
size_t ndim = shape.size();
|
||||
FDASSERT(axis >= 0 && axis <= ndim,
|
||||
"The allowed 'axis' must be in range of (0, %lu)!", ndim);
|
||||
shape.insert(shape.begin() + axis, 1);
|
||||
}
|
||||
|
||||
void FDTensor::Allocate(const std::vector<int64_t>& new_shape,
|
||||
const FDDataType& data_type,
|
||||
const std::string& tensor_name,
|
||||
|
@@ -67,6 +67,10 @@ struct FASTDEPLOY_DECL FDTensor {
|
||||
const FDDataType& data_type, void* data_buffer,
|
||||
const Device& new_device = Device::CPU);
|
||||
|
||||
// Expand the shape of a Tensor. Insert a new axis that will appear
|
||||
// at the `axis` position in the expanded Tensor shape.
|
||||
void ExpandDim(int64_t axis = 0);
|
||||
|
||||
// Initialize Tensor
|
||||
// Include setting attribute for tensor
|
||||
// and allocate cpu memory buffer
|
||||
|
@@ -35,39 +35,84 @@ std::string ClassifyResult::Str() {
|
||||
return out;
|
||||
}
|
||||
|
||||
void Mask::Reserve(int size) { data.reserve(size); }
|
||||
|
||||
void Mask::Resize(int size) { data.resize(size); }
|
||||
|
||||
void Mask::Clear() {
|
||||
std::vector<int32_t>().swap(data);
|
||||
std::vector<int64_t>().swap(shape);
|
||||
}
|
||||
|
||||
std::string Mask::Str() {
|
||||
std::string out = "Mask(";
|
||||
size_t ndim = shape.size();
|
||||
for (size_t i = 0; i < ndim; ++i) {
|
||||
if (i < ndim - 1) {
|
||||
out += std::to_string(shape[i]) + ",";
|
||||
} else {
|
||||
out += std::to_string(shape[i]);
|
||||
}
|
||||
}
|
||||
out += ")\n";
|
||||
return out;
|
||||
}
|
||||
|
||||
DetectionResult::DetectionResult(const DetectionResult& res) {
|
||||
boxes.assign(res.boxes.begin(), res.boxes.end());
|
||||
scores.assign(res.scores.begin(), res.scores.end());
|
||||
label_ids.assign(res.label_ids.begin(), res.label_ids.end());
|
||||
contain_masks = res.contain_masks;
|
||||
if (contain_masks) {
|
||||
masks.clear();
|
||||
size_t mask_size = res.masks.size();
|
||||
for (size_t i = 0; i < mask_size; ++i) {
|
||||
masks.emplace_back(res.masks[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DetectionResult::Clear() {
|
||||
std::vector<std::array<float, 4>>().swap(boxes);
|
||||
std::vector<float>().swap(scores);
|
||||
std::vector<int32_t>().swap(label_ids);
|
||||
std::vector<Mask>().swap(masks);
|
||||
contain_masks = false;
|
||||
}
|
||||
|
||||
void DetectionResult::Reserve(int size) {
|
||||
boxes.reserve(size);
|
||||
scores.reserve(size);
|
||||
label_ids.reserve(size);
|
||||
masks.reserve(size);
|
||||
}
|
||||
|
||||
void DetectionResult::Resize(int size) {
|
||||
boxes.resize(size);
|
||||
scores.resize(size);
|
||||
label_ids.resize(size);
|
||||
masks.resize(size);
|
||||
}
|
||||
|
||||
std::string DetectionResult::Str() {
|
||||
std::string out;
|
||||
if (!contain_masks) {
|
||||
out = "DetectionResult: [xmin, ymin, xmax, ymax, score, label_id]\n";
|
||||
} else {
|
||||
out =
|
||||
"DetectionResult: [xmin, ymin, xmax, ymax, score, label_id, "
|
||||
"mask_shape]\n";
|
||||
}
|
||||
for (size_t i = 0; i < boxes.size(); ++i) {
|
||||
out = out + std::to_string(boxes[i][0]) + "," +
|
||||
std::to_string(boxes[i][1]) + ", " + std::to_string(boxes[i][2]) +
|
||||
", " + std::to_string(boxes[i][3]) + ", " +
|
||||
std::to_string(scores[i]) + ", " + std::to_string(label_ids[i]) +
|
||||
"\n";
|
||||
std::to_string(scores[i]) + ", " + std::to_string(label_ids[i]);
|
||||
if (!contain_masks) {
|
||||
out += "\n";
|
||||
} else {
|
||||
out += ", " + masks[i].Str();
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
@@ -25,7 +25,8 @@ enum FASTDEPLOY_DECL ResultType {
|
||||
OCR,
|
||||
FACE_DETECTION,
|
||||
FACE_RECOGNITION,
|
||||
MATTING
|
||||
MATTING,
|
||||
MASK
|
||||
};
|
||||
|
||||
struct FASTDEPLOY_DECL BaseResult {
|
||||
@@ -41,11 +42,31 @@ struct FASTDEPLOY_DECL ClassifyResult : public BaseResult {
|
||||
std::string Str();
|
||||
};
|
||||
|
||||
struct FASTDEPLOY_DECL Mask : public BaseResult {
|
||||
std::vector<int32_t> data;
|
||||
std::vector<int64_t> shape; // (H,W) ...
|
||||
ResultType type = ResultType::MASK;
|
||||
|
||||
void Clear();
|
||||
|
||||
void* Data() { return data.data(); }
|
||||
|
||||
const void* Data() const { return data.data(); }
|
||||
|
||||
void Reserve(int size);
|
||||
|
||||
void Resize(int size);
|
||||
|
||||
std::string Str();
|
||||
};
|
||||
|
||||
struct FASTDEPLOY_DECL DetectionResult : public BaseResult {
|
||||
// box: xmin, ymin, xmax, ymax
|
||||
std::vector<std::array<float, 4>> boxes;
|
||||
std::vector<float> scores;
|
||||
std::vector<int32_t> label_ids;
|
||||
std::vector<Mask> masks;
|
||||
bool contain_masks = false;
|
||||
ResultType type = ResultType::DETECTION;
|
||||
|
||||
DetectionResult() {}
|
||||
|
120
csrc/fastdeploy/vision/detection/ppdet/mask_rcnn.cc
Normal file
120
csrc/fastdeploy/vision/detection/ppdet/mask_rcnn.cc
Normal file
@@ -0,0 +1,120 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/detection/ppdet/mask_rcnn.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
|
||||
MaskRCNN::MaskRCNN(const std::string& model_file,
|
||||
const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option,
|
||||
const Frontend& model_format) {
|
||||
config_file_ = config_file;
|
||||
valid_cpu_backends = {Backend::PDINFER};
|
||||
valid_gpu_backends = {Backend::PDINFER};
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = model_format;
|
||||
runtime_option.model_file = model_file;
|
||||
runtime_option.params_file = params_file;
|
||||
initialized = Initialize();
|
||||
}
|
||||
|
||||
bool MaskRCNN::Postprocess(std::vector<FDTensor>& infer_result,
|
||||
DetectionResult* result) {
|
||||
// index 0: bbox_data [N, 6] float32
|
||||
// index 1: bbox_num [B=1] int32
|
||||
// index 2: mask_data [N, h, w] int32
|
||||
FDASSERT(infer_result[1].shape[0] == 1,
|
||||
"Only support batch = 1 in FastDeploy now.");
|
||||
FDASSERT(infer_result.size() == 3,
|
||||
"The infer_result must contains 3 otuput Tensors, but found %lu",
|
||||
infer_result.size());
|
||||
|
||||
FDTensor& box_tensor = infer_result[0];
|
||||
FDTensor& box_num_tensor = infer_result[1];
|
||||
FDTensor& mask_tensor = infer_result[2];
|
||||
|
||||
int box_num = 0;
|
||||
if (box_num_tensor.dtype == FDDataType::INT32) {
|
||||
box_num = *(static_cast<int32_t*>(box_num_tensor.Data()));
|
||||
} else if (box_num_tensor.dtype == FDDataType::INT64) {
|
||||
box_num = *(static_cast<int64_t*>(box_num_tensor.Data()));
|
||||
} else {
|
||||
FDASSERT(false,
|
||||
"The output box_num of PaddleDetection/MaskRCNN model should be "
|
||||
"type of int32/int64.");
|
||||
}
|
||||
if (box_num <= 0) {
|
||||
return true; // no object detected.
|
||||
}
|
||||
result->Resize(box_num);
|
||||
float* box_data = static_cast<float*>(box_tensor.Data());
|
||||
for (size_t i = 0; i < box_num; ++i) {
|
||||
result->label_ids[i] = static_cast<int>(box_data[i * 6]);
|
||||
result->scores[i] = box_data[i * 6 + 1];
|
||||
result->boxes[i] =
|
||||
std::array<float, 4>{box_data[i * 6 + 2], box_data[i * 6 + 3],
|
||||
box_data[i * 6 + 4], box_data[i * 6 + 5]};
|
||||
}
|
||||
result->contain_masks = true;
|
||||
// TODO(qiuyanjun): Cast int64/int8 to int32.
|
||||
FDASSERT(mask_tensor.dtype == FDDataType::INT32,
|
||||
"The dtype of mask Tensor must be int32 now!");
|
||||
// In PaddleDetection/MaskRCNN, the mask_h and mask_w
|
||||
// are already aligned with original input image. So,
|
||||
// we need to crop it from output mask according to
|
||||
// the detected bounding box.
|
||||
// +-----------------------+
|
||||
// | x1,y1 |
|
||||
// | +---------------+ |
|
||||
// | | | |
|
||||
// | | Crop | |
|
||||
// | | | |
|
||||
// | | | |
|
||||
// | +---------------+ |
|
||||
// | x2,y2 |
|
||||
// +-----------------------+
|
||||
int64_t out_mask_h = mask_tensor.shape[1];
|
||||
int64_t out_mask_w = mask_tensor.shape[2];
|
||||
int64_t out_mask_numel = out_mask_h * out_mask_w;
|
||||
int32_t* out_mask_data = static_cast<int32_t*>(mask_tensor.Data());
|
||||
for (size_t i = 0; i < box_num; ++i) {
|
||||
// crop instance mask according to box
|
||||
int64_t x1 = static_cast<int64_t>(result->boxes[i][0]);
|
||||
int64_t y1 = static_cast<int64_t>(result->boxes[i][1]);
|
||||
int64_t x2 = static_cast<int64_t>(result->boxes[i][2]);
|
||||
int64_t y2 = static_cast<int64_t>(result->boxes[i][3]);
|
||||
int64_t keep_mask_h = y2 - y1;
|
||||
int64_t keep_mask_w = x2 - x1;
|
||||
int64_t keep_mask_numel = keep_mask_h * keep_mask_w;
|
||||
result->masks[i].Resize(keep_mask_numel); // int32
|
||||
result->masks[i].shape = {keep_mask_h, keep_mask_w};
|
||||
int32_t* mask_start_ptr = out_mask_data + i * out_mask_numel;
|
||||
int32_t* keep_mask_ptr = static_cast<int32_t*>(result->masks[i].Data());
|
||||
for (size_t row = y1; row < y2; ++row) {
|
||||
size_t keep_nbytes_in_col = keep_mask_w * sizeof(int32_t);
|
||||
int32_t* out_row_start_ptr = mask_start_ptr + row * out_mask_w + x1;
|
||||
int32_t* keep_row_start_ptr = keep_mask_ptr + (row - y1) * keep_mask_w;
|
||||
std::memcpy(keep_row_start_ptr, out_row_start_ptr, keep_nbytes_in_col);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
40
csrc/fastdeploy/vision/detection/ppdet/mask_rcnn.h
Normal file
40
csrc/fastdeploy/vision/detection/ppdet/mask_rcnn.h
Normal file
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/vision/detection/ppdet/rcnn.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
|
||||
class FASTDEPLOY_DECL MaskRCNN : public FasterRCNN {
|
||||
public:
|
||||
MaskRCNN(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const Frontend& model_format = Frontend::PADDLE);
|
||||
|
||||
virtual std::string ModelName() const { return "PaddleDetection/MaskRCNN"; }
|
||||
|
||||
virtual bool Postprocess(std::vector<FDTensor>& infer_result,
|
||||
DetectionResult* result);
|
||||
|
||||
protected:
|
||||
MaskRCNN() {}
|
||||
};
|
||||
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/vision/detection/ppdet/mask_rcnn.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/picodet.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/ppyolo.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/ppyoloe.h"
|
||||
|
@@ -15,54 +15,56 @@
|
||||
|
||||
namespace fastdeploy {
|
||||
void BindPPDet(pybind11::module& m) {
|
||||
pybind11::class_<vision::detection::PPYOLOE, FastDeployModel>(m,
|
||||
"PPYOLOE")
|
||||
pybind11::class_<vision::detection::PPYOLOE, FastDeployModel>(m, "PPYOLOE")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
Frontend>())
|
||||
.def("predict", [](vision::detection::PPYOLOE& self, pybind11::array& data) {
|
||||
.def("predict",
|
||||
[](vision::detection::PPYOLOE& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res);
|
||||
return res;
|
||||
});
|
||||
|
||||
pybind11::class_<vision::detection::PPYOLO, FastDeployModel>(m,
|
||||
"PPYOLO")
|
||||
pybind11::class_<vision::detection::PPYOLO, FastDeployModel>(m, "PPYOLO")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
Frontend>())
|
||||
.def("predict", [](vision::detection::PPYOLO& self, pybind11::array& data) {
|
||||
.def("predict",
|
||||
[](vision::detection::PPYOLO& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res);
|
||||
return res;
|
||||
});
|
||||
|
||||
pybind11::class_<vision::detection::PPYOLOv2, FastDeployModel>(m,
|
||||
"PPYOLOv2")
|
||||
pybind11::class_<vision::detection::PPYOLOv2, FastDeployModel>(m, "PPYOLOv2")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
Frontend>())
|
||||
.def("predict", [](vision::detection::PPYOLOv2& self, pybind11::array& data) {
|
||||
.def("predict",
|
||||
[](vision::detection::PPYOLOv2& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res);
|
||||
return res;
|
||||
});
|
||||
|
||||
pybind11::class_<vision::detection::PicoDet, FastDeployModel>(m,
|
||||
"PicoDet")
|
||||
pybind11::class_<vision::detection::PicoDet, FastDeployModel>(m, "PicoDet")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
Frontend>())
|
||||
.def("predict", [](vision::detection::PicoDet& self, pybind11::array& data) {
|
||||
.def("predict",
|
||||
[](vision::detection::PicoDet& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res);
|
||||
return res;
|
||||
});
|
||||
|
||||
pybind11::class_<vision::detection::PaddleYOLOX, FastDeployModel>(m, "PaddleYOLOX")
|
||||
pybind11::class_<vision::detection::PaddleYOLOX, FastDeployModel>(
|
||||
m, "PaddleYOLOX")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
Frontend>())
|
||||
.def("predict", [](vision::detection::PaddleYOLOX& self, pybind11::array& data) {
|
||||
.def("predict",
|
||||
[](vision::detection::PaddleYOLOX& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res);
|
||||
@@ -81,11 +83,22 @@ void BindPPDet(pybind11::module& m) {
|
||||
return res;
|
||||
});
|
||||
|
||||
pybind11::class_<vision::detection::YOLOv3, FastDeployModel>(m,
|
||||
"YOLOv3")
|
||||
pybind11::class_<vision::detection::YOLOv3, FastDeployModel>(m, "YOLOv3")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
Frontend>())
|
||||
.def("predict", [](vision::detection::YOLOv3& self, pybind11::array& data) {
|
||||
.def("predict",
|
||||
[](vision::detection::YOLOv3& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res);
|
||||
return res;
|
||||
});
|
||||
|
||||
pybind11::class_<vision::detection::MaskRCNN, FastDeployModel>(m, "MaskRCNN")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
Frontend>())
|
||||
.def("predict",
|
||||
[](vision::detection::MaskRCNN& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res);
|
||||
|
@@ -109,7 +109,7 @@ bool PPYOLOE::BuildPreprocessPipelineFromConfig() {
|
||||
bool keep_ratio = op["keep_ratio"].as<bool>();
|
||||
auto target_size = op["target_size"].as<std::vector<int>>();
|
||||
int interp = op["interp"].as<int>();
|
||||
FDASSERT(target_size.size(),
|
||||
FDASSERT(target_size.size() == 2,
|
||||
"Require size of target_size be 2, but now it's %lu.",
|
||||
target_size.size());
|
||||
if (!keep_ratio) {
|
||||
|
@@ -14,11 +14,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "fastdeploy/core/fd_tensor.h"
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
#include "fastdeploy/vision/common/result.h"
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
@@ -87,7 +87,8 @@ void ArgmaxScoreMap(T infer_result_buffer, SegmentationResult* result,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T> void NCHW2NHWC(FDTensor& infer_result) {
|
||||
template <typename T>
|
||||
void NCHW2NHWC(FDTensor& infer_result) {
|
||||
T* infer_result_buffer = reinterpret_cast<T*>(infer_result.MutableData());
|
||||
int num = infer_result.shape[0];
|
||||
int channel = infer_result.shape[1];
|
||||
@@ -124,8 +125,8 @@ void SortDetectionResult(DetectionResult* output);
|
||||
void SortDetectionResult(FaceDetectionResult* result);
|
||||
|
||||
// L2 Norm / cosine similarity (for face recognition, ...)
|
||||
FASTDEPLOY_DECL std::vector<float>
|
||||
L2Normalize(const std::vector<float>& values);
|
||||
FASTDEPLOY_DECL std::vector<float> L2Normalize(
|
||||
const std::vector<float>& values);
|
||||
|
||||
FASTDEPLOY_DECL float CosineSimilarity(const std::vector<float>& a,
|
||||
const std::vector<float>& b,
|
||||
|
@@ -28,6 +28,13 @@ void BindVisualize(pybind11::module& m);
|
||||
#endif
|
||||
|
||||
void BindVision(pybind11::module& m) {
|
||||
pybind11::class_<vision::Mask>(m, "Mask")
|
||||
.def(pybind11::init())
|
||||
.def_readwrite("data", &vision::Mask::data)
|
||||
.def_readwrite("shape", &vision::Mask::shape)
|
||||
.def("__repr__", &vision::Mask::Str)
|
||||
.def("__str__", &vision::Mask::Str);
|
||||
|
||||
pybind11::class_<vision::ClassifyResult>(m, "ClassifyResult")
|
||||
.def(pybind11::init())
|
||||
.def_readwrite("label_ids", &vision::ClassifyResult::label_ids)
|
||||
@@ -40,6 +47,8 @@ void BindVision(pybind11::module& m) {
|
||||
.def_readwrite("boxes", &vision::DetectionResult::boxes)
|
||||
.def_readwrite("scores", &vision::DetectionResult::scores)
|
||||
.def_readwrite("label_ids", &vision::DetectionResult::label_ids)
|
||||
.def_readwrite("masks", &vision::DetectionResult::masks)
|
||||
.def_readwrite("contain_masks", &vision::DetectionResult::contain_masks)
|
||||
.def("__repr__", &vision::DetectionResult::Str)
|
||||
.def("__str__", &vision::DetectionResult::Str);
|
||||
|
||||
@@ -52,6 +61,7 @@ void BindVision(pybind11::module& m) {
|
||||
.def_readwrite("cls_labels", &vision::OCRResult::cls_labels)
|
||||
.def("__repr__", &vision::OCRResult::Str)
|
||||
.def("__str__", &vision::OCRResult::Str);
|
||||
|
||||
pybind11::class_<vision::FaceDetectionResult>(m, "FaceDetectionResult")
|
||||
.def(pybind11::init())
|
||||
.def_readwrite("boxes", &vision::FaceDetectionResult::boxes)
|
||||
|
@@ -27,6 +27,10 @@ cv::Mat Visualize::VisDetection(const cv::Mat& im,
|
||||
const DetectionResult& result,
|
||||
float score_threshold, int line_size,
|
||||
float font_size) {
|
||||
if (result.contain_masks) {
|
||||
FDASSERT(result.boxes.size() == result.masks.size(),
|
||||
"The size of masks must be equal the size of boxes!");
|
||||
}
|
||||
auto color_map = GetColorMap();
|
||||
int h = im.rows;
|
||||
int w = im.cols;
|
||||
@@ -35,9 +39,12 @@ cv::Mat Visualize::VisDetection(const cv::Mat& im,
|
||||
if (result.scores[i] < score_threshold) {
|
||||
continue;
|
||||
}
|
||||
cv::Rect rect(result.boxes[i][0], result.boxes[i][1],
|
||||
result.boxes[i][2] - result.boxes[i][0],
|
||||
result.boxes[i][3] - result.boxes[i][1]);
|
||||
int x1 = static_cast<int>(result.boxes[i][0]);
|
||||
int y1 = static_cast<int>(result.boxes[i][1]);
|
||||
int x2 = static_cast<int>(result.boxes[i][2]);
|
||||
int y2 = static_cast<int>(result.boxes[i][3]);
|
||||
int box_h = y2 - y1;
|
||||
int box_w = x2 - x1;
|
||||
int c0 = color_map[3 * result.label_ids[i] + 0];
|
||||
int c1 = color_map[3 * result.label_ids[i] + 1];
|
||||
int c2 = color_map[3 * result.label_ids[i] + 2];
|
||||
@@ -51,14 +58,46 @@ cv::Mat Visualize::VisDetection(const cv::Mat& im,
|
||||
int font = cv::FONT_HERSHEY_SIMPLEX;
|
||||
cv::Size text_size = cv::getTextSize(text, font, font_size, 1, nullptr);
|
||||
cv::Point origin;
|
||||
origin.x = rect.x;
|
||||
origin.y = rect.y;
|
||||
cv::Rect text_background =
|
||||
cv::Rect(result.boxes[i][0], result.boxes[i][1] - text_size.height,
|
||||
text_size.width, text_size.height);
|
||||
origin.x = x1;
|
||||
origin.y = y1;
|
||||
cv::Rect rect(x1, y1, box_w, box_h);
|
||||
cv::rectangle(vis_im, rect, rect_color, line_size);
|
||||
cv::putText(vis_im, text, origin, font, font_size,
|
||||
cv::Scalar(255, 255, 255), 1);
|
||||
if (result.contain_masks) {
|
||||
int mask_h = static_cast<int>(result.masks[i].shape[0]);
|
||||
int mask_w = static_cast<int>(result.masks[i].shape[1]);
|
||||
// non-const pointer for cv:Mat constructor
|
||||
int32_t* mask_raw_data = const_cast<int32_t*>(
|
||||
static_cast<const int32_t*>(result.masks[i].Data()));
|
||||
// only reference to mask data (zero copy)
|
||||
cv::Mat mask(mask_h, mask_w, CV_32SC1, mask_raw_data);
|
||||
if ((mask_h != box_h) || (mask_w != box_w)) {
|
||||
cv::resize(mask, mask, cv::Size(box_w, box_h));
|
||||
}
|
||||
// use a bright color for instance mask
|
||||
int mc0 = 255 - c0 >= 127 ? 255 - c0 : 127;
|
||||
int mc1 = 255 - c1 >= 127 ? 255 - c1 : 127;
|
||||
int mc2 = 255 - c2 >= 127 ? 255 - c2 : 127;
|
||||
int32_t* mask_data = reinterpret_cast<int32_t*>(mask.data);
|
||||
// inplace blending (zero copy)
|
||||
uchar* vis_im_data = static_cast<uchar*>(vis_im.data);
|
||||
for (size_t i = y1; i < y2; ++i) {
|
||||
for (size_t j = x1; j < x2; ++j) {
|
||||
if (mask_data[(i - y1) * mask_w + (j - x1)] != 0) {
|
||||
vis_im_data[i * w * 3 + j * 3 + 0] = cv::saturate_cast<uchar>(
|
||||
static_cast<float>(mc0) * 0.5f +
|
||||
static_cast<float>(vis_im_data[i * w * 3 + j * 3 + 0]) * 0.5f);
|
||||
vis_im_data[i * w * 3 + j * 3 + 1] = cv::saturate_cast<uchar>(
|
||||
static_cast<float>(mc1) * 0.5f +
|
||||
static_cast<float>(vis_im_data[i * w * 3 + j * 3 + 1]) * 0.5f);
|
||||
vis_im_data[i * w * 3 + j * 3 + 2] = cv::saturate_cast<uchar>(
|
||||
static_cast<float>(mc2) * 0.5f +
|
||||
static_cast<float>(vis_im_data[i * w * 3 + j * 3 + 2]) * 0.5f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return vis_im;
|
||||
}
|
||||
|
@@ -45,7 +45,6 @@ class FASTDEPLOY_DECL Visualize {
|
||||
int background_label,
|
||||
const SegmentationResult& result);
|
||||
static cv::Mat VisOcr(const cv::Mat& srcimg, const OCRResult& ocr_result);
|
||||
|
||||
};
|
||||
|
||||
} // namespace vision
|
||||
|
@@ -4,13 +4,17 @@ DetectionResult代码定义在`csrcs/fastdeploy/vision/common/result.h`中,用
|
||||
|
||||
## C++ 定义
|
||||
|
||||
`fastdeploy::vision::DetectionResult`
|
||||
```c++
|
||||
fastdeploy::vision::DetectionResult
|
||||
```
|
||||
|
||||
```c++
|
||||
struct DetectionResult {
|
||||
std::vector<std::array<float, 4>> boxes;
|
||||
std::vector<float> scores;
|
||||
std::vector<int32_t> label_ids;
|
||||
std::vector<Mask> masks;
|
||||
bool contain_masks = false;
|
||||
void Clear();
|
||||
std::string Str();
|
||||
};
|
||||
@@ -19,13 +23,42 @@ struct DetectionResult {
|
||||
- **boxes**: 成员变量,表示单张图片检测出来的所有目标框坐标,`boxes.size()`表示框的个数,每个框以4个float数值依次表示xmin, ymin, xmax, ymax, 即左上角和右下角坐标
|
||||
- **scores**: 成员变量,表示单张图片检测出来的所有目标置信度,其元素个数与`boxes.size()`一致
|
||||
- **label_ids**: 成员变量,表示单张图片检测出来的所有目标类别,其元素个数与`boxes.size()`一致
|
||||
- **masks**: 成员变量,表示单张图片检测出来的所有实例mask,其元素个数及shape大小与`boxes`一致
|
||||
- **contain_masks**: 成员变量,表示检测结果中是否包含实例mask,实例分割模型的结果此项一般为true.
|
||||
- **Clear()**: 成员函数,用于清除结构体中存储的结果
|
||||
- **Str()**: 成员函数,将结构体中的信息以字符串形式输出(用于Debug)
|
||||
|
||||
```c++
|
||||
fastdeploy::vision::Mask
|
||||
```
|
||||
```c++
|
||||
struct Mask {
|
||||
std::vector<int32_t> data;
|
||||
std::vector<int64_t> shape; // (H,W) ...
|
||||
|
||||
void Clear();
|
||||
std::string Str();
|
||||
};
|
||||
```
|
||||
- **data**: 成员变量,表示检测到的一个mask
|
||||
- **shape**: 成员变量,表示mask的shape,如 (h,w)
|
||||
- **Clear()**: 成员函数,用于清除结构体中存储的结果
|
||||
- **Str()**: 成员函数,将结构体中的信息以字符串形式输出(用于Debug)
|
||||
|
||||
## Python 定义
|
||||
|
||||
`fastdeploy.vision.DetectionResult`
|
||||
```python
|
||||
fastdeploy.vision.DetectionResult
|
||||
```
|
||||
|
||||
- **boxes**(list of list(float)): 成员变量,表示单张图片检测出来的所有目标框坐标。boxes是一个list,其每个元素为一个长度为4的list, 表示为一个框,每个框以4个float数值依次表示xmin, ymin, xmax, ymax, 即左上角和右下角坐标
|
||||
- **scores**(list of float): 成员变量,表示单张图片检测出来的所有目标置信度
|
||||
- **label_ids**(list of int): 成员变量,表示单张图片检测出来的所有目标类别
|
||||
- **masks**: 成员变量,表示单张图片检测出来的所有实例mask,其元素个数及shape大小与`boxes`一致
|
||||
- **contain_masks**: 成员变量,表示检测结果中是否包含实例mask,实例分割模型的结果此项一般为True.
|
||||
|
||||
```python
|
||||
fastdeploy.vision.Mask
|
||||
```
|
||||
- **data**: 成员变量,表示检测到的一个mask
|
||||
- **shape**: 成员变量,表示mask的shape,如 (h,w)
|
||||
|
@@ -14,6 +14,7 @@
|
||||
- [YOLOv3系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/yolov3)
|
||||
- [YOLOX系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/yolox)
|
||||
- [FasterRCNN系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/faster_rcnn)
|
||||
- [MaskRCNN系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/mask_rcnn)
|
||||
|
||||
## 导出部署模型
|
||||
|
||||
@@ -32,13 +33,14 @@
|
||||
|
||||
| 模型 | 参数大小 | 精度 | 备注 |
|
||||
|:---------------------------------------------------------------- |:----- |:----- | :------ |
|
||||
| [picodet_l_320_coco_lcnet](https://bj.bcebos.com/paddlehub/fastdeploy/picodet_l_320_coco_lcnet.tgz) |23MB | 42.6% |
|
||||
| [ppyoloe_crn_l_300e_coco](https://bj.bcebos.com/paddlehub/fastdeploy/ppyoloe_crn_l_300e_coco.tgz) |200MB | 51.4% |
|
||||
| [ppyolo_r50vd_dcn_1x_coco](https://bj.bcebos.com/paddlehub/fastdeploy/ppyolo_r50vd_dcn_1x_coco.tgz) | 180MB | 44.8% | 暂不支持TensorRT |
|
||||
| [ppyolov2_r101vd_dcn_365e_coco](https://bj.bcebos.com/paddlehub/fastdeploy/ppyolov2_r101vd_dcn_365e_coco.tgz) | 282MB | 49.7% | 暂不支持TensorRT |
|
||||
| [yolov3_darknet53_270e_coco](https://bj.bcebos.com/paddlehub/fastdeploy/yolov3_darknet53_270e_coco.tgz) |237MB | 39.1% | |
|
||||
| [yolox_s_300e_coco](https://bj.bcebos.com/paddlehub/fastdeploy/yolox_s_300e_coco.tgz) | 35MB | 40.4% | |
|
||||
| [faster_rcnn_r50_vd_fpn_2x_coco](https://bj.bcebos.com/paddlehub/fastdeploy/faster_rcnn_r50_vd_fpn_2x_coco.tgz) | 160MB | 40.8%| 暂不支持TensorRT |
|
||||
| [picodet_l_320_coco_lcnet](https://bj.bcebos.com/paddlehub/fastdeploy/picodet_l_320_coco_lcnet.tgz) |23MB | Box AP 42.6% |
|
||||
| [ppyoloe_crn_l_300e_coco](https://bj.bcebos.com/paddlehub/fastdeploy/ppyoloe_crn_l_300e_coco.tgz) |200MB | Box AP 51.4% |
|
||||
| [ppyolo_r50vd_dcn_1x_coco](https://bj.bcebos.com/paddlehub/fastdeploy/ppyolo_r50vd_dcn_1x_coco.tgz) | 180MB | Box AP 44.8% | 暂不支持TensorRT |
|
||||
| [ppyolov2_r101vd_dcn_365e_coco](https://bj.bcebos.com/paddlehub/fastdeploy/ppyolov2_r101vd_dcn_365e_coco.tgz) | 282MB | Box AP 49.7% | 暂不支持TensorRT |
|
||||
| [yolov3_darknet53_270e_coco](https://bj.bcebos.com/paddlehub/fastdeploy/yolov3_darknet53_270e_coco.tgz) |237MB | Box AP 39.1% | |
|
||||
| [yolox_s_300e_coco](https://bj.bcebos.com/paddlehub/fastdeploy/yolox_s_300e_coco.tgz) | 35MB | Box AP 40.4% | |
|
||||
| [faster_rcnn_r50_vd_fpn_2x_coco](https://bj.bcebos.com/paddlehub/fastdeploy/faster_rcnn_r50_vd_fpn_2x_coco.tgz) | 160MB | Box AP 40.8%| 暂不支持TensorRT |
|
||||
| [mask_rcnn_r50_1x_coco](https://bj.bcebos.com/paddlehub/fastdeploy/mask_rcnn_r50_1x_coco.tgz) | 128M | Box AP 37.4%, Mask AP 32.8%| 暂不支持TensorRT、ORT |
|
||||
|
||||
|
||||
## 详细部署文档
|
||||
|
100
examples/vision/detection/paddledetection/cpp/infer_mask_rcnn.cc
Normal file
100
examples/vision/detection/paddledetection/cpp/infer_mask_rcnn.cc
Normal file
@@ -0,0 +1,100 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision.h"
|
||||
|
||||
#ifdef WIN32
|
||||
const char sep = '\\';
|
||||
#else
|
||||
const char sep = '/';
|
||||
#endif
|
||||
|
||||
void CpuInfer(const std::string& model_dir, const std::string& image_file) {
|
||||
auto model_file = model_dir + sep + "model.pdmodel";
|
||||
auto params_file = model_dir + sep + "model.pdiparams";
|
||||
auto config_file = model_dir + sep + "infer_cfg.yml";
|
||||
auto model = fastdeploy::vision::detection::MaskRCNN(model_file, params_file,
|
||||
config_file);
|
||||
if (!model.Initialized()) {
|
||||
std::cerr << "Failed to initialize." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
auto im = cv::imread(image_file);
|
||||
auto im_bak = im.clone();
|
||||
|
||||
fastdeploy::vision::DetectionResult res;
|
||||
if (!model.Predict(&im, &res)) {
|
||||
std::cerr << "Failed to predict." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
std::cout << res.Str() << std::endl;
|
||||
auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res, 0.5);
|
||||
cv::imwrite("vis_result.jpg", vis_im);
|
||||
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
|
||||
}
|
||||
|
||||
void GpuInfer(const std::string& model_dir, const std::string& image_file) {
|
||||
auto model_file = model_dir + sep + "model.pdmodel";
|
||||
auto params_file = model_dir + sep + "model.pdiparams";
|
||||
auto config_file = model_dir + sep + "infer_cfg.yml";
|
||||
|
||||
auto option = fastdeploy::RuntimeOption();
|
||||
option.UseGpu();
|
||||
auto model = fastdeploy::vision::detection::MaskRCNN(model_file, params_file,
|
||||
config_file, option);
|
||||
if (!model.Initialized()) {
|
||||
std::cerr << "Failed to initialize." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
auto im = cv::imread(image_file);
|
||||
auto im_bak = im.clone();
|
||||
|
||||
fastdeploy::vision::DetectionResult res;
|
||||
if (!model.Predict(&im, &res)) {
|
||||
std::cerr << "Failed to predict." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
std::cout << res.Str() << std::endl;
|
||||
auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res, 0.5);
|
||||
cv::imwrite("vis_result.jpg", vis_im);
|
||||
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
if (argc < 4) {
|
||||
std::cout
|
||||
<< "Usage: infer_demo path/to/model_dir path/to/image run_option, "
|
||||
"e.g ./infer_model ./mask_rcnn_r50_1x_coco/ ./test.jpeg 0"
|
||||
<< std::endl;
|
||||
std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
|
||||
"with gpu."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (std::atoi(argv[3]) == 0) {
|
||||
CpuInfer(argv[1], argv[2]);
|
||||
} else if (std::atoi(argv[3]) == 1) {
|
||||
GpuInfer(argv[1], argv[2]);
|
||||
} else if (std::atoi(argv[3]) == 2) {
|
||||
std::cout
|
||||
<< "Backend::TRT has not been supported yet, will skip this inference."
|
||||
<< std::endl;
|
||||
}
|
||||
return 0;
|
||||
}
|
@@ -39,6 +39,7 @@ fastdeploy.vision.detection.PaddleYOLOX(model_file, params_file, config_file, ru
|
||||
fastdeploy.vision.detection.YOLOv3(model_file, params_file, config_file, runtime_option=None, model_format=Frontend.PADDLE)
|
||||
fastdeploy.vision.detection.PPYOLO(model_file, params_file, config_file, runtime_option=None, model_format=Frontend.PADDLE)
|
||||
fastdeploy.vision.detection.FasterRCNN(model_file, params_file, config_file, runtime_option=None, model_format=Frontend.PADDLE)
|
||||
fastdeploy.vision.detection.MaskRCNN(model_file, params_file, config_file, runtime_option=None, model_format=Frontend.PADDLE)
|
||||
```
|
||||
|
||||
PaddleDetection模型加载和初始化,其中model_file, params_file为导出的Paddle部署模型格式, config_file为PaddleDetection同时导出的部署配置yaml文件
|
||||
|
@@ -0,0 +1,69 @@
|
||||
import fastdeploy as fd
|
||||
import cv2
|
||||
import os
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
import argparse
|
||||
import ast
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_dir",
|
||||
required=True,
|
||||
help="Path of PaddleDetection model directory")
|
||||
parser.add_argument(
|
||||
"--image", required=True, help="Path of test image file.")
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default='cpu',
|
||||
help="Type of inference device, support 'cpu' or 'gpu'.")
|
||||
parser.add_argument(
|
||||
"--use_trt",
|
||||
type=ast.literal_eval,
|
||||
default=False,
|
||||
help="Wether to use tensorrt.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def build_option(args):
|
||||
option = fd.RuntimeOption()
|
||||
|
||||
if args.device.lower() == "gpu":
|
||||
# option.use_gpu()
|
||||
print(
|
||||
"""GPU inference with Backend::Paddle in python has not been supported yet. \
|
||||
\nWill ignore this option.""")
|
||||
|
||||
if args.use_trt:
|
||||
# TODO(qiuyanjun): may remove TRT option
|
||||
# Backend::TRT has not been supported yet.
|
||||
print(
|
||||
"""Backend::TRT has not been supported yet, will ignore this option.\
|
||||
\nPaddleDetection/MaskRCNN has only support Backend::Paddle now."""
|
||||
)
|
||||
|
||||
return option
|
||||
|
||||
|
||||
args = parse_arguments()
|
||||
|
||||
model_file = os.path.join(args.model_dir, "model.pdmodel")
|
||||
params_file = os.path.join(args.model_dir, "model.pdiparams")
|
||||
config_file = os.path.join(args.model_dir, "infer_cfg.yml")
|
||||
|
||||
# 配置runtime,加载模型
|
||||
runtime_option = build_option(args)
|
||||
model = fd.vision.detection.MaskRCNN(
|
||||
model_file, params_file, config_file, runtime_option=runtime_option)
|
||||
|
||||
# 预测图片检测结果
|
||||
im = cv2.imread(args.image)
|
||||
result = model.predict(im.copy())
|
||||
print(result)
|
||||
|
||||
# 预测结果可视化
|
||||
vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5)
|
||||
cv2.imwrite("visualized_result.jpg", vis_im)
|
||||
print("Visualized result save in ./visualized_result.jpg")
|
||||
print(runtime_option)
|
@@ -23,4 +23,4 @@ from .contrib.yolov5lite import YOLOv5Lite
|
||||
from .contrib.yolov6 import YOLOv6
|
||||
from .contrib.yolov7end2end_trt import YOLOv7End2EndTRT
|
||||
from .contrib.yolov7end2end_ort import YOLOv7End2EndORT
|
||||
from .ppdet import PPYOLOE, PPYOLO, PPYOLOv2, PaddleYOLOX, PicoDet, FasterRCNN, YOLOv3
|
||||
from .ppdet import PPYOLOE, PPYOLO, PPYOLOv2, PaddleYOLOX, PicoDet, FasterRCNN, YOLOv3, MaskRCNN
|
||||
|
@@ -28,8 +28,8 @@ class PPYOLOE(FastDeployModel):
|
||||
super(PPYOLOE, self).__init__(runtime_option)
|
||||
|
||||
assert model_format == Frontend.PADDLE, "PPYOLOE model only support model format of Frontend.Paddle now."
|
||||
self._model = C.vision.detection.PPYOLOE(model_file, params_file,
|
||||
config_file, self._runtime_option,
|
||||
self._model = C.vision.detection.PPYOLOE(
|
||||
model_file, params_file, config_file, self._runtime_option,
|
||||
model_format)
|
||||
assert self.initialized, "PPYOLOE model initialize failed."
|
||||
|
||||
@@ -48,8 +48,8 @@ class PPYOLO(PPYOLOE):
|
||||
super(PPYOLOE, self).__init__(runtime_option)
|
||||
|
||||
assert model_format == Frontend.PADDLE, "PPYOLO model only support model format of Frontend.Paddle now."
|
||||
self._model = C.vision.detection.PPYOLO(model_file, params_file,
|
||||
config_file, self._runtime_option,
|
||||
self._model = C.vision.detection.PPYOLO(
|
||||
model_file, params_file, config_file, self._runtime_option,
|
||||
model_format)
|
||||
assert self.initialized, "PPYOLO model initialize failed."
|
||||
|
||||
@@ -64,8 +64,8 @@ class PPYOLOv2(PPYOLOE):
|
||||
super(PPYOLOE, self).__init__(runtime_option)
|
||||
|
||||
assert model_format == Frontend.PADDLE, "PPYOLOv2 model only support model format of Frontend.Paddle now."
|
||||
self._model = C.vision.detection.PPYOLOv2(model_file, params_file,
|
||||
config_file, self._runtime_option,
|
||||
self._model = C.vision.detection.PPYOLOv2(
|
||||
model_file, params_file, config_file, self._runtime_option,
|
||||
model_format)
|
||||
assert self.initialized, "PPYOLOv2 model initialize failed."
|
||||
|
||||
@@ -80,8 +80,8 @@ class PaddleYOLOX(PPYOLOE):
|
||||
super(PPYOLOE, self).__init__(runtime_option)
|
||||
|
||||
assert model_format == Frontend.PADDLE, "PaddleYOLOX model only support model format of Frontend.Paddle now."
|
||||
self._model = C.vision.detection.PaddleYOLOX(model_file, params_file,
|
||||
config_file, self._runtime_option,
|
||||
self._model = C.vision.detection.PaddleYOLOX(
|
||||
model_file, params_file, config_file, self._runtime_option,
|
||||
model_format)
|
||||
assert self.initialized, "PaddleYOLOX model initialize failed."
|
||||
|
||||
@@ -96,8 +96,8 @@ class PicoDet(PPYOLOE):
|
||||
super(PPYOLOE, self).__init__(runtime_option)
|
||||
|
||||
assert model_format == Frontend.PADDLE, "PicoDet model only support model format of Frontend.Paddle now."
|
||||
self._model = C.vision.detection.PicoDet(model_file, params_file,
|
||||
config_file, self._runtime_option,
|
||||
self._model = C.vision.detection.PicoDet(
|
||||
model_file, params_file, config_file, self._runtime_option,
|
||||
model_format)
|
||||
assert self.initialized, "PicoDet model initialize failed."
|
||||
|
||||
@@ -128,7 +128,27 @@ class YOLOv3(PPYOLOE):
|
||||
super(PPYOLOE, self).__init__(runtime_option)
|
||||
|
||||
assert model_format == Frontend.PADDLE, "YOLOv3 model only support model format of Frontend.Paddle now."
|
||||
self._model = C.vision.detection.YOLOv3(model_file, params_file,
|
||||
config_file, self._runtime_option,
|
||||
self._model = C.vision.detection.YOLOv3(
|
||||
model_file, params_file, config_file, self._runtime_option,
|
||||
model_format)
|
||||
assert self.initialized, "YOLOv3 model initialize failed."
|
||||
|
||||
|
||||
class MaskRCNN(FastDeployModel):
|
||||
def __init__(self,
|
||||
model_file,
|
||||
params_file,
|
||||
config_file,
|
||||
runtime_option=None,
|
||||
model_format=Frontend.PADDLE):
|
||||
super(MaskRCNN, self).__init__(runtime_option)
|
||||
|
||||
assert model_format == Frontend.PADDLE, "MaskRCNN model only support model format of Frontend.Paddle now."
|
||||
self._model = C.vision.detection.MaskRCNN(
|
||||
model_file, params_file, config_file, self._runtime_option,
|
||||
model_format)
|
||||
assert self.initialized, "MaskRCNN model initialize failed."
|
||||
|
||||
def predict(self, input_image):
|
||||
assert input_image is not None, "The input image data is None."
|
||||
return self._model.predict(input_image)
|
||||
|
Reference in New Issue
Block a user