diff --git a/docs/api_docs/python/index.rst b/docs/api_docs/python/index.rst index 06d4a95cb..69b65b3b1 100644 --- a/docs/api_docs/python/index.rst +++ b/docs/api_docs/python/index.rst @@ -20,4 +20,8 @@ FastDeploy matting.md face_recognition.md face_detection.md + face_alignment.md + headpose.md vision_results_en.md + runtime.md + runtime_option.md diff --git a/docs/api_docs/python/runtime.md b/docs/api_docs/python/runtime.md new file mode 100644 index 000000000..4a519ee7e --- /dev/null +++ b/docs/api_docs/python/runtime.md @@ -0,0 +1,9 @@ +# Runtime API + +## fastdeploy.Runtime + +```{eval-rst} +.. autoclass:: fastdeploy.Runtime + :members: + :inherited-members: +``` diff --git a/docs/api_docs/python/runtime_option.md b/docs/api_docs/python/runtime_option.md new file mode 100644 index 000000000..96eff8672 --- /dev/null +++ b/docs/api_docs/python/runtime_option.md @@ -0,0 +1,9 @@ +# Runtime Option API + +## fastdeploy.RuntimeOption + +```{eval-rst} +.. autoclass:: fastdeploy.RuntimeOption + :members: + :inherited-members: +``` diff --git a/examples/runtime/cpp/CMakeLists.txt b/examples/runtime/cpp/CMakeLists.txt new file mode 100644 index 000000000..09ea45c3b --- /dev/null +++ b/examples/runtime/cpp/CMakeLists.txt @@ -0,0 +1,14 @@ +PROJECT(runtime_demo C CXX) +CMAKE_MINIMUM_REQUIRED (VERSION 3.12) + +# 指定下载解压后的fastdeploy库路径 +option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.") + +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) + +# 添加FastDeploy依赖头文件 +include_directories(${FASTDEPLOY_INCS}) + +add_executable(runtime_demo ${PROJECT_SOURCE_DIR}/infer_onnx_openvino.cc) +# 添加FastDeploy库依赖 +target_link_libraries(runtime_demo ${FASTDEPLOY_LIBS}) diff --git a/examples/runtime/cpp/infer_onnx_openvino.cc b/examples/runtime/cpp/infer_onnx_openvino.cc new file mode 100644 index 000000000..c2f270be9 --- /dev/null +++ b/examples/runtime/cpp/infer_onnx_openvino.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/runtime.h" + +namespace fd = fastdeploy; + +int main(int argc, char* argv[]) { + std::string model_file = "mobilenetv2.onnx"; + + // setup option + fd::RuntimeOption runtime_option; + runtime_option.SetModelPath(model_file, "", fd::ModelFormat::ONNX); + runtime_option.UseOpenVINOBackend(); + runtime_option.SetCpuThreadNum(12); + // init runtime + std::unique_ptr runtime = + std::unique_ptr(new fd::Runtime()); + if (!runtime->Init(runtime_option)) { + std::cerr << "--- Init FastDeploy Runitme Failed! " + << "\n--- Model: " << model_file << std::endl; + return -1; + } else { + std::cout << "--- Init FastDeploy Runitme Done! " + << "\n--- Model: " << model_file << std::endl; + } + // init input tensor shape + fd::TensorInfo info = runtime->GetInputInfo(0); + info.shape = {1, 3, 224, 224}; + + std::vector input_tensors(1); + std::vector output_tensors(1); + + std::vector inputs_data; + inputs_data.resize(1 * 3 * 224 * 224); + for (size_t i = 0; i < inputs_data.size(); ++i) { + inputs_data[i] = std::rand() % 1000 / 1000.0f; + } + input_tensors[0].SetExternalData({1, 3, 224, 224}, fd::FDDataType::FP32, inputs_data.data()); + + //get input name + input_tensors[0].name = info.name; + + runtime->Infer(input_tensors, &output_tensors); + + output_tensors[0].PrintInfo(); + return 0; +} \ No newline at end of file diff --git a/examples/runtime/cpp/infer_onnx_tensorrt.cc b/examples/runtime/cpp/infer_onnx_tensorrt.cc new file mode 100644 index 000000000..084c1dfae --- /dev/null +++ b/examples/runtime/cpp/infer_onnx_tensorrt.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/runtime.h" + +namespace fd = fastdeploy; + +int main(int argc, char* argv[]) { + std::string model_file = "mobilenetv2.onnx"; + + // setup option + fd::RuntimeOption runtime_option; + runtime_option.SetModelPath(model_file, "", fd::ModelFormat::ONNX); + runtime_option.UseGpu(0); + runtime_option.UseTrtBackend(); + runtime_option.SetTrtInputShape("inputs", {1, 3, 224, 224}); + // init runtime + std::unique_ptr runtime = + std::unique_ptr(new fd::Runtime()); + if (!runtime->Init(runtime_option)) { + std::cerr << "--- Init FastDeploy Runitme Failed! " + << "\n--- Model: " << model_file << std::endl; + return -1; + } else { + std::cout << "--- Init FastDeploy Runitme Done! " + << "\n--- Model: " << model_file << std::endl; + } + // init input tensor shape + fd::TensorInfo info = runtime->GetInputInfo(0); + info.shape = {1, 3, 224, 224}; + + std::vector input_tensors(1); + std::vector output_tensors(1); + + std::vector inputs_data; + inputs_data.resize(1 * 3 * 224 * 224); + for (size_t i = 0; i < inputs_data.size(); ++i) { + inputs_data[i] = std::rand() % 1000 / 1000.0f; + } + input_tensors[0].SetExternalData({1, 3, 224, 224}, fd::FDDataType::FP32, inputs_data.data()); + + //get input name + input_tensors[0].name = info.name; + + runtime->Infer(input_tensors, &output_tensors); + + output_tensors[0].PrintInfo(); + return 0; +} \ No newline at end of file diff --git a/examples/runtime/cpp/infer_paddle_onnxruntime.cc b/examples/runtime/cpp/infer_paddle_onnxruntime.cc new file mode 100644 index 000000000..d8d036a03 --- /dev/null +++ b/examples/runtime/cpp/infer_paddle_onnxruntime.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/runtime.h" + +namespace fd = fastdeploy; + +int main(int argc, char* argv[]) { + std::string model_file = "mobilenetv2/inference.pdmodel"; + std::string params_file = "mobilenetv2/inference.pdiparams"; + + // setup option + fd::RuntimeOption runtime_option; + runtime_option.SetModelPath(model_file, params_file, fd::ModelFormat::PADDLE); + runtime_option.UseOrtBackend(); + runtime_option.SetCpuThreadNum(12); + // init runtime + std::unique_ptr runtime = + std::unique_ptr(new fd::Runtime()); + if (!runtime->Init(runtime_option)) { + std::cerr << "--- Init FastDeploy Runitme Failed! " + << "\n--- Model: " << model_file << std::endl; + return -1; + } else { + std::cout << "--- Init FastDeploy Runitme Done! " + << "\n--- Model: " << model_file << std::endl; + } + // init input tensor shape + fd::TensorInfo info = runtime->GetInputInfo(0); + info.shape = {1, 3, 224, 224}; + + std::vector input_tensors(1); + std::vector output_tensors(1); + + std::vector inputs_data; + inputs_data.resize(1 * 3 * 224 * 224); + for (size_t i = 0; i < inputs_data.size(); ++i) { + inputs_data[i] = std::rand() % 1000 / 1000.0f; + } + input_tensors[0].SetExternalData({1, 3, 224, 224}, fd::FDDataType::FP32, inputs_data.data()); + + //get input name + input_tensors[0].name = info.name; + + runtime->Infer(input_tensors, &output_tensors); + + output_tensors[0].PrintInfo(); + return 0; +} \ No newline at end of file diff --git a/examples/runtime/cpp/infer_paddle_openvino.cc b/examples/runtime/cpp/infer_paddle_openvino.cc new file mode 100644 index 000000000..3958cdcf0 --- /dev/null +++ b/examples/runtime/cpp/infer_paddle_openvino.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/runtime.h" + +namespace fd = fastdeploy; + +int main(int argc, char* argv[]) { + std::string model_file = "mobilenetv2/inference.pdmodel"; + std::string params_file = "mobilenetv2/inference.pdiparams"; + + // setup option + fd::RuntimeOption runtime_option; + runtime_option.SetModelPath(model_file, params_file, fd::ModelFormat::PADDLE); + runtime_option.UseOpenVINOBackend(); + runtime_option.SetCpuThreadNum(12); + // init runtime + std::unique_ptr runtime = + std::unique_ptr(new fd::Runtime()); + if (!runtime->Init(runtime_option)) { + std::cerr << "--- Init FastDeploy Runitme Failed! " + << "\n--- Model: " << model_file << std::endl; + return -1; + } else { + std::cout << "--- Init FastDeploy Runitme Done! " + << "\n--- Model: " << model_file << std::endl; + } + // init input tensor shape + fd::TensorInfo info = runtime->GetInputInfo(0); + info.shape = {1, 3, 224, 224}; + + std::vector input_tensors(1); + std::vector output_tensors(1); + + std::vector inputs_data; + inputs_data.resize(1 * 3 * 224 * 224); + for (size_t i = 0; i < inputs_data.size(); ++i) { + inputs_data[i] = std::rand() % 1000 / 1000.0f; + } + input_tensors[0].SetExternalData({1, 3, 224, 224}, fd::FDDataType::FP32, inputs_data.data()); + + //get input name + input_tensors[0].name = info.name; + + runtime->Infer(input_tensors, &output_tensors); + + output_tensors[0].PrintInfo(); + return 0; +} \ No newline at end of file diff --git a/examples/runtime/cpp/infer_paddle_paddle_inference.cc b/examples/runtime/cpp/infer_paddle_paddle_inference.cc new file mode 100644 index 000000000..1d0bd82ad --- /dev/null +++ b/examples/runtime/cpp/infer_paddle_paddle_inference.cc @@ -0,0 +1,65 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/runtime.h" + +namespace fd = fastdeploy; + +int main(int argc, char* argv[]) { + std::string model_file = "mobilenetv2/inference.pdmodel"; + std::string params_file = "mobilenetv2/inference.pdiparams"; + + // setup option + fd::RuntimeOption runtime_option; + runtime_option.SetModelPath(model_file, params_file, fd::ModelFormat::PADDLE); + // CPU + runtime_option.UsePaddleBackend(); + runtime_option.SetCpuThreadNum(12); + // GPU + // runtime_option.UseGpu(0); + // IPU + // runtime_option.UseIpu(); + // init runtime + std::unique_ptr runtime = + std::unique_ptr(new fd::Runtime()); + if (!runtime->Init(runtime_option)) { + std::cerr << "--- Init FastDeploy Runitme Failed! " + << "\n--- Model: " << model_file << std::endl; + return -1; + } else { + std::cout << "--- Init FastDeploy Runitme Done! " + << "\n--- Model: " << model_file << std::endl; + } + // init input tensor shape + fd::TensorInfo info = runtime->GetInputInfo(0); + info.shape = {1, 3, 224, 224}; + + std::vector input_tensors(1); + std::vector output_tensors(1); + + std::vector inputs_data; + inputs_data.resize(1 * 3 * 224 * 224); + for (size_t i = 0; i < inputs_data.size(); ++i) { + inputs_data[i] = std::rand() % 1000 / 1000.0f; + } + input_tensors[0].SetExternalData({1, 3, 224, 224}, fd::FDDataType::FP32, inputs_data.data()); + + //get input name + input_tensors[0].name = info.name; + + runtime->Infer(input_tensors, &output_tensors); + + output_tensors[0].PrintInfo(); + return 0; +} \ No newline at end of file diff --git a/examples/runtime/cpp/infer_paddle_tensorrt.cc b/examples/runtime/cpp/infer_paddle_tensorrt.cc new file mode 100644 index 000000000..04fe311b2 --- /dev/null +++ b/examples/runtime/cpp/infer_paddle_tensorrt.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/runtime.h" + +namespace fd = fastdeploy; + +int main(int argc, char* argv[]) { + std::string model_file = "mobilenetv2/inference.pdmodel"; + std::string params_file = "mobilenetv2/inference.pdiparams"; + + // setup option + fd::RuntimeOption runtime_option; + runtime_option.SetModelPath(model_file, params_file, fd::ModelFormat::PADDLE); + runtime_option.UseGpu(0); + runtime_option.UseTrtBackend(); + runtime_option.EnablePaddleToTrt(); + // init runtime + std::unique_ptr runtime = + std::unique_ptr(new fd::Runtime()); + if (!runtime->Init(runtime_option)) { + std::cerr << "--- Init FastDeploy Runitme Failed! " + << "\n--- Model: " << model_file << std::endl; + return -1; + } else { + std::cout << "--- Init FastDeploy Runitme Done! " + << "\n--- Model: " << model_file << std::endl; + } + // init input tensor shape + fd::TensorInfo info = runtime->GetInputInfo(0); + info.shape = {1, 3, 224, 224}; + + std::vector input_tensors(1); + std::vector output_tensors(1); + + std::vector inputs_data; + inputs_data.resize(1 * 3 * 224 * 224); + for (size_t i = 0; i < inputs_data.size(); ++i) { + inputs_data[i] = std::rand() % 1000 / 1000.0f; + } + input_tensors[0].SetExternalData({1, 3, 224, 224}, fd::FDDataType::FP32, inputs_data.data()); + + //get input name + input_tensors[0].name = info.name; + + runtime->Infer(input_tensors, &output_tensors); + + output_tensors[0].PrintInfo(); + return 0; +} \ No newline at end of file diff --git a/examples/runtime/python/infer_paddle_tensorrt.py b/examples/runtime/python/infer_paddle_tensorrt.py index ad2b8e197..94c95cb87 100644 --- a/examples/runtime/python/infer_paddle_tensorrt.py +++ b/examples/runtime/python/infer_paddle_tensorrt.py @@ -27,6 +27,8 @@ option.set_model_path("mobilenetv2/inference.pdmodel", # **** GPU 配置 *** option.use_gpu(0) option.use_trt_backend() +# using TensorRT integrated in Paddle Inference +# option.enable_paddle_to_trt() # 初始化构造runtime runtime = fd.Runtime(option) diff --git a/fastdeploy/function/concat.cc b/fastdeploy/function/concat.cc index c2b1f2744..32ca407c0 100644 --- a/fastdeploy/function/concat.cc +++ b/fastdeploy/function/concat.cc @@ -121,4 +121,4 @@ void Concat(const std::vector& x, FDTensor* out, int axis) { *out = std::move(out_temp); } -} // namespace fastdeploy +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/pybind/runtime.cc b/fastdeploy/pybind/runtime.cc old mode 100755 new mode 100644 diff --git a/fastdeploy/vision/classification/ppcls/model.cc b/fastdeploy/vision/classification/ppcls/model.cc index 3eed25c6c..5f88e0a72 100644 --- a/fastdeploy/vision/classification/ppcls/model.cc +++ b/fastdeploy/vision/classification/ppcls/model.cc @@ -14,9 +14,6 @@ #include "fastdeploy/vision/classification/ppcls/model.h" -#include "fastdeploy/vision/utils/utils.h" -#include "yaml-cpp/yaml.h" - namespace fastdeploy { namespace vision { namespace classification { @@ -25,8 +22,7 @@ PaddleClasModel::PaddleClasModel(const std::string& model_file, const std::string& params_file, const std::string& config_file, const RuntimeOption& custom_option, - const ModelFormat& model_format) { - config_file_ = config_file; + const ModelFormat& model_format) : preprocessor_(config_file) { valid_cpu_backends = {Backend::ORT, Backend::OPENVINO, Backend::PDINFER, Backend::LITE}; valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT}; @@ -38,11 +34,6 @@ PaddleClasModel::PaddleClasModel(const std::string& model_file, } bool PaddleClasModel::Initialize() { - if (!BuildPreprocessPipelineFromConfig()) { - FDERROR << "Failed to build preprocess pipeline from configuration file." - << std::endl; - return false; - } if (!InitRuntime()) { FDERROR << "Failed to initialize fastdeploy backend." << std::endl; return false; @@ -50,105 +41,41 @@ bool PaddleClasModel::Initialize() { return true; } -bool PaddleClasModel::BuildPreprocessPipelineFromConfig() { - processors_.clear(); - YAML::Node cfg; - try { - cfg = YAML::LoadFile(config_file_); - } catch (YAML::BadFile& e) { - FDERROR << "Failed to load yaml file " << config_file_ - << ", maybe you should check this file." << std::endl; - return false; - } - auto preprocess_cfg = cfg["PreProcess"]["transform_ops"]; - processors_.push_back(std::make_shared()); - for (const auto& op : preprocess_cfg) { - FDASSERT(op.IsMap(), - "Require the transform information in yaml be Map type."); - auto op_name = op.begin()->first.as(); - if (op_name == "ResizeImage") { - int target_size = op.begin()->second["resize_short"].as(); - bool use_scale = false; - int interp = 1; - processors_.push_back( - std::make_shared(target_size, 1, use_scale)); - } else if (op_name == "CropImage") { - int width = op.begin()->second["size"].as(); - int height = op.begin()->second["size"].as(); - processors_.push_back(std::make_shared(width, height)); - } else if (op_name == "NormalizeImage") { - auto mean = op.begin()->second["mean"].as>(); - auto std = op.begin()->second["std"].as>(); - auto scale = op.begin()->second["scale"].as(); - FDASSERT((scale - 0.00392157) < 1e-06 && (scale - 0.00392157) > -1e-06, - "Only support scale in Normalize be 0.00392157, means the pixel " - "is in range of [0, 255]."); - processors_.push_back(std::make_shared(mean, std)); - } else if (op_name == "ToCHWImage") { - processors_.push_back(std::make_shared()); - } else { - FDERROR << "Unexcepted preprocess operator: " << op_name << "." - << std::endl; - return false; - } - } - return true; -} - -bool PaddleClasModel::Preprocess(Mat* mat, FDTensor* output) { - for (size_t i = 0; i < processors_.size(); ++i) { - if (!(*(processors_[i].get()))(mat)) { - FDERROR << "Failed to process image data in " << processors_[i]->Name() - << "." << std::endl; - return false; - } - } - - int channel = mat->Channels(); - int width = mat->Width(); - int height = mat->Height(); - output->name = InputInfoOfRuntime(0).name; - output->SetExternalData({1, channel, height, width}, FDDataType::FP32, - mat->Data()); - return true; -} - -bool PaddleClasModel::Postprocess(const FDTensor& infer_result, - ClassifyResult* result, int topk) { - int num_classes = infer_result.shape[1]; - const float* infer_result_buffer = - reinterpret_cast(infer_result.Data()); - topk = std::min(num_classes, topk); - result->label_ids = - utils::TopKIndices(infer_result_buffer, num_classes, topk); - result->scores.resize(topk); - for (int i = 0; i < topk; ++i) { - result->scores[i] = *(infer_result_buffer + result->label_ids[i]); - } - return true; -} - bool PaddleClasModel::Predict(cv::Mat* im, ClassifyResult* result, int topk) { - Mat mat(*im); - std::vector processed_data(1); - if (!Preprocess(&mat, &(processed_data[0]))) { - FDERROR << "Failed to preprocess input data while using model:" - << ModelName() << "." << std::endl; + postprocessor_.SetTopk(topk); + if (!Predict(*im, result)) { + return false; + } + return true; +} + +bool PaddleClasModel::Predict(const cv::Mat& im, ClassifyResult* result) { + std::vector results; + if (!BatchPredict({im}, &results)) { + return false; + } + *result = std::move(results[0]); + return true; +} + +bool PaddleClasModel::BatchPredict(const std::vector& images, std::vector* results) { + std::vector fd_images = WrapMat(images); + if (!preprocessor_.Run(&fd_images, &reused_input_tensors)) { + FDERROR << "Failed to preprocess the input image." << std::endl; return false; } - std::vector infer_result(1); - if (!Infer(processed_data, &infer_result)) { - FDERROR << "Failed to inference while using model:" << ModelName() << "." - << std::endl; + reused_input_tensors[0].name = InputInfoOfRuntime(0).name; + if (!Infer(reused_input_tensors, &reused_output_tensors)) { + FDERROR << "Failed to inference by runtime." << std::endl; return false; } - if (!Postprocess(infer_result[0], result, topk)) { - FDERROR << "Failed to postprocess while using model:" << ModelName() << "." - << std::endl; + if (!postprocessor_.Run(reused_output_tensors, results)) { + FDERROR << "Failed to postprocess the inference results by runtime." << std::endl; return false; } + return true; } diff --git a/fastdeploy/vision/classification/ppcls/model.h b/fastdeploy/vision/classification/ppcls/model.h index dc0d7452b..e24477073 100644 --- a/fastdeploy/vision/classification/ppcls/model.h +++ b/fastdeploy/vision/classification/ppcls/model.h @@ -14,8 +14,8 @@ #pragma once #include "fastdeploy/fastdeploy_model.h" -#include "fastdeploy/vision/common/processors/transform.h" -#include "fastdeploy/vision/common/result.h" +#include "fastdeploy/vision/classification/ppcls/preprocessor.h" +#include "fastdeploy/vision/classification/ppcls/postprocessor.h" namespace fastdeploy { namespace vision { @@ -43,28 +43,46 @@ class FASTDEPLOY_DECL PaddleClasModel : public FastDeployModel { /// Get model's name virtual std::string ModelName() const { return "PaddleClas/Model"; } - /** \brief Predict the classification result for an input image + /** \brief DEPRECATED Predict the classification result for an input image, remove at 1.0 version * * \param[in] im The input image data, comes from cv::imread() * \param[in] result The output classification result will be writen to this structure - * \param[in] topk (int)The topk result by the classify confidence score, default 1 * \return true if the prediction successed, otherwise false */ - // TODO(jiangjiajun) Batch is on the way virtual bool Predict(cv::Mat* im, ClassifyResult* result, int topk = 1); + + /** \brief Predict the classification result for an input image + * + * \param[in] img The input image data, comes from cv::imread() + * \param[in] result The output classification result + * \return true if the prediction successed, otherwise false + */ + virtual bool Predict(const cv::Mat& img, ClassifyResult* result); + + /** \brief Predict the classification results for a batch of input images + * + * \param[in] imgs, The input image list, each element comes from cv::imread() + * \param[in] results The output classification result list + * \return true if the prediction successed, otherwise false + */ + virtual bool BatchPredict(const std::vector& imgs, + std::vector* results); + + /// Get preprocessor reference of PaddleClasModel + virtual PaddleClasPreprocessor& GetPreprocessor() { + return preprocessor_; + } + + /// Get postprocessor reference of PaddleClasModel + virtual PaddleClasPostprocessor& GetPostprocessor() { + return postprocessor_; + } + protected: bool Initialize(); - - bool BuildPreprocessPipelineFromConfig(); - - bool Preprocess(Mat* mat, FDTensor* outputs); - - bool Postprocess(const FDTensor& infer_result, ClassifyResult* result, - int topk = 1); - - std::vector> processors_; - std::string config_file_; + PaddleClasPreprocessor preprocessor_; + PaddleClasPostprocessor postprocessor_; }; typedef PaddleClasModel PPLCNet; diff --git a/fastdeploy/vision/classification/ppcls/postprocessor.cc b/fastdeploy/vision/classification/ppcls/postprocessor.cc new file mode 100644 index 000000000..34618035f --- /dev/null +++ b/fastdeploy/vision/classification/ppcls/postprocessor.cc @@ -0,0 +1,53 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/classification/ppcls/postprocessor.h" +#include "fastdeploy/vision/utils/utils.h" + +namespace fastdeploy { +namespace vision { +namespace classification { + +PaddleClasPostprocessor::PaddleClasPostprocessor(int topk) { + topk_ = topk; + initialized_ = true; +} + +bool PaddleClasPostprocessor::Run(const std::vector& infer_result, std::vector* results) { + if (!initialized_) { + FDERROR << "Postprocessor is not initialized." << std::endl; + return false; + } + + int batch = infer_result[0].shape[0]; + int num_classes = infer_result[0].shape[1]; + const float* infer_result_data = reinterpret_cast(infer_result[0].Data()); + + results->resize(batch); + + int topk = std::min(num_classes, topk_); + for (int i = 0; i < batch; ++i) { + (*results)[i].label_ids = utils::TopKIndices(infer_result_data + i * num_classes, num_classes, topk); + (*results)[i].scores.resize(topk); + for (int j = 0; j < topk; ++j) { + (*results)[i].scores[j] = infer_result_data[i * num_classes + (*results)[i].label_ids[j]]; + } + } + + return true; +} + +} // namespace classification +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/classification/ppcls/postprocessor.h b/fastdeploy/vision/classification/ppcls/postprocessor.h new file mode 100644 index 000000000..5623db36b --- /dev/null +++ b/fastdeploy/vision/classification/ppcls/postprocessor.h @@ -0,0 +1,55 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { +namespace vision { + +namespace classification { +/*! @brief Postprocessor object for PaddleClas serials model. + */ +class FASTDEPLOY_DECL PaddleClasPostprocessor { + public: + /** \brief Create a postprocessor instance for PaddleClas serials model + * + * \param[in] topk The topk result filtered by the classify confidence score, default 1 + */ + explicit PaddleClasPostprocessor(int topk = 1); + + /** \brief Process the result of runtime and fill to ClassifyResult structure + * + * \param[in] tensors The inference result from runtime + * \param[in] result The output result of classification + * \return true if the postprocess successed, otherwise false + */ + bool Run(const std::vector& tensors, + std::vector* result); + + /// Set topk value + void SetTopk(int topk) { topk_ = topk; } + + /// Get topk value + int GetTopk() const { return topk_; } + + private: + int topk_ = 1; + bool initialized_ = false; +}; + +} // namespace classification +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/classification/ppcls/ppcls_pybind.cc b/fastdeploy/vision/classification/ppcls/ppcls_pybind.cc index ea4e386f2..dc63744cc 100644 --- a/fastdeploy/vision/classification/ppcls/ppcls_pybind.cc +++ b/fastdeploy/vision/classification/ppcls/ppcls_pybind.cc @@ -15,16 +15,62 @@ namespace fastdeploy { void BindPaddleClas(pybind11::module& m) { + pybind11::class_( + m, "PaddleClasPreprocessor") + .def(pybind11::init()) + .def("run", [](vision::classification::PaddleClasPreprocessor& self, std::vector& im_list) { + std::vector images; + for (size_t i = 0; i < im_list.size(); ++i) { + images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); + } + std::vector outputs; + if (!self.Run(&images, &outputs)) { + pybind11::eval("raise Exception('Failed to preprocess the input data in PaddleClasPreprocessor.')"); + } + return outputs; + }); + + pybind11::class_( + m, "PaddleClasPostprocessor") + .def(pybind11::init()) + .def("run", [](vision::classification::PaddleClasPostprocessor& self, std::vector& inputs) { + std::vector results; + if (!self.Run(inputs, &results)) { + pybind11::eval("raise Exception('Failed to postprocess the runtime result in PaddleClasPostprocessor.')"); + } + return results; + }) + .def("run", [](vision::classification::PaddleClasPostprocessor& self, std::vector& input_array) { + std::vector results; + std::vector inputs; + PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true); + if (!self.Run(inputs, &results)) { + pybind11::eval("raise Exception('Failed to postprocess the runtime result in PaddleClasPostprocessor.')"); + } + return results; + }) + .def_property("topk", &vision::classification::PaddleClasPostprocessor::GetTopk, &vision::classification::PaddleClasPostprocessor::SetTopk); + pybind11::class_( m, "PaddleClasModel") .def(pybind11::init()) - .def("predict", [](vision::classification::PaddleClasModel& self, - pybind11::array& data, int topk = 1) { - auto mat = PyArrayToCvMat(data); - vision::ClassifyResult res; - self.Predict(&mat, &res, topk); - return res; - }); + .def("predict", [](vision::classification::PaddleClasModel& self, pybind11::array& data) { + cv::Mat im = PyArrayToCvMat(data); + vision::ClassifyResult result; + self.Predict(im, &result); + return result; + }) + .def("batch_predict", [](vision::classification::PaddleClasModel& self, std::vector& data) { + std::vector images; + for (size_t i = 0; i < data.size(); ++i) { + images.push_back(PyArrayToCvMat(data[i])); + } + std::vector results; + self.BatchPredict(images, &results); + return results; + }) + .def_property_readonly("preprocessor", &vision::classification::PaddleClasModel::GetPreprocessor) + .def_property_readonly("postprocessor", &vision::classification::PaddleClasModel::GetPostprocessor); } } // namespace fastdeploy diff --git a/fastdeploy/vision/classification/ppcls/preprocessor.cc b/fastdeploy/vision/classification/ppcls/preprocessor.cc new file mode 100644 index 000000000..d2aaca2c7 --- /dev/null +++ b/fastdeploy/vision/classification/ppcls/preprocessor.cc @@ -0,0 +1,108 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/classification/ppcls/preprocessor.h" +#include "fastdeploy/function/concat.h" +#include "yaml-cpp/yaml.h" + +namespace fastdeploy { +namespace vision { +namespace classification { + +PaddleClasPreprocessor::PaddleClasPreprocessor(const std::string& config_file) { + FDASSERT(BuildPreprocessPipelineFromConfig(config_file), "Failed to create PaddleClasPreprocessor."); + initialized_ = true; +} + +bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig(const std::string& config_file) { + processors_.clear(); + YAML::Node cfg; + try { + cfg = YAML::LoadFile(config_file); + } catch (YAML::BadFile& e) { + FDERROR << "Failed to load yaml file " << config_file + << ", maybe you should check this file." << std::endl; + return false; + } + auto preprocess_cfg = cfg["PreProcess"]["transform_ops"]; + processors_.push_back(std::make_shared()); + for (const auto& op : preprocess_cfg) { + FDASSERT(op.IsMap(), + "Require the transform information in yaml be Map type."); + auto op_name = op.begin()->first.as(); + if (op_name == "ResizeImage") { + int target_size = op.begin()->second["resize_short"].as(); + bool use_scale = false; + int interp = 1; + processors_.push_back( + std::make_shared(target_size, 1, use_scale)); + } else if (op_name == "CropImage") { + int width = op.begin()->second["size"].as(); + int height = op.begin()->second["size"].as(); + processors_.push_back(std::make_shared(width, height)); + } else if (op_name == "NormalizeImage") { + auto mean = op.begin()->second["mean"].as>(); + auto std = op.begin()->second["std"].as>(); + auto scale = op.begin()->second["scale"].as(); + FDASSERT((scale - 0.00392157) < 1e-06 && (scale - 0.00392157) > -1e-06, + "Only support scale in Normalize be 0.00392157, means the pixel " + "is in range of [0, 255]."); + processors_.push_back(std::make_shared(mean, std)); + } else if (op_name == "ToCHWImage") { + processors_.push_back(std::make_shared()); + } else { + FDERROR << "Unexcepted preprocess operator: " << op_name << "." + << std::endl; + return false; + } + } + + // Fusion will improve performance + FuseTransforms(&processors_); + return true; +} + +bool PaddleClasPreprocessor::Run(std::vector* images, std::vector* outputs) { + if (!initialized_) { + FDERROR << "The preprocessor is not initialized." << std::endl; + return false; + } + if (images->size() == 0) { + FDERROR << "The size of input images should be greater than 0." << std::endl; + return false; + } + + for (size_t i = 0; i < images->size(); ++i) { + for (size_t j = 0; j < processors_.size(); ++j) { + if (!(*(processors_[j].get()))(&((*images)[i]))) { + FDERROR << "Failed to processs image:" << i << " in " << processors_[i]->Name() << "." << std::endl; + return false; + } + } + } + + outputs->resize(1); + // Concat all the preprocessed data to a batch tensor + std::vector tensors(images->size()); + for (size_t i = 0; i < images->size(); ++i) { + (*images)[i].ShareWithTensor(&(tensors[i])); + tensors[i].ExpandDim(0); + } + Concat(tensors, &((*outputs)[0]), 0); + return true; +} + +} // namespace classification +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/classification/ppcls/preprocessor.h b/fastdeploy/vision/classification/ppcls/preprocessor.h new file mode 100644 index 000000000..38588f89a --- /dev/null +++ b/fastdeploy/vision/classification/ppcls/preprocessor.h @@ -0,0 +1,50 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { +namespace vision { + +namespace classification { +/*! @brief Preprocessor object for PaddleClas serials model. + */ +class FASTDEPLOY_DECL PaddleClasPreprocessor { + public: + /** \brief Create a preprocessor instance for PaddleClas serials model + * + * \param[in] config_file Path of configuration file for deployment, e.g resnet/infer_cfg.yml + */ + explicit PaddleClasPreprocessor(const std::string& config_file); + + /** \brief Process the input image and prepare input tensors for runtime + * + * \param[in] images The input image data list, all the elements are returned by cv::imread() + * \param[in] outputs The output tensors which will feed in runtime + * \return true if the preprocess successed, otherwise false + */ + bool Run(std::vector* images, std::vector* outputs); + + + private: + bool BuildPreprocessPipelineFromConfig(const std::string& config_file); + std::vector> processors_; + bool initialized_ = false; +}; + +} // namespace classification +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/base.cc b/fastdeploy/vision/common/processors/base.cc index 9410f6b61..20c058446 100644 --- a/fastdeploy/vision/common/processors/base.cc +++ b/fastdeploy/vision/common/processors/base.cc @@ -54,5 +54,12 @@ void DisableFlyCV() { << DefaultProcLib::default_lib << std::endl; } +void SetProcLibCpuNumThreads(int threads) { + cv::setNumThreads(threads); +#ifdef ENABLE_FLYCV + fcv::set_thread_num(threads); +#endif +} + } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/base.h b/fastdeploy/vision/common/processors/base.h index 6c67d10bc..e9d7ede90 100644 --- a/fastdeploy/vision/common/processors/base.h +++ b/fastdeploy/vision/common/processors/base.h @@ -31,6 +31,11 @@ FASTDEPLOY_DECL void EnableFlyCV(); /// Disable using FlyCV to process image while deploy vision models. FASTDEPLOY_DECL void DisableFlyCV(); +/*! @brief Set the cpu num threads of ProcLib. The cpu num threads + * of FlyCV and OpenCV is 2 by default. + */ +FASTDEPLOY_DECL void SetProcLibCpuNumThreads(int threads); + class FASTDEPLOY_DECL Processor { public: // default_lib has the highest priority diff --git a/fastdeploy/vision/common/processors/resize_by_short.cc b/fastdeploy/vision/common/processors/resize_by_short.cc index 5616961f1..de480a2eb 100644 --- a/fastdeploy/vision/common/processors/resize_by_short.cc +++ b/fastdeploy/vision/common/processors/resize_by_short.cc @@ -51,7 +51,7 @@ bool ResizeByShort::ImplByFlyCV(Mat* mat) { } else if (interp_ == 2) { interp_method = fcv::InterpolationType::INTER_CUBIC; } else { - FDERROR << "LimitLong: Only support interp_ be 0/1/2 with FlyCV, but " + FDERROR << "LimitByShort: Only support interp_ be 0/1/2 with FlyCV, but " "now it's " << interp_ << "." << std::endl; return false; diff --git a/fastdeploy/vision/common/result.cc b/fastdeploy/vision/common/result.cc index 760acb51d..3585713a6 100755 --- a/fastdeploy/vision/common/result.cc +++ b/fastdeploy/vision/common/result.cc @@ -35,6 +35,14 @@ std::string ClassifyResult::Str() { return out; } +ClassifyResult& ClassifyResult::operator=(ClassifyResult&& other) { + if (&other != this) { + label_ids = std::move(other.label_ids); + scores = std::move(other.scores); + } + return *this; +} + void Mask::Reserve(int size) { data.reserve(size); } void Mask::Resize(int size) { data.resize(size); } diff --git a/fastdeploy/vision/common/result.h b/fastdeploy/vision/common/result.h index 771bd62b1..59690ab6d 100755 --- a/fastdeploy/vision/common/result.h +++ b/fastdeploy/vision/common/result.h @@ -44,6 +44,7 @@ struct FASTDEPLOY_DECL BaseResult { /*! @brief Classify result structure for all the image classify models */ struct FASTDEPLOY_DECL ClassifyResult : public BaseResult { + ClassifyResult() = default; /// Classify result for an image std::vector label_ids; /// The confidence for each classify result @@ -53,6 +54,11 @@ struct FASTDEPLOY_DECL ClassifyResult : public BaseResult { /// Clear result void Clear(); + /// Copy constructor + ClassifyResult(const ClassifyResult& other) = default; + /// Move assignment + ClassifyResult& operator=(ClassifyResult&& other); + /// Debug function, convert the result to string to print std::string Str(); }; diff --git a/java/android/app/src/main/AndroidManifest.xml b/java/android/app/src/main/AndroidManifest.xml index 83dfc7542..ab3b31b87 100644 --- a/java/android/app/src/main/AndroidManifest.xml +++ b/java/android/app/src/main/AndroidManifest.xml @@ -15,14 +15,14 @@ android:roundIcon="@mipmap/ic_launcher_round" android:supportsRtl="true" android:theme="@style/AppTheme"> - + diff --git a/java/android/app/src/main/java/com/baidu/paddle/fastdeploy/app/examples/detection/MainActivity.java b/java/android/app/src/main/java/com/baidu/paddle/fastdeploy/app/examples/detection/DetectionMainActivity.java similarity index 91% rename from java/android/app/src/main/java/com/baidu/paddle/fastdeploy/app/examples/detection/MainActivity.java rename to java/android/app/src/main/java/com/baidu/paddle/fastdeploy/app/examples/detection/DetectionMainActivity.java index e071ae56e..704f34058 100644 --- a/java/android/app/src/main/java/com/baidu/paddle/fastdeploy/app/examples/detection/MainActivity.java +++ b/java/android/app/src/main/java/com/baidu/paddle/fastdeploy/app/examples/detection/DetectionMainActivity.java @@ -44,8 +44,8 @@ import java.math.BigDecimal; import java.util.ArrayList; import java.util.List; -public class MainActivity extends Activity implements View.OnClickListener, CameraSurfaceView.OnTextureChangedListener { - private static final String TAG = MainActivity.class.getSimpleName(); +public class DetectionMainActivity extends Activity implements View.OnClickListener, CameraSurfaceView.OnTextureChangedListener { + private static final String TAG = DetectionMainActivity.class.getSimpleName(); CameraSurfaceView svPreview; TextView tvStatus; @@ -90,7 +90,7 @@ public class MainActivity extends Activity implements View.OnClickListener, Came requestWindowFeature(Window.FEATURE_NO_TITLE); getWindow().setFlags(WindowManager.LayoutParams.FLAG_FULLSCREEN, WindowManager.LayoutParams.FLAG_FULLSCREEN); - setContentView(R.layout.default_activity_main); + setContentView(R.layout.detection_activity_main); // Clear all setting items to avoid app crashing due to the incorrect settings initSettings(); @@ -121,7 +121,7 @@ public class MainActivity extends Activity implements View.OnClickListener, Came resultImage.setImageBitmap(shutterBitmap); break; case R.id.btn_settings: - startActivity(new Intent(MainActivity.this, SettingsActivity.class)); + startActivity(new Intent(DetectionMainActivity.this, DetectionSettingsActivity.class)); break; case R.id.realtime_toggle_btn: toggleRealtimeStyle(); @@ -216,11 +216,11 @@ public class MainActivity extends Activity implements View.OnClickListener, Came originShutterBitmap = ARGB8888ImageBitmap.copy(Bitmap.Config.ARGB_8888,true); boolean modified = false; DetectionResult result = predictor.predict( - ARGB8888ImageBitmap, savedImagePath, SettingsActivity.scoreThreshold); + ARGB8888ImageBitmap, savedImagePath, DetectionSettingsActivity.scoreThreshold); modified = result.initialized(); if (!savedImagePath.isEmpty()) { synchronized (this) { - MainActivity.this.savedImagePath = "result.jpg"; + DetectionMainActivity.this.savedImagePath = "result.jpg"; } } lastFrameIndex++; @@ -325,7 +325,7 @@ public class MainActivity extends Activity implements View.OnClickListener, Came results.add(new BaseResultModel(1, "cup", 0.4f)); results.add(new BaseResultModel(2, "pen", 0.6f)); results.add(new BaseResultModel(3, "tang", 1.0f)); - final DetectResultAdapter adapter = new DetectResultAdapter(this, R.layout.default_result_page_item, results); + final DetectResultAdapter adapter = new DetectResultAdapter(this, R.layout.detection_result_page_item, results); detectResultView.setAdapter(adapter); detectResultView.invalidate(); @@ -375,25 +375,25 @@ public class MainActivity extends Activity implements View.OnClickListener, Came SharedPreferences.Editor editor = sharedPreferences.edit(); editor.clear(); editor.commit(); - SettingsActivity.resetSettings(); + DetectionSettingsActivity.resetSettings(); } public void checkAndUpdateSettings() { - if (SettingsActivity.checkAndUpdateSettings(this)) { - String realModelDir = getCacheDir() + "/" + SettingsActivity.modelDir; - Utils.copyDirectoryFromAssets(this, SettingsActivity.modelDir, realModelDir); - String realLabelPath = getCacheDir() + "/" + SettingsActivity.labelPath; - Utils.copyFileFromAssets(this, SettingsActivity.labelPath, realLabelPath); + if (DetectionSettingsActivity.checkAndUpdateSettings(this)) { + String realModelDir = getCacheDir() + "/" + DetectionSettingsActivity.modelDir; + Utils.copyDirectoryFromAssets(this, DetectionSettingsActivity.modelDir, realModelDir); + String realLabelPath = getCacheDir() + "/" + DetectionSettingsActivity.labelPath; + Utils.copyFileFromAssets(this, DetectionSettingsActivity.labelPath, realLabelPath); String modelFile = realModelDir + "/" + "model.pdmodel"; String paramsFile = realModelDir + "/" + "model.pdiparams"; String configFile = realModelDir + "/" + "infer_cfg.yml"; String labelFile = realLabelPath; RuntimeOption option = new RuntimeOption(); - option.setCpuThreadNum(SettingsActivity.cpuThreadNum); - option.setLitePowerMode(SettingsActivity.cpuPowerMode); + option.setCpuThreadNum(DetectionSettingsActivity.cpuThreadNum); + option.setLitePowerMode(DetectionSettingsActivity.cpuPowerMode); option.enableRecordTimeOfRuntime(); - if (Boolean.parseBoolean(SettingsActivity.enableLiteFp16)) { + if (Boolean.parseBoolean(DetectionSettingsActivity.enableLiteFp16)) { option.enableLiteFp16(); } predictor.init(modelFile, paramsFile, configFile, labelFile, option); @@ -405,7 +405,7 @@ public class MainActivity extends Activity implements View.OnClickListener, Came @NonNull int[] grantResults) { super.onRequestPermissionsResult(requestCode, permissions, grantResults); if (grantResults[0] != PackageManager.PERMISSION_GRANTED || grantResults[1] != PackageManager.PERMISSION_GRANTED) { - new AlertDialog.Builder(MainActivity.this) + new AlertDialog.Builder(DetectionMainActivity.this) .setTitle("Permission denied") .setMessage("Click to force quit the app, then open Settings->Apps & notifications->Target " + "App->Permissions to grant all of the permissions.") @@ -413,7 +413,7 @@ public class MainActivity extends Activity implements View.OnClickListener, Came .setPositiveButton("Exit", new DialogInterface.OnClickListener() { @Override public void onClick(DialogInterface dialog, int which) { - MainActivity.this.finish(); + DetectionMainActivity.this.finish(); } }).show(); } diff --git a/java/android/app/src/main/java/com/baidu/paddle/fastdeploy/app/examples/detection/SettingsActivity.java b/java/android/app/src/main/java/com/baidu/paddle/fastdeploy/app/examples/detection/DetectionSettingsActivity.java similarity index 98% rename from java/android/app/src/main/java/com/baidu/paddle/fastdeploy/app/examples/detection/SettingsActivity.java rename to java/android/app/src/main/java/com/baidu/paddle/fastdeploy/app/examples/detection/DetectionSettingsActivity.java index bcb56d703..e31c228cf 100644 --- a/java/android/app/src/main/java/com/baidu/paddle/fastdeploy/app/examples/detection/SettingsActivity.java +++ b/java/android/app/src/main/java/com/baidu/paddle/fastdeploy/app/examples/detection/DetectionSettingsActivity.java @@ -16,9 +16,9 @@ import com.baidu.paddle.fastdeploy.app.ui.Utils; import java.util.ArrayList; import java.util.List; -public class SettingsActivity extends AppCompatPreferenceActivity implements +public class DetectionSettingsActivity extends AppCompatPreferenceActivity implements SharedPreferences.OnSharedPreferenceChangeListener { - private static final String TAG = SettingsActivity.class.getSimpleName(); + private static final String TAG = DetectionSettingsActivity.class.getSimpleName(); static public int selectedModelIdx = -1; static public String modelDir = ""; diff --git a/java/android/app/src/main/java/com/baidu/paddle/fastdeploy/app/examples/ocr/MainActivity.java b/java/android/app/src/main/java/com/baidu/paddle/fastdeploy/app/examples/ocr/OcrMainActivity.java similarity index 84% rename from java/android/app/src/main/java/com/baidu/paddle/fastdeploy/app/examples/ocr/MainActivity.java rename to java/android/app/src/main/java/com/baidu/paddle/fastdeploy/app/examples/ocr/OcrMainActivity.java index df53d42eb..ab61f366c 100644 --- a/java/android/app/src/main/java/com/baidu/paddle/fastdeploy/app/examples/ocr/MainActivity.java +++ b/java/android/app/src/main/java/com/baidu/paddle/fastdeploy/app/examples/ocr/OcrMainActivity.java @@ -25,8 +25,8 @@ import android.widget.Toast; import com.baidu.paddle.fastdeploy.RuntimeOption; import com.baidu.paddle.fastdeploy.app.examples.R; -import com.baidu.paddle.fastdeploy.app.ui.CameraSurfaceView; -import com.baidu.paddle.fastdeploy.app.ui.view.Utils; +import com.baidu.paddle.fastdeploy.app.ui.Utils; +import com.baidu.paddle.fastdeploy.app.ui.view.CameraSurfaceView; import com.baidu.paddle.fastdeploy.vision.OCRResult; import com.baidu.paddle.fastdeploy.pipeline.PPOCRv2; import com.baidu.paddle.fastdeploy.vision.ocr.Classifier; @@ -37,8 +37,8 @@ import java.io.File; import java.text.SimpleDateFormat; import java.util.Date; -public class MainActivity extends Activity implements View.OnClickListener, CameraSurfaceView.OnTextureChangedListener { - private static final String TAG = MainActivity.class.getSimpleName(); +public class OcrMainActivity extends Activity implements View.OnClickListener, CameraSurfaceView.OnTextureChangedListener { + private static final String TAG = OcrMainActivity.class.getSimpleName(); CameraSurfaceView svPreview; TextView tvStatus; @@ -64,7 +64,7 @@ public class MainActivity extends Activity implements View.OnClickListener, Came requestWindowFeature(Window.FEATURE_NO_TITLE); getWindow().setFlags(WindowManager.LayoutParams.FLAG_FULLSCREEN, WindowManager.LayoutParams.FLAG_FULLSCREEN); - setContentView(R.layout.activity_main); + setContentView(R.layout.ocr_activity_main); // Clear all setting items to avoid app crashing due to the incorrect settings initSettings(); @@ -91,10 +91,10 @@ public class MainActivity extends Activity implements View.OnClickListener, Came synchronized (this) { savedImagePath = Utils.getDCIMDirectory() + File.separator + date.format(new Date()).toString() + ".png"; } - Toast.makeText(MainActivity.this, "Save snapshot to " + savedImagePath, Toast.LENGTH_SHORT).show(); + Toast.makeText(OcrMainActivity.this, "Save snapshot to " + savedImagePath, Toast.LENGTH_SHORT).show(); break; case R.id.btn_settings: - startActivity(new Intent(MainActivity.this, SettingsActivity.class)); + startActivity(new Intent(OcrMainActivity.this, OcrSettingsActivity.class)); break; case R.id.realtime_toggle_btn: toggleRealtimeStyle(); @@ -128,14 +128,14 @@ public class MainActivity extends Activity implements View.OnClickListener, Came public boolean onTextureChanged(Bitmap ARGB8888ImageBitmap) { String savedImagePath = ""; synchronized (this) { - savedImagePath = MainActivity.this.savedImagePath; + savedImagePath = OcrMainActivity.this.savedImagePath; } boolean modified = false; OCRResult result = predictor.predict(ARGB8888ImageBitmap, savedImagePath); modified = result.initialized(); if (!savedImagePath.isEmpty()) { synchronized (this) { - MainActivity.this.savedImagePath = "result.jpg"; + OcrMainActivity.this.savedImagePath = "result.jpg"; } } lastFrameIndex++; @@ -201,12 +201,12 @@ public class MainActivity extends Activity implements View.OnClickListener, Came SharedPreferences.Editor editor = sharedPreferences.edit(); editor.clear(); editor.commit(); - SettingsActivity.resetSettings(); + OcrSettingsActivity.resetSettings(); } public void checkAndUpdateSettings() { - if (SettingsActivity.checkAndUpdateSettings(this)) { - String realModelDir = getCacheDir() + "/" + SettingsActivity.modelDir; + if (OcrSettingsActivity.checkAndUpdateSettings(this)) { + String realModelDir = getCacheDir() + "/" + OcrSettingsActivity.modelDir; // String detModelName = "ch_PP-OCRv2_det_infer"; String detModelName = "ch_PP-OCRv3_det_infer"; // String detModelName = "ch_ppocr_mobile_v2.0_det_infer"; @@ -217,14 +217,14 @@ public class MainActivity extends Activity implements View.OnClickListener, Came String realDetModelDir = realModelDir + "/" + detModelName; String realClsModelDir = realModelDir + "/" + clsModelName; String realRecModelDir = realModelDir + "/" + recModelName; - String srcDetModelDir = SettingsActivity.modelDir + "/" + detModelName; - String srcClsModelDir = SettingsActivity.modelDir + "/" + clsModelName; - String srcRecModelDir = SettingsActivity.modelDir + "/" + recModelName; + String srcDetModelDir = OcrSettingsActivity.modelDir + "/" + detModelName; + String srcClsModelDir = OcrSettingsActivity.modelDir + "/" + clsModelName; + String srcRecModelDir = OcrSettingsActivity.modelDir + "/" + recModelName; Utils.copyDirectoryFromAssets(this, srcDetModelDir, realDetModelDir); Utils.copyDirectoryFromAssets(this, srcClsModelDir, realClsModelDir); Utils.copyDirectoryFromAssets(this, srcRecModelDir, realRecModelDir); - String realLabelPath = getCacheDir() + "/" + SettingsActivity.labelPath; - Utils.copyFileFromAssets(this, SettingsActivity.labelPath, realLabelPath); + String realLabelPath = getCacheDir() + "/" + OcrSettingsActivity.labelPath; + Utils.copyFileFromAssets(this, OcrSettingsActivity.labelPath, realLabelPath); String detModelFile = realDetModelDir + "/" + "inference.pdmodel"; String detParamsFile = realDetModelDir + "/" + "inference.pdiparams"; @@ -236,16 +236,16 @@ public class MainActivity extends Activity implements View.OnClickListener, Came RuntimeOption detOption = new RuntimeOption(); RuntimeOption clsOption = new RuntimeOption(); RuntimeOption recOption = new RuntimeOption(); - detOption.setCpuThreadNum(SettingsActivity.cpuThreadNum); - clsOption.setCpuThreadNum(SettingsActivity.cpuThreadNum); - recOption.setCpuThreadNum(SettingsActivity.cpuThreadNum); - detOption.setLitePowerMode(SettingsActivity.cpuPowerMode); - clsOption.setLitePowerMode(SettingsActivity.cpuPowerMode); - recOption.setLitePowerMode(SettingsActivity.cpuPowerMode); + detOption.setCpuThreadNum(OcrSettingsActivity.cpuThreadNum); + clsOption.setCpuThreadNum(OcrSettingsActivity.cpuThreadNum); + recOption.setCpuThreadNum(OcrSettingsActivity.cpuThreadNum); + detOption.setLitePowerMode(OcrSettingsActivity.cpuPowerMode); + clsOption.setLitePowerMode(OcrSettingsActivity.cpuPowerMode); + recOption.setLitePowerMode(OcrSettingsActivity.cpuPowerMode); detOption.enableRecordTimeOfRuntime(); clsOption.enableRecordTimeOfRuntime(); recOption.enableRecordTimeOfRuntime(); - if (Boolean.parseBoolean(SettingsActivity.enableLiteFp16)) { + if (Boolean.parseBoolean(OcrSettingsActivity.enableLiteFp16)) { detOption.enableLiteFp16(); clsOption.enableLiteFp16(); recOption.enableLiteFp16(); @@ -263,7 +263,7 @@ public class MainActivity extends Activity implements View.OnClickListener, Came @NonNull int[] grantResults) { super.onRequestPermissionsResult(requestCode, permissions, grantResults); if (grantResults[0] != PackageManager.PERMISSION_GRANTED || grantResults[1] != PackageManager.PERMISSION_GRANTED) { - new AlertDialog.Builder(MainActivity.this) + new AlertDialog.Builder(OcrMainActivity.this) .setTitle("Permission denied") .setMessage("Click to force quit the app, then open Settings->Apps & notifications->Target " + "App->Permissions to grant all of the permissions.") @@ -271,7 +271,7 @@ public class MainActivity extends Activity implements View.OnClickListener, Came .setPositiveButton("Exit", new DialogInterface.OnClickListener() { @Override public void onClick(DialogInterface dialog, int which) { - MainActivity.this.finish(); + OcrMainActivity.this.finish(); } }).show(); } diff --git a/java/android/app/src/main/java/com/baidu/paddle/fastdeploy/app/examples/ocr/SettingsActivity.java b/java/android/app/src/main/java/com/baidu/paddle/fastdeploy/app/examples/ocr/OcrSettingsActivity.java similarity index 97% rename from java/android/app/src/main/java/com/baidu/paddle/fastdeploy/app/examples/ocr/SettingsActivity.java rename to java/android/app/src/main/java/com/baidu/paddle/fastdeploy/app/examples/ocr/OcrSettingsActivity.java index 27b0c9e43..6f8c45ff4 100644 --- a/java/android/app/src/main/java/com/baidu/paddle/fastdeploy/app/examples/ocr/SettingsActivity.java +++ b/java/android/app/src/main/java/com/baidu/paddle/fastdeploy/app/examples/ocr/OcrSettingsActivity.java @@ -10,15 +10,15 @@ import android.preference.PreferenceManager; import android.support.v7.app.ActionBar; import com.baidu.paddle.fastdeploy.app.examples.R; -import com.baidu.paddle.fastdeploy.app.ui.AppCompatPreferenceActivity; -import com.baidu.paddle.fastdeploy.app.ui.view.Utils; +import com.baidu.paddle.fastdeploy.app.ui.Utils; +import com.baidu.paddle.fastdeploy.app.ui.view.AppCompatPreferenceActivity; import java.util.ArrayList; import java.util.List; -public class SettingsActivity extends AppCompatPreferenceActivity implements +public class OcrSettingsActivity extends AppCompatPreferenceActivity implements SharedPreferences.OnSharedPreferenceChangeListener { - private static final String TAG = SettingsActivity.class.getSimpleName(); + private static final String TAG = OcrSettingsActivity.class.getSimpleName(); static public int selectedModelIdx = -1; static public String modelDir = ""; diff --git a/java/android/app/src/main/res/layout-land/default_activity_main.xml b/java/android/app/src/main/res/layout-land/default_activity_main.xml deleted file mode 100644 index 4cae72e1d..000000000 --- a/java/android/app/src/main/res/layout-land/default_activity_main.xml +++ /dev/null @@ -1,99 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/java/android/app/src/main/res/layout-land/detection_activity_main.xml b/java/android/app/src/main/res/layout-land/detection_activity_main.xml new file mode 100644 index 000000000..1a75970f4 --- /dev/null +++ b/java/android/app/src/main/res/layout-land/detection_activity_main.xml @@ -0,0 +1,14 @@ + + + + + + + diff --git a/java/android/app/src/main/res/layout/default_activity_main.xml b/java/android/app/src/main/res/layout-land/ocr_activity_main.xml similarity index 80% rename from java/android/app/src/main/res/layout/default_activity_main.xml rename to java/android/app/src/main/res/layout-land/ocr_activity_main.xml index 2ceb83eaf..b30f35edf 100644 --- a/java/android/app/src/main/res/layout/default_activity_main.xml +++ b/java/android/app/src/main/res/layout-land/ocr_activity_main.xml @@ -4,11 +4,11 @@ android:layout_height="match_parent"> diff --git a/java/android/app/src/main/res/layout/detection_activity_main.xml b/java/android/app/src/main/res/layout/detection_activity_main.xml new file mode 100644 index 000000000..1a75970f4 --- /dev/null +++ b/java/android/app/src/main/res/layout/detection_activity_main.xml @@ -0,0 +1,14 @@ + + + + + + + diff --git a/java/android/app/src/main/res/layout/detection_camera_page.xml b/java/android/app/src/main/res/layout/detection_camera_page.xml new file mode 100644 index 000000000..da262b58b --- /dev/null +++ b/java/android/app/src/main/res/layout/detection_camera_page.xml @@ -0,0 +1,161 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/java/android/app/src/main/res/layout/default_result_page.xml b/java/android/app/src/main/res/layout/detection_result_page.xml similarity index 100% rename from java/android/app/src/main/res/layout/default_result_page.xml rename to java/android/app/src/main/res/layout/detection_result_page.xml diff --git a/java/android/app/src/main/res/layout/default_result_page_item.xml b/java/android/app/src/main/res/layout/detection_result_page_item.xml similarity index 100% rename from java/android/app/src/main/res/layout/default_result_page_item.xml rename to java/android/app/src/main/res/layout/detection_result_page_item.xml diff --git a/java/android/app/src/main/res/layout/ocr_activity_main.xml b/java/android/app/src/main/res/layout/ocr_activity_main.xml new file mode 100644 index 000000000..b30f35edf --- /dev/null +++ b/java/android/app/src/main/res/layout/ocr_activity_main.xml @@ -0,0 +1,14 @@ + + + + + + + diff --git a/java/android/app/src/main/res/layout/default_camera_page.xml b/java/android/app/src/main/res/layout/ocr_camera_page.xml similarity index 99% rename from java/android/app/src/main/res/layout/default_camera_page.xml rename to java/android/app/src/main/res/layout/ocr_camera_page.xml index 098c50d9d..4fb00472a 100644 --- a/java/android/app/src/main/res/layout/default_camera_page.xml +++ b/java/android/app/src/main/res/layout/ocr_camera_page.xml @@ -5,7 +5,7 @@ android:layout_width="match_parent" android:layout_height="match_parent" android:keepScreenOn="true" - tools:context=".MainActivity"> + tools:context=".ocr.OcrMainActivity"> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/java/android/app/src/main/res/layout/ocr_result_page_item.xml b/java/android/app/src/main/res/layout/ocr_result_page_item.xml new file mode 100644 index 000000000..6a2b09ebf --- /dev/null +++ b/java/android/app/src/main/res/layout/ocr_result_page_item.xml @@ -0,0 +1,26 @@ + + + + + + + + + \ No newline at end of file diff --git a/java/android/fastdeploy/build.gradle b/java/android/fastdeploy/build.gradle index cf4815a94..b83ccee54 100644 --- a/java/android/fastdeploy/build.gradle +++ b/java/android/fastdeploy/build.gradle @@ -46,7 +46,7 @@ dependencies { def archives = [ [ - 'src' : 'https://bj.bcebos.com/fastdeploy/test/fastdeploy-android-0.5.0-shared-dev.tgz', + 'src' : 'https://bj.bcebos.com/fastdeploy/test/fastdeploy-android-latest-shared-dev.tgz', 'dest': 'libs' ] ] diff --git a/java/android/fastdeploy/src/main/cpp/CMakeLists.txt b/java/android/fastdeploy/src/main/cpp/CMakeLists.txt index 6298d03aa..97d772af2 100644 --- a/java/android/fastdeploy/src/main/cpp/CMakeLists.txt +++ b/java/android/fastdeploy/src/main/cpp/CMakeLists.txt @@ -12,7 +12,7 @@ project("fastdeploy_jni") # You can define multiple libraries, and CMake builds them for you. # Gradle automatically packages shared libraries with your APK. -set(FastDeploy_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../libs/fastdeploy-android-0.5.0-shared-dev") +set(FastDeploy_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../libs/fastdeploy-android-latest-shared-dev") find_package(FastDeploy REQUIRED) diff --git a/python/fastdeploy/vision/classification/__init__.py b/python/fastdeploy/vision/classification/__init__.py index 0b426fab1..4ca2097d0 100644 --- a/python/fastdeploy/vision/classification/__init__.py +++ b/python/fastdeploy/vision/classification/__init__.py @@ -14,8 +14,9 @@ from __future__ import absolute_import from .contrib.yolov5cls import YOLOv5Cls -from .ppcls import PaddleClasModel +from .ppcls import * from .contrib.resnet import ResNet + PPLCNet = PaddleClasModel PPLCNetv2 = PaddleClasModel EfficientNet = PaddleClasModel diff --git a/python/fastdeploy/vision/classification/ppcls/__init__.py b/python/fastdeploy/vision/classification/ppcls/__init__.py index 5672e8efb..879d54441 100644 --- a/python/fastdeploy/vision/classification/ppcls/__init__.py +++ b/python/fastdeploy/vision/classification/ppcls/__init__.py @@ -18,6 +18,42 @@ from .... import FastDeployModel, ModelFormat from .... import c_lib_wrap as C +class PaddleClasPreprocessor: + def __init__(self, config_file): + """Create a preprocessor for PaddleClasModel from configuration file + + :param config_file: (str)Path of configuration file, e.g resnet50/inference_cls.yaml + """ + self._preprocessor = C.vision.classification.PaddleClasPreprocessor( + config_file) + + def run(self, input_ims): + """Preprocess input images for PaddleClasModel + + :param: input_ims: (list of numpy.ndarray)The input image + :return: list of FDTensor + """ + return self._preprocessor.run(input_ims) + + +class PaddleClasPostprocessor: + def __init__(self, topk=1): + """Create a postprocessor for PaddleClasModel + + :param topk: (int)Filter the top k classify label + """ + self._postprocessor = C.vision.classification.PaddleClasPostprocessor( + topk) + + def run(self, runtime_results): + """Postprocess the runtime results for PaddleClasModel + + :param: runtime_results: (list of FDTensor)The output FDTensor results from runtime + :return: list of ClassifyResult(If the runtime_results is predict by batched samples, the length of this list equals to the batch size) + """ + return self._postprocessor.run(runtime_results) + + class PaddleClasModel(FastDeployModel): def __init__(self, model_file, @@ -45,9 +81,35 @@ class PaddleClasModel(FastDeployModel): def predict(self, im, topk=1): """Classify an input image - :param im: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format - :param topk: (int)The topk result by the classify confidence score, default 1 + :param im: (numpy.ndarray) The input image data, a 3-D array with layout HWC, BGR format + :param topk: (int) Filter the topk classify result, default 1 :return: ClassifyResult """ - return self._model.predict(im, topk) + self.postprocessor.topk = topk + return self._model.predict(im) + + def batch_predict(self, images): + """Classify a batch of input image + + :param im: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format + :return list of ClassifyResult + """ + + return self._model.batch_predict(images) + + @property + def preprocessor(self): + """Get PaddleClasPreprocessor object of the loaded model + + :return PaddleClasPreprocessor + """ + return self._model.preprocessor + + @property + def postprocessor(self): + """Get PaddleClasPostprocessor object of the loaded model + + :return PaddleClasPostprocessor + """ + return self._model.postprocessor diff --git a/tests/models/test_mobilenetv2.py b/tests/models/test_mobilenetv2.py index c2cec0220..3bedc82f1 100755 --- a/tests/models/test_mobilenetv2.py +++ b/tests/models/test_mobilenetv2.py @@ -22,9 +22,11 @@ import runtime_config as rc def test_classification_mobilenetv2(): model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/MobileNetV1_x0_25_infer.tgz" - input_url = "https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg" + input_url1 = "https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg" + input_url2 = "https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00030010.jpeg" fd.download_and_decompress(model_url, "resources") - fd.download(input_url, "resources") + fd.download(input_url1, "resources") + fd.download(input_url2, "resources") model_path = "resources/MobileNetV1_x0_25_infer" model_file = "resources/MobileNetV1_x0_25_infer/inference.pdmodel" @@ -33,18 +35,67 @@ def test_classification_mobilenetv2(): model = fd.vision.classification.PaddleClasModel( model_file, params_file, config_file, runtime_option=rc.test_option) - expected_label_ids = [153, 333, 259, 338, 265, 154] - expected_scores = [ + expected_label_ids_1 = [153, 333, 259, 338, 265, 154] + expected_scores_1 = [ 0.221088, 0.109457, 0.078668, 0.076814, 0.052401, 0.048206 ] + expected_label_ids_2 = [80, 23, 93, 99, 143, 7] + expected_scores_2 = [ + 0.975599, 0.014083, 0.003821, 0.001571, 0.001233, 0.000924 + ] + # compare diff - im = cv2.imread("./resources/ILSVRC2012_val_00000010.jpeg") - for i in range(2): - result = model.predict(im, topk=6) - diff_label = np.fabs( - np.array(result.label_ids) - np.array(expected_label_ids)) - diff_scores = np.fabs( - np.array(result.scores) - np.array(expected_scores)) - assert diff_label.max() < 1e-06, "There's difference in classify label." - assert diff_scores.max( - ) < 1e-05, "There's difference in classify score." + im1 = cv2.imread("./resources/ILSVRC2012_val_00000010.jpeg") + im2 = cv2.imread("./resources/ILSVRC2012_val_00030010.jpeg") + + # for i in range(3000000): + while True: + # test single predict + model.postprocessor.topk = 6 + result1 = model.predict(im1) + result2 = model.predict(im2) + + diff_label_1 = np.fabs( + np.array(result1.label_ids) - np.array(expected_label_ids_1)) + diff_label_2 = np.fabs( + np.array(result2.label_ids) - np.array(expected_label_ids_2)) + + diff_scores_1 = np.fabs( + np.array(result1.scores) - np.array(expected_scores_1)) + diff_scores_2 = np.fabs( + np.array(result2.scores) - np.array(expected_scores_2)) + assert diff_label_1.max( + ) < 1e-06, "There's difference in classify label 1." + assert diff_scores_1.max( + ) < 1e-05, "There's difference in classify score 1." + assert diff_label_2.max( + ) < 1e-06, "There's difference in classify label 2." + assert diff_scores_2.max( + ) < 1e-05, "There's difference in classify score 2." + + # test batch predict + results = model.batch_predict([im1, im2]) + result1 = results[0] + result2 = results[1] + + diff_label_1 = np.fabs( + np.array(result1.label_ids) - np.array(expected_label_ids_1)) + diff_label_2 = np.fabs( + np.array(result2.label_ids) - np.array(expected_label_ids_2)) + + diff_scores_1 = np.fabs( + np.array(result1.scores) - np.array(expected_scores_1)) + diff_scores_2 = np.fabs( + np.array(result2.scores) - np.array(expected_scores_2)) + assert diff_label_1.max( + ) < 1e-06, "There's difference in classify label 1." + assert diff_scores_1.max( + ) < 1e-05, "There's difference in classify score 1." + assert diff_label_2.max( + ) < 1e-06, "There's difference in classify label 2." + assert diff_scores_2.max( + ) < 1e-05, "There's difference in classify score 2." + + +if __name__ == "__main__": + test_classification_mobilenetv2()