mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 09:07:10 +08:00
Fix classification and move segmentation
This commit is contained in:
@@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
#include "fastdeploy/core/config.h"
|
#include "fastdeploy/core/config.h"
|
||||||
#ifdef ENABLE_VISION
|
#ifdef ENABLE_VISION
|
||||||
|
#include "fastdeploy/vision/classification/ppcls/model.h"
|
||||||
#include "fastdeploy/vision/detection/contrib/nanodet_plus.h"
|
#include "fastdeploy/vision/detection/contrib/nanodet_plus.h"
|
||||||
#include "fastdeploy/vision/detection/contrib/scaledyolov4.h"
|
#include "fastdeploy/vision/detection/contrib/scaledyolov4.h"
|
||||||
#include "fastdeploy/vision/detection/contrib/yolor.h"
|
#include "fastdeploy/vision/detection/contrib/yolor.h"
|
||||||
@@ -23,6 +24,7 @@
|
|||||||
#include "fastdeploy/vision/detection/contrib/yolov6.h"
|
#include "fastdeploy/vision/detection/contrib/yolov6.h"
|
||||||
#include "fastdeploy/vision/detection/contrib/yolov7.h"
|
#include "fastdeploy/vision/detection/contrib/yolov7.h"
|
||||||
#include "fastdeploy/vision/detection/contrib/yolox.h"
|
#include "fastdeploy/vision/detection/contrib/yolox.h"
|
||||||
|
#include "fastdeploy/vision/detection/ppdet/model.h"
|
||||||
#include "fastdeploy/vision/facedet/contrib/retinaface.h"
|
#include "fastdeploy/vision/facedet/contrib/retinaface.h"
|
||||||
#include "fastdeploy/vision/facedet/contrib/scrfd.h"
|
#include "fastdeploy/vision/facedet/contrib/scrfd.h"
|
||||||
#include "fastdeploy/vision/facedet/contrib/ultraface.h"
|
#include "fastdeploy/vision/facedet/contrib/ultraface.h"
|
||||||
@@ -33,9 +35,7 @@
|
|||||||
#include "fastdeploy/vision/faceid/contrib/partial_fc.h"
|
#include "fastdeploy/vision/faceid/contrib/partial_fc.h"
|
||||||
#include "fastdeploy/vision/faceid/contrib/vpl.h"
|
#include "fastdeploy/vision/faceid/contrib/vpl.h"
|
||||||
#include "fastdeploy/vision/matting/contrib/modnet.h"
|
#include "fastdeploy/vision/matting/contrib/modnet.h"
|
||||||
#include "fastdeploy/vision/classification/ppcls/model.h"
|
#include "fastdeploy/vision/segmentation/ppseg/model.h"
|
||||||
#include "fastdeploy/vision/detection/ppdet/model.h"
|
|
||||||
#include "fastdeploy/vision/ppseg/model.h"
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "fastdeploy/vision/visualize/visualize.h"
|
#include "fastdeploy/vision/visualize/visualize.h"
|
||||||
|
@@ -0,0 +1,26 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "fastdeploy/pybind/main.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
|
||||||
|
void BindPaddleClas(pybind11::module& m);
|
||||||
|
|
||||||
|
void BindClassification(pybind11::module& m) {
|
||||||
|
auto classification_module =
|
||||||
|
m.def_submodule("classification", "Image classification models.");
|
||||||
|
BindPaddleClas(classification_module);
|
||||||
|
}
|
||||||
|
} // namespace fastdeploy
|
@@ -14,16 +14,17 @@
|
|||||||
#include "fastdeploy/pybind/main.h"
|
#include "fastdeploy/pybind/main.h"
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
void BindPPCls(pybind11::module& m) {
|
void BindPaddleClas(pybind11::module& m) {
|
||||||
pybind11::class_<vision::classification::PaddleClasModel, FastDeployModel>(m, "PaddleClasModel")
|
pybind11::class_<vision::classification::PaddleClasModel, FastDeployModel>(
|
||||||
|
m, "PaddleClasModel")
|
||||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||||
Frontend>())
|
Frontend>())
|
||||||
.def("predict",
|
.def("predict", [](vision::classification::PaddleClasModel& self,
|
||||||
[](vision::classification::PaddleClasModel& self, pybind11::array& data, int topk = 1) {
|
pybind11::array& data, int topk = 1) {
|
||||||
auto mat = PyArrayToCvMat(data);
|
auto mat = PyArrayToCvMat(data);
|
||||||
vision::ClassifyResult res;
|
vision::ClassifyResult res;
|
||||||
self.Predict(&mat, &res, topk);
|
self.Predict(&mat, &res, topk);
|
||||||
return res;
|
return res;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
@@ -1,15 +1,17 @@
|
|||||||
#include "fastdeploy/vision/ppseg/model.h"
|
#include "fastdeploy/vision/segmentation/ppseg/model.h"
|
||||||
#include "fastdeploy/vision.h"
|
#include "fastdeploy/vision.h"
|
||||||
#include "fastdeploy/vision/utils/utils.h"
|
#include "fastdeploy/vision/utils/utils.h"
|
||||||
#include "yaml-cpp/yaml.h"
|
#include "yaml-cpp/yaml.h"
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
namespace ppseg {
|
namespace segmentation {
|
||||||
|
|
||||||
Model::Model(const std::string& model_file, const std::string& params_file,
|
PaddleSegModel::PaddleSegModel(const std::string& model_file,
|
||||||
const std::string& config_file, const RuntimeOption& custom_option,
|
const std::string& params_file,
|
||||||
const Frontend& model_format) {
|
const std::string& config_file,
|
||||||
|
const RuntimeOption& custom_option,
|
||||||
|
const Frontend& model_format) {
|
||||||
config_file_ = config_file;
|
config_file_ = config_file;
|
||||||
valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
|
valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
|
||||||
valid_gpu_backends = {Backend::PDINFER, Backend::ORT};
|
valid_gpu_backends = {Backend::PDINFER, Backend::ORT};
|
||||||
@@ -20,7 +22,7 @@ Model::Model(const std::string& model_file, const std::string& params_file,
|
|||||||
initialized = Initialize();
|
initialized = Initialize();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Model::Initialize() {
|
bool PaddleSegModel::Initialize() {
|
||||||
if (!BuildPreprocessPipelineFromConfig()) {
|
if (!BuildPreprocessPipelineFromConfig()) {
|
||||||
FDERROR << "Failed to build preprocess pipeline from configuration file."
|
FDERROR << "Failed to build preprocess pipeline from configuration file."
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
@@ -33,7 +35,7 @@ bool Model::Initialize() {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Model::BuildPreprocessPipelineFromConfig() {
|
bool PaddleSegModel::BuildPreprocessPipelineFromConfig() {
|
||||||
processors_.clear();
|
processors_.clear();
|
||||||
YAML::Node cfg;
|
YAML::Node cfg;
|
||||||
processors_.push_back(std::make_shared<BGR2RGB>());
|
processors_.push_back(std::make_shared<BGR2RGB>());
|
||||||
@@ -75,8 +77,9 @@ bool Model::BuildPreprocessPipelineFromConfig() {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Model::Preprocess(Mat* mat, FDTensor* output,
|
bool PaddleSegModel::Preprocess(
|
||||||
std::map<std::string, std::array<int, 2>>* im_info) {
|
Mat* mat, FDTensor* output,
|
||||||
|
std::map<std::string, std::array<int, 2>>* im_info) {
|
||||||
for (size_t i = 0; i < processors_.size(); ++i) {
|
for (size_t i = 0; i < processors_.size(); ++i) {
|
||||||
if (processors_[i]->Name().compare("Resize") == 0) {
|
if (processors_[i]->Name().compare("Resize") == 0) {
|
||||||
auto processor = dynamic_cast<Resize*>(processors_[i].get());
|
auto processor = dynamic_cast<Resize*>(processors_[i].get());
|
||||||
@@ -107,8 +110,9 @@ bool Model::Preprocess(Mat* mat, FDTensor* output,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Model::Postprocess(FDTensor& infer_result, SegmentationResult* result,
|
bool PaddleSegModel::Postprocess(
|
||||||
std::map<std::string, std::array<int, 2>>* im_info) {
|
FDTensor& infer_result, SegmentationResult* result,
|
||||||
|
std::map<std::string, std::array<int, 2>>* im_info) {
|
||||||
// PaddleSeg has three types of inference output:
|
// PaddleSeg has three types of inference output:
|
||||||
// 1. output with argmax and without softmax. 3-D matrix CHW, Channel
|
// 1. output with argmax and without softmax. 3-D matrix CHW, Channel
|
||||||
// always 1, the element in matrix is classified label_id INT64 Type.
|
// always 1, the element in matrix is classified label_id INT64 Type.
|
||||||
@@ -196,7 +200,7 @@ bool Model::Postprocess(FDTensor& infer_result, SegmentationResult* result,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Model::Predict(cv::Mat* im, SegmentationResult* result) {
|
bool PaddleSegModel::Predict(cv::Mat* im, SegmentationResult* result) {
|
||||||
Mat mat(*im);
|
Mat mat(*im);
|
||||||
std::vector<FDTensor> processed_data(1);
|
std::vector<FDTensor> processed_data(1);
|
||||||
|
|
||||||
@@ -227,6 +231,6 @@ bool Model::Predict(cv::Mat* im, SegmentationResult* result) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace ppseg
|
} // namespace segmentation
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
@@ -5,16 +5,16 @@
|
|||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
namespace ppseg {
|
namespace segmentation {
|
||||||
|
|
||||||
class FASTDEPLOY_DECL Model : public FastDeployModel {
|
class FASTDEPLOY_DECL PaddleSegModel : public FastDeployModel {
|
||||||
public:
|
public:
|
||||||
Model(const std::string& model_file, const std::string& params_file,
|
PaddleSegModel(const std::string& model_file, const std::string& params_file,
|
||||||
const std::string& config_file,
|
const std::string& config_file,
|
||||||
const RuntimeOption& custom_option = RuntimeOption(),
|
const RuntimeOption& custom_option = RuntimeOption(),
|
||||||
const Frontend& model_format = Frontend::PADDLE);
|
const Frontend& model_format = Frontend::PADDLE);
|
||||||
|
|
||||||
std::string ModelName() const { return "ppseg"; }
|
std::string ModelName() const { return "PaddleSeg"; }
|
||||||
|
|
||||||
virtual bool Predict(cv::Mat* im, SegmentationResult* result);
|
virtual bool Predict(cv::Mat* im, SegmentationResult* result);
|
||||||
|
|
||||||
@@ -38,6 +38,6 @@ class FASTDEPLOY_DECL Model : public FastDeployModel {
|
|||||||
std::vector<std::shared_ptr<Processor>> processors_;
|
std::vector<std::shared_ptr<Processor>> processors_;
|
||||||
std::string config_file_;
|
std::string config_file_;
|
||||||
};
|
};
|
||||||
} // namespace ppseg
|
} // namespace segmentation
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
@@ -15,21 +15,22 @@
|
|||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
void BindPPSeg(pybind11::module& m) {
|
void BindPPSeg(pybind11::module& m) {
|
||||||
auto ppseg_module =
|
pybind11::class_<vision::segmentation::PaddleSegModel, FastDeployModel>(
|
||||||
m.def_submodule("ppseg", "Module to deploy PaddleSegmentation.");
|
m, "PaddleSegModel")
|
||||||
pybind11::class_<vision::ppseg::Model, FastDeployModel>(ppseg_module, "Model")
|
|
||||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||||
Frontend>())
|
Frontend>())
|
||||||
.def("predict",
|
.def("predict",
|
||||||
[](vision::ppseg::Model& self, pybind11::array& data) {
|
[](vision::segmentation::PaddleSegModel& self,
|
||||||
|
pybind11::array& data) {
|
||||||
auto mat = PyArrayToCvMat(data);
|
auto mat = PyArrayToCvMat(data);
|
||||||
vision::SegmentationResult* res = new vision::SegmentationResult();
|
vision::SegmentationResult* res = new vision::SegmentationResult();
|
||||||
// self.Predict(&mat, &res);
|
// self.Predict(&mat, &res);
|
||||||
self.Predict(&mat, res);
|
self.Predict(&mat, res);
|
||||||
return res;
|
return res;
|
||||||
})
|
})
|
||||||
.def_readwrite("with_softmax", &vision::ppseg::Model::with_softmax)
|
.def_readwrite("with_softmax",
|
||||||
|
&vision::segmentation::PaddleSegModel::with_softmax)
|
||||||
.def_readwrite("is_vertical_screen",
|
.def_readwrite("is_vertical_screen",
|
||||||
&vision::ppseg::Model::is_vertical_screen);
|
&vision::segmentation::PaddleSegModel::is_vertical_screen);
|
||||||
}
|
}
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
26
csrc/fastdeploy/vision/segmentation/segmentation_pybind.cc
Normal file
26
csrc/fastdeploy/vision/segmentation/segmentation_pybind.cc
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "fastdeploy/pybind/main.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
|
||||||
|
void BindPPSeg(pybind11::module& m);
|
||||||
|
|
||||||
|
void BindSegmentation(pybind11::module& m) {
|
||||||
|
auto segmentation_module =
|
||||||
|
m.def_submodule("segmentation", "Image semantic segmentation models.");
|
||||||
|
BindPPSeg(segmentation_module);
|
||||||
|
}
|
||||||
|
} // namespace fastdeploy
|
@@ -16,10 +16,9 @@
|
|||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
|
|
||||||
void BindPPCls(pybind11::module& m);
|
|
||||||
void BindPPSeg(pybind11::module& m);
|
|
||||||
|
|
||||||
void BindDetection(pybind11::module& m);
|
void BindDetection(pybind11::module& m);
|
||||||
|
void BindClassification(pybind11::module& m);
|
||||||
|
void BindSegmentation(pybind11::module& m);
|
||||||
void BindMatting(pybind11::module& m);
|
void BindMatting(pybind11::module& m);
|
||||||
void BindFaceDet(pybind11::module& m);
|
void BindFaceDet(pybind11::module& m);
|
||||||
void BindFaceId(pybind11::module& m);
|
void BindFaceId(pybind11::module& m);
|
||||||
@@ -77,10 +76,9 @@ void BindVision(pybind11::module& m) {
|
|||||||
.def("__repr__", &vision::MattingResult::Str)
|
.def("__repr__", &vision::MattingResult::Str)
|
||||||
.def("__str__", &vision::MattingResult::Str);
|
.def("__str__", &vision::MattingResult::Str);
|
||||||
|
|
||||||
BindPPCls(m);
|
|
||||||
BindPPSeg(m);
|
|
||||||
|
|
||||||
BindDetection(m);
|
BindDetection(m);
|
||||||
|
BindClassification(m);
|
||||||
|
BindSegmentation(m);
|
||||||
BindFaceDet(m);
|
BindFaceDet(m);
|
||||||
BindFaceId(m);
|
BindFaceId(m);
|
||||||
BindMatting(m);
|
BindMatting(m);
|
||||||
|
@@ -19,7 +19,8 @@ from . import c_lib_wrap as C
|
|||||||
class Runtime:
|
class Runtime:
|
||||||
def __init__(self, runtime_option):
|
def __init__(self, runtime_option):
|
||||||
self._runtime = C.Runtime()
|
self._runtime = C.Runtime()
|
||||||
assert self._runtime.init(runtime_option), "Initialize Runtime Failed!"
|
assert self._runtime.init(
|
||||||
|
runtime_option._option), "Initialize Runtime Failed!"
|
||||||
|
|
||||||
def infer(self, data):
|
def infer(self, data):
|
||||||
assert isinstance(data, dict), "The input data should be type of dict."
|
assert isinstance(data, dict), "The input data should be type of dict."
|
||||||
|
@@ -15,11 +15,11 @@ from __future__ import absolute_import
|
|||||||
|
|
||||||
from . import detection
|
from . import detection
|
||||||
from . import classification
|
from . import classification
|
||||||
|
from . import segmentation
|
||||||
|
|
||||||
from . import matting
|
from . import matting
|
||||||
from . import facedet
|
from . import facedet
|
||||||
from . import faceid
|
from . import faceid
|
||||||
|
|
||||||
from . import ppseg
|
|
||||||
from . import evaluation
|
from . import evaluation
|
||||||
from .visualize import *
|
from .visualize import *
|
||||||
|
16
fastdeploy/vision/segmentation/__init__.py
Normal file
16
fastdeploy/vision/segmentation/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
from .ppseg import PaddleSegModel
|
@@ -14,11 +14,11 @@
|
|||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
import logging
|
import logging
|
||||||
from ... import FastDeployModel, Frontend
|
from .... import FastDeployModel, Frontend
|
||||||
from ... import c_lib_wrap as C
|
from .... import c_lib_wrap as C
|
||||||
|
|
||||||
|
|
||||||
class Model(FastDeployModel):
|
class PaddleSegModel(FastDeployModel):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_file,
|
model_file,
|
||||||
params_file,
|
params_file,
|
||||||
@@ -28,9 +28,9 @@ class Model(FastDeployModel):
|
|||||||
super(Model, self).__init__(backend_option)
|
super(Model, self).__init__(backend_option)
|
||||||
|
|
||||||
assert model_format == Frontend.PADDLE, "PaddleSeg only support model format of Frontend.Paddle now."
|
assert model_format == Frontend.PADDLE, "PaddleSeg only support model format of Frontend.Paddle now."
|
||||||
self._model = C.vision.ppseg.Model(model_file, params_file,
|
self._model = C.vision.segmentation.PaddleSegModel(
|
||||||
config_file, self._runtime_option,
|
model_file, params_file, config_file, self._runtime_option,
|
||||||
model_format)
|
model_format)
|
||||||
assert self.initialized, "PaddleSeg model initialize failed."
|
assert self.initialized, "PaddleSeg model initialize failed."
|
||||||
|
|
||||||
def predict(self, input_image):
|
def predict(self, input_image):
|
Reference in New Issue
Block a user