diff --git a/CMakeLists.txt b/CMakeLists.txt index 14abd7fcb..b42f0c23c 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -63,6 +63,7 @@ option(ENABLE_POROS_BACKEND "Whether to enable poros backend." OFF) option(ENABLE_OPENVINO_BACKEND "Whether to enable openvino backend." OFF) option(ENABLE_RKNPU2_BACKEND "Whether to enable RKNPU2 backend." OFF) option(ENABLE_SOPHGO_BACKEND "Whether to enable SOPHON backend." OFF) +option(ENABLE_TVM_BACKEND "Whether to enable TVM backend." OFF) option(ENABLE_LITE_BACKEND "Whether to enable paddle lite backend." OFF) option(ENABLE_HORIZON_BACKEND "Whether to enable HORIZON backend." OFF) option(ENABLE_VISION "Whether to enable vision models usage." OFF) @@ -169,6 +170,7 @@ file(GLOB_RECURSE DEPLOY_OPENVINO_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/f file(GLOB_RECURSE DEPLOY_RKNPU2_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/runtime/backends/rknpu2/*.cc) file(GLOB_RECURSE DEPLOY_HORIZON_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/runtime/backends/horizon/*.cc) file(GLOB_RECURSE DEPLOY_SOPHGO_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/runtime/backends/sophgo/*.cc) +file(GLOB_RECURSE DEPLOY_TVM_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/runtime/backends/tvm/*.cc) file(GLOB_RECURSE DEPLOY_LITE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/runtime/backends/lite/*.cc) file(GLOB_RECURSE DEPLOY_ENCRYPTION_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/encryption/*.cc) file(GLOB_RECURSE DEPLOY_PIPELINE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pipeline/*.cc) @@ -188,7 +190,8 @@ list(REMOVE_ITEM ALL_DEPLOY_SRCS ${DEPLOY_ORT_SRCS} ${DEPLOY_PADDLE_SRCS} ${DEPLOY_OPENVINO_SRCS} ${DEPLOY_LITE_SRCS} ${DEPLOY_VISION_SRCS} ${DEPLOY_TEXT_SRCS} ${DEPLOY_PIPELINE_SRCS} ${DEPLOY_RKNPU2_SRCS} - ${DEPLOY_SOPHGO_SRCS} ${DEPLOY_ENCRYPTION_SRCS} ${DEPLOY_HORIZON_SRCS}) + ${DEPLOY_SOPHGO_SRCS} ${DEPLOY_ENCRYPTION_SRCS} + ${DEPLOY_HORIZON_SRCS} ${DEPLOY_TVM_SRCS}) set(DEPEND_LIBS "") @@ -263,6 +266,14 @@ if(ENABLE_HORIZON_BACKEND) list(APPEND DEPEND_LIBS ${BPU_libs}) endif() +if(ENABLE_TVM_BACKEND) + set(CMAKE_CXX_STANDARD 17) + add_definitions(-DENABLE_TVM_BACKEND) + list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_TVM_SRCS}) + include(${PROJECT_SOURCE_DIR}/cmake/tvm.cmake) + list(APPEND DEPEND_LIBS ${TVM_RUNTIME_LIB}) +endif() + if(ENABLE_SOPHGO_BACKEND) add_definitions(-DENABLE_SOPHGO_BACKEND) list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_SOPHGO_SRCS}) diff --git a/FastDeploy.cmake.in b/FastDeploy.cmake.in index fea096299..9f0e9f40e 100644 --- a/FastDeploy.cmake.in +++ b/FastDeploy.cmake.in @@ -24,6 +24,7 @@ set(RKNN2_TARGET_SOC "@RKNN2_TARGET_SOC@") # Inference backend and FastDeploy Moudle set(ENABLE_ORT_BACKEND @ENABLE_ORT_BACKEND@) set(ENABLE_RKNPU2_BACKEND @ENABLE_RKNPU2_BACKEND@) +set(ENABLE_TVM_BACKEND @ENABLE_TVM_BACKEND@) set(ENABLE_HORIZON_BACKEND @ENABLE_HORIZON_BACKEND@) set(ENABLE_SOPHGO_BACKEND @ENABLE_SOPHGO_BACKEND@) set(ENABLE_LITE_BACKEND @ENABLE_LITE_BACKEND@) @@ -129,6 +130,15 @@ if(ENABLE_ORT_BACKEND) list(APPEND FASTDEPLOY_LIBS ${ORT_LIB}) endif() +if(ENABLE_TVM_BACKEND) + if(APPLE) + set(TVM_RUNTIME_LIB ${CMAKE_CURRENT_LIST_DIR}/third_libs/install/tvm/lib/libtvm_runtime.dylib) + else() + set(TVM_RUNTIME_LIB ${CMAKE_CURRENT_LIST_DIR}/third_libs/install/tvm/lib/libtvm_runtime.so) + endif() + list(APPEND FASTDEPLOY_LIBS ${TVM_RUNTIME_LIB}) +endif() + if(ENABLE_PADDLE_BACKEND) find_library(PADDLE_LIB paddle_inference ${CMAKE_CURRENT_LIST_DIR}/third_libs/install/paddle_inference/paddle/lib NO_DEFAULT_PATH) if(WIN32) diff --git a/cmake/summary.cmake b/cmake/summary.cmake index 076c28dd5..935e5910b 100755 --- a/cmake/summary.cmake +++ b/cmake/summary.cmake @@ -40,6 +40,7 @@ function(fastdeploy_summary) message(STATUS " ENABLE_POROS_BACKEND : ${ENABLE_POROS_BACKEND}") message(STATUS " ENABLE_TRT_BACKEND : ${ENABLE_TRT_BACKEND}") message(STATUS " ENABLE_OPENVINO_BACKEND : ${ENABLE_OPENVINO_BACKEND}") + message(STATUS " ENABLE_TVM_BACKEND : ${ENABLE_TVM_BACKEND}") message(STATUS " ENABLE_BENCHMARK : ${ENABLE_BENCHMARK}") message(STATUS " ENABLE_VISION : ${ENABLE_VISION}") message(STATUS " ENABLE_TEXT : ${ENABLE_TEXT}") diff --git a/cmake/tvm.cmake b/cmake/tvm.cmake new file mode 100644 index 000000000..4c35ffcdf --- /dev/null +++ b/cmake/tvm.cmake @@ -0,0 +1,55 @@ +# set path + +set(TVM_URL_BASE "https://bj.bcebos.com/fastdeploy/third_libs/") +set(TVM_VERSION "0.12.0") +set(TVM_SYSTEM "") + +if (${CMAKE_SYSTEM} MATCHES "Darwin") + if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "arm64") + set(TVM_SYSTEM "macos-arm64") + endif () +elseif (${CMAKE_SYSTEM} MATCHES "Linux") + if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86") + set(TVM_SYSTEM "linux-x86") + endif () +else () + error("TVM only support MacOS in Arm64 or linux in x86") +endif () +set(TVM_FILE "tvm-${TVM_SYSTEM}-${TVM_VERSION}.tgz") +set(TVM_URL "${TVM_URL_BASE}${TVM_FILE}") + +set(TVM_RUNTIME_PATH "${THIRD_PARTY_PATH}/install/tvm") +execute_process(COMMAND ${CMAKE_COMMAND} -E make_directory "${TVM_RUNTIME_PATH}") +download_and_decompress(${TVM_URL} + "${CMAKE_CURRENT_BINARY_DIR}/${TVM_FILE}" + "${THIRD_PARTY_PATH}/install/") +include_directories(${TVM_RUNTIME_PATH}/include) + +# copy dlpack to third_party +set(DLPACK_PATH "${THIRD_PARTY_PATH}/install/dlpack") +execute_process(COMMAND ${CMAKE_COMMAND} -E make_directory "${DLPACK_PATH}") +execute_process(COMMAND ${CMAKE_COMMAND} -E copy_directory + "${PROJECT_SOURCE_DIR}/third_party/dlpack" + "${THIRD_PARTY_PATH}/install/") +include_directories(${DLPACK_PATH}/include) + +set(DMLC_CORE_PATH "${THIRD_PARTY_PATH}/install/dmlc-core") +execute_process(COMMAND ${CMAKE_COMMAND} -E make_directory "${DMLC_CORE_PATH}") +set(DMLC_CORE_URL https://bj.bcebos.com/fastdeploy/third_libs/dmlc-core.tgz) +download_and_decompress(${DMLC_CORE_URL} + "${CMAKE_CURRENT_BINARY_DIR}/dmlc-core.tgz" + "${THIRD_PARTY_PATH}/install/") +include_directories(${DMLC_CORE_PATH}/include) + +# include lib +if (EXISTS ${TVM_RUNTIME_PATH}) + if (${CMAKE_SYSTEM} MATCHES "Darwin") + set(TVM_RUNTIME_LIB ${TVM_RUNTIME_PATH}/lib/libtvm_runtime.dylib) + elseif (${CMAKE_SYSTEM} MATCHES "Linux") + set(TVM_RUNTIME_LIB ${TVM_RUNTIME_PATH}/lib/libtvm_runtime.so) + endif () + include(${TVM_RUNTIME_PATH}/lib/cmake/tvm/tvmConfig.cmake) + add_definitions(-DDMLC_USE_LOGGING_LIBRARY=) +else () + error(FATAL_ERROR "[tvm.cmake] TVM_RUNTIME_PATH does not exist.") +endif () \ No newline at end of file diff --git a/examples/vision/detection/paddledetection/tvm/README.md b/examples/vision/detection/paddledetection/tvm/README.md new file mode 100644 index 000000000..3a8c5c3f1 --- /dev/null +++ b/examples/vision/detection/paddledetection/tvm/README.md @@ -0,0 +1,35 @@ +[English](README.md) | 简体中文 + +# PaddleDetection TVM部署示例 + +在TVM上已经通过测试的PaddleDetection模型如下: + +* picodet +* PPYOLOE + +### Paddle模型转换为TVM模型 + +由于TVM不支持NMS算子,因此在转换模型前我们需要对PaddleDetection模型进行裁剪,将模型的输出节点改为NMS节点的输入节点。 +输入以下命令,你将得到一个裁剪后的PPYOLOE模型。 + +```bash +git clone https://github.com/PaddlePaddle/Paddle2ONNX.git +cd Paddle2ONNX/tools/paddle +wget https://bj.bcebos.com/fastdeploy/models/ppyoloe_plus_crn_m_80e_coco.tgz +tar xvf ppyoloe_plus_crn_m_80e_coco.tgz +python prune_paddle_model.py --model_dir ppyoloe_plus_crn_m_80e_coco \ + --model_filename model.pdmodel \ + --params_filename model.pdiparams \ + --output_names tmp_17 concat_14.tmp_0 \ + --save_dir ppyoloe_plus_crn_m_80e_coco +``` + +裁剪完模型后我们就可以通过tvm python库实现编译模型,这里为了方便大家使用,提供了转换脚本。 +输入以下命令,你将得到转换过后的TVM模型。 +注意,FastDeploy在推理PPYOLOE时不关依赖模型,还依赖yml文件,因此你还需要将对应的yml文件拷贝到模型目录下。 + +```bash +python path/to/FastDeploy/tools/tvm/paddle2tvm.py --model_path=./ppyoloe_plus_crn_m_80e_coco/model \ + --shape_dict="{'image': [1, 3, 640, 640], 'scale_factor': [1, 2]}" +cp ppyoloe_plus_crn_m_80e_coco/infer_cfg.yml tvm_save +``` diff --git a/examples/vision/detection/paddledetection/tvm/cpp/CMakeLists.txt b/examples/vision/detection/paddledetection/tvm/cpp/CMakeLists.txt new file mode 100644 index 000000000..3dafcf718 --- /dev/null +++ b/examples/vision/detection/paddledetection/tvm/cpp/CMakeLists.txt @@ -0,0 +1,13 @@ +PROJECT(infer_demo C CXX) +CMAKE_MINIMUM_REQUIRED (VERSION 3.10) + +# 指定下载解压后的fastdeploy库路径 +option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.") + +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) + +# 添加FastDeploy依赖头文件 +include_directories(${FASTDEPLOY_INCS}) + +add_executable(infer_ppyoloe_demo ${PROJECT_SOURCE_DIR}/infer_ppyoloe_demo.cc) +target_link_libraries(infer_ppyoloe_demo ${FASTDEPLOY_LIBS}) diff --git a/examples/vision/detection/paddledetection/tvm/cpp/README.md b/examples/vision/detection/paddledetection/tvm/cpp/README.md new file mode 100644 index 000000000..2fcba607b --- /dev/null +++ b/examples/vision/detection/paddledetection/tvm/cpp/README.md @@ -0,0 +1,60 @@ +[English](README.md) | 简体中文 + +# PaddleDetection C++部署示例 + +本目录下提供`infer_ppyoloe_demo.cc`快速完成PPDetection模型使用TVM加速部署的示例。 + +## 转换模型并运行 + +```bash +# build example +mkdir build +cd build +cmake .. -DFASTDEPLOY_INSTALL_DIR=/path/to/fastdeploy-sdk +make -j +wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg +./infer_ppyoloe_demo ../tvm_save 000000014439.jpg + ``` + + +## PaddleDetection C++接口 + +### 模型类 + +PaddleDetection目前支持6种模型系列,类名分别为`PPYOLOE`, `PicoDet`, `PaddleYOLOX`, `PPYOLO`, `FasterRCNN`,`SSD`,`PaddleYOLOv5`,`PaddleYOLOv6`,`PaddleYOLOv7`,`RTMDet`,`CascadeRCNN`,`PSSDet`,`RetinaNet`,`PPYOLOESOD`,`FCOS`,`TTFNet`,`TOOD`,`GFL`所有类名的构造函数和预测函数在参数上完全一致,本文档以PPYOLOE为例讲解API +```c++ +fastdeploy::vision::detection::PPYOLOE( + const string& model_file, + const string& params_file, + const string& config_file + const RuntimeOption& runtime_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::PADDLE) +``` + +PaddleDetection PPYOLOE模型加载和初始化,其中model_file为导出的ONNX模型格式。 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径 +> * **config_file**(str): 配置文件路径,即PaddleDetection导出的部署yaml文件 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(ModelFormat): 模型格式,默认为PADDLE格式 + +#### Predict函数 + +> ```c++ +> PPYOLOE::Predict(cv::Mat* im, DetectionResult* result) +> ``` +> +> 模型预测接口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **im**: 输入图像,注意需为HWC,BGR格式 +> > * **result**: 检测结果,包括检测框,各个框的置信度, DetectionResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/) + +- [模型介绍](../../) +- [Python部署](../python) +- [视觉模型预测结果](../../../../../docs/api/vision_results/) +- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md) diff --git a/examples/vision/detection/paddledetection/tvm/cpp/infer_ppyoloe_demo.cc b/examples/vision/detection/paddledetection/tvm/cpp/infer_ppyoloe_demo.cc new file mode 100644 index 000000000..de406b284 --- /dev/null +++ b/examples/vision/detection/paddledetection/tvm/cpp/infer_ppyoloe_demo.cc @@ -0,0 +1,57 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +void TVMInfer(const std::string& model_dir, const std::string& image_file) { + auto model_file = model_dir + "/tvm_model"; + auto params_file = ""; + auto config_file = model_dir + "/infer_cfg.yml"; + + auto option = fastdeploy::RuntimeOption(); + option.UseCpu(); + option.UseTVMBackend(); + + auto format = fastdeploy::ModelFormat::TVMFormat; + + auto model = fastdeploy::vision::detection::PPYOLOE( + model_file, params_file, config_file, option, format); + model.GetPostprocessor().ApplyNMS(); + + auto im = cv::imread(image_file); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + std::cout << res.Str() << std::endl; + auto vis_im = fastdeploy::vision::VisDetection(im, res, 0.5); + cv::imwrite("infer.jpg", vis_im); + std::cout << "Visualized result saved in ./infer.jpg" << std::endl; +} + +int main(int argc, char* argv[]) { + if (argc < 3) { + std::cout + << "Usage: infer_demo path/to/model_dir path/to/image run_option, " + "e.g ./infer_model ./picodet_model_dir ./test.jpeg" + << std::endl; + return -1; + } + + TVMInfer(argv[1], argv[2]); + return 0; +} \ No newline at end of file diff --git a/fastdeploy/core/config.h.in b/fastdeploy/core/config.h.in index f5a8d41b7..efea5221c 100755 --- a/fastdeploy/core/config.h.in +++ b/fastdeploy/core/config.h.in @@ -71,4 +71,8 @@ #ifndef ENABLE_HORIZON_BACKEND #cmakedefine ENABLE_HORIZON_BACKEND -#endif \ No newline at end of file +#endif + +#ifndef ENABLE_TVM_BACKEND +#cmakedefine ENABLE_TVM_BACKEND +#endif diff --git a/fastdeploy/runtime/backends/tvm/option.h b/fastdeploy/runtime/backends/tvm/option.h new file mode 100644 index 000000000..45f998159 --- /dev/null +++ b/fastdeploy/runtime/backends/tvm/option.h @@ -0,0 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +namespace fastdeploy { +struct TVMBackendOption { + TVMBackendOption() {} +}; + +} // namespace fastdeploy diff --git a/fastdeploy/runtime/backends/tvm/tvm_backend.cc b/fastdeploy/runtime/backends/tvm/tvm_backend.cc new file mode 100644 index 000000000..6e04f78b8 --- /dev/null +++ b/fastdeploy/runtime/backends/tvm/tvm_backend.cc @@ -0,0 +1,205 @@ +#include "fastdeploy/runtime/backends/tvm/tvm_backend.h" + +#include "yaml-cpp/yaml.h" +namespace fastdeploy { +bool TVMBackend::Init(const fastdeploy::RuntimeOption& runtime_option) { + if (!(Supported(runtime_option.model_format, Backend::TVM) && + Supported(runtime_option.device, Backend::TVM))) { + FDERROR << "TVMBackend only supports model " + "ModelFormat::TVMFormat/Backend::TVM, but now its " + << runtime_option.model_format << "/" << runtime_option.device + << std::endl; + return false; + } + + if (runtime_option.model_from_memory_) { + FDERROR << "TVMBackend doesn't support load model from memory, please " + "load model from disk." + << std::endl; + return false; + } + + if (!BuildDLDevice(runtime_option.device)) { + FDERROR << "TVMBackend only don't support run in this device." << std::endl; + return false; + } + + if (!BuildModel(runtime_option)) { + FDERROR << "TVMBackend only don't support run with this model path." + << std::endl; + return false; + } + + if (!InitInputAndOutputTensor()) { + FDERROR << "InitInputAndOutputTensor failed." << std::endl; + return false; + } + return true; +} + +bool TVMBackend::InitInputAndOutputTensor() { + input_tensor_.resize(NumInputs()); + for (int i = 0; i < NumInputs(); ++i) { + TensorInfo tensor_info = GetInputInfo(i); + tvm::ShapeTuple shape(tensor_info.shape.begin(), tensor_info.shape.end()); + input_tensor_[i] = tvm::runtime::NDArray::Empty( + shape, FDDataTypeToDLDataType(tensor_info.dtype), dev_); + } + + output_tensor_.resize(NumOutputs()); + for (int i = 0; i < NumOutputs(); ++i) { + TensorInfo tensor_info = GetOutputInfo(i); + tvm::ShapeTuple shape(tensor_info.shape.begin(), tensor_info.shape.end()); + output_tensor_[i] = tvm::runtime::NDArray::Empty( + shape, FDDataTypeToDLDataType(tensor_info.dtype), dev_); + } + return true; +} + +bool TVMBackend::BuildModel(const RuntimeOption& runtime_option) { + // load in the library + tvm::runtime::Module mod_factory = + tvm::runtime::Module::LoadFromFile(runtime_option.model_file + ".so"); + + // create the graph executor module + gmod_ = mod_factory.GetFunction("default")(dev_); + + // load params + std::ifstream params_in(runtime_option.model_file + ".params", + std::ios::binary); + std::string params_data((std::istreambuf_iterator(params_in)), + std::istreambuf_iterator()); + params_in.close(); + TVMByteArray params_arr; + params_arr.data = params_data.c_str(); + params_arr.size = params_data.length(); + tvm::runtime::PackedFunc load_params = gmod_.GetFunction("load_params"); + load_params(params_arr); + + // read input and output info + tvm::runtime::PackedFunc get_input_info = gmod_.GetFunction("get_input_info"); + tvm::Map input_info = get_input_info(); + auto input_info_shape = tvm::Downcast>( + input_info["shape"]); + inputs_desc_.reserve(input_info_shape.size()); + for (auto map_node : input_info_shape) { + std::string temp_name = map_node.first; + + tvm::ShapeTuple tup = map_node.second; + std::vector temp_shape{}; + temp_shape.resize(tup.size()); + for (int j = 0; j < tup.size(); ++j) { + temp_shape[j] = static_cast(tup[j]); + } + + FDDataType temp_dtype = fastdeploy::UNKNOWN1; + TensorInfo temp_input_info = {temp_name, temp_shape, temp_dtype}; + inputs_desc_.emplace_back(temp_input_info); + } + + int input_dtype_index = 0; + auto input_info_dtype = + tvm::Downcast>(input_info["dtype"]); + for (auto map_node : input_info_dtype) { + tvm::String tup = map_node.second; + inputs_desc_[input_dtype_index].dtype = TVMTensorTypeToFDDataType(tup); + input_dtype_index++; + } + + tvm::runtime::PackedFunc get_output_info = + gmod_.GetFunction("get_output_info"); + tvm::Map output_info = get_output_info(); + auto output_info_shape = + tvm::Downcast>( + output_info["shape"]); + outputs_desc_.reserve(output_info_shape.size()); + for (auto map_node : output_info_shape) { + std::string temp_name = map_node.first; + + tvm::ShapeTuple tup = map_node.second; + std::vector temp_shape{}; + temp_shape.resize(tup.size()); + for (int j = 0; j < tup.size(); ++j) { + temp_shape[j] = static_cast(tup[j]); + } + + FDDataType temp_dtype = fastdeploy::FP32; + TensorInfo temp_input_info = {temp_name, temp_shape, temp_dtype}; + outputs_desc_.emplace_back(temp_input_info); + } + + int output_dtype_index = 0; + auto output_info_dtype = + tvm::Downcast>(output_info["dtype"]); + for (auto map_node : output_info_dtype) { + tvm::String tup = map_node.second; + outputs_desc_[output_dtype_index].dtype = TVMTensorTypeToFDDataType(tup); + output_dtype_index++; + } + return true; +} + +FDDataType TVMBackend::TVMTensorTypeToFDDataType(tvm::String type) { + if (type == "float32") { + return FDDataType::FP32; + } + FDERROR << "FDDataType don't support this type" << std::endl; + return FDDataType::UNKNOWN1; +} + +bool TVMBackend::Infer(std::vector& inputs, + std::vector* outputs, bool copy_to_fd) { + for (int i = 0; i < inputs.size(); ++i) { + memcpy(input_tensor_[i]->data, inputs[i].Data(), inputs[i].Nbytes()); + } + + // get the function from the module(set input data) + tvm::runtime::PackedFunc set_input = gmod_.GetFunction("set_input"); + for (int i = 0; i < NumInputs(); ++i) { + set_input(GetInputInfo(i).name, input_tensor_[i]); + } + + // get the function from the module(run it) + tvm::runtime::PackedFunc run = gmod_.GetFunction("run"); + run(); + + // get the function from the module(get output data) + tvm::runtime::PackedFunc get_output = gmod_.GetFunction("get_output"); + for (int i = 0; i < NumOutputs(); ++i) { + get_output(i, output_tensor_[i]); + } + + // get result + outputs->resize(NumOutputs()); + std::vector temp_shape{}; + for (size_t i = 0; i < outputs_desc_.size(); ++i) { + temp_shape.resize(outputs_desc_[i].shape.size()); + for (int j = 0; j < outputs_desc_[i].shape.size(); ++j) { + temp_shape[j] = outputs_desc_[i].shape[j]; + } + (*outputs)[i].Resize(temp_shape, outputs_desc_[i].dtype, + outputs_desc_[i].name); + memcpy((*outputs)[i].MutableData(), + static_cast(output_tensor_[i]->data), + (*outputs)[i].Nbytes()); + } + return true; +} + +bool TVMBackend::BuildDLDevice(fastdeploy::Device device) { + if (device == Device::CPU) { + dev_ = DLDevice{kDLCPU, 0}; + } else { + FDERROR << "TVMBackend only support run in CPU." << std::endl; + return false; + } + return true; +} + +DLDataType TVMBackend::FDDataTypeToDLDataType(fastdeploy::FDDataType dtype) { + if (dtype == FDDataType::FP32) { + return DLDataType{kDLFloat, 32, 1}; + } + return {}; +} +} // namespace fastdeploy diff --git a/fastdeploy/runtime/backends/tvm/tvm_backend.h b/fastdeploy/runtime/backends/tvm/tvm_backend.h new file mode 100644 index 000000000..a40b964df --- /dev/null +++ b/fastdeploy/runtime/backends/tvm/tvm_backend.h @@ -0,0 +1,61 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "fastdeploy/core/fd_tensor.h" +#include "fastdeploy/runtime/backends/backend.h" +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace fastdeploy { +class TVMBackend : public BaseBackend { + public: + TVMBackend() = default; + virtual ~TVMBackend() = default; + bool Init(const RuntimeOption& runtime_option) override; + int NumInputs() const override { return inputs_desc_.size(); } + int NumOutputs() const override { return outputs_desc_.size(); } + TensorInfo GetInputInfo(int index) override { return inputs_desc_[index]; } + TensorInfo GetOutputInfo(int index) override { return outputs_desc_[index]; } + std::vector GetInputInfos() override { return inputs_desc_; } + std::vector GetOutputInfos() override { return outputs_desc_; } + bool Infer(std::vector& inputs, std::vector* outputs, + bool copy_to_fd = true) override; + + private: + DLDevice dev_{}; + tvm::runtime::Module gmod_; + std::vector inputs_desc_; + std::vector outputs_desc_; + + bool BuildDLDevice(Device device); + bool BuildModel(const RuntimeOption& runtime_option); + bool InitInputAndOutputTensor(); + + std::vector input_tensor_; + std::vector output_tensor_; + + FDDataType TVMTensorTypeToFDDataType(tvm::String type); + DLDataType FDDataTypeToDLDataType(FDDataType dtype); +}; +} // namespace fastdeploy diff --git a/fastdeploy/runtime/enum_variables.cc b/fastdeploy/runtime/enum_variables.cc index 61869740c..40e998946 100644 --- a/fastdeploy/runtime/enum_variables.cc +++ b/fastdeploy/runtime/enum_variables.cc @@ -32,8 +32,10 @@ std::ostream& operator<<(std::ostream& out, const Backend& backend) { out << "Backend::POROS"; } else if (backend == Backend::LITE) { out << "Backend::PDLITE"; - } else if(backend == Backend::HORIZONNPU){ + } else if (backend == Backend::HORIZONNPU) { out << "Backend::HORIZONNPU"; + } else if (backend == Backend::TVM) { + out << "Backend::TVM"; } else { out << "UNKNOWN-Backend"; } @@ -88,8 +90,9 @@ std::ostream& operator<<(std::ostream& out, const ModelFormat& format) { out << "ModelFormat::TORCHSCRIPT"; } else if (format == ModelFormat::HORIZON) { out << "ModelFormat::HORIZON"; - } - else { + } else if (format == ModelFormat::TVMFormat) { + out << "ModelFormat::TVMFormat"; + } else { out << "UNKNOWN-ModelFormat"; } return out; @@ -123,6 +126,9 @@ std::vector GetAvailableBackends() { #endif #ifdef ENABLE_SOPHGO_BACKEND backends.push_back(Backend::SOPHGOTPU); +#endif +#ifdef ENABLE_TVM_BACKEND + backends.push_back(Backend::TVM); #endif return backends; } diff --git a/fastdeploy/runtime/enum_variables.h b/fastdeploy/runtime/enum_variables.h index 0e23f21ee..7a13f4b76 100644 --- a/fastdeploy/runtime/enum_variables.h +++ b/fastdeploy/runtime/enum_variables.h @@ -39,6 +39,7 @@ enum Backend { RKNPU2, ///< RKNPU2, support RKNN format model, Rockchip NPU only SOPHGOTPU, ///< SOPHGOTPU, support SOPHGO format model, Sophgo TPU only HORIZONNPU, ///< HORIZONNPU, support Horizon format model, Horizon NPU + TVM, ///< TVMBackend, support TVM format model, CPU / Nvidia GPU }; /** @@ -74,6 +75,7 @@ enum ModelFormat { TORCHSCRIPT, ///< Model with TorchScript format SOPHGO, ///< Model with SOPHGO format HORIZON, ///< Model with HORIZON format + TVMFormat, ///< Model with TVM format }; /// Describle all the supported backends for specified model format @@ -85,16 +87,17 @@ static std::map> {ModelFormat::RKNN, {Backend::RKNPU2}}, {ModelFormat::HORIZON, {Backend::HORIZONNPU}}, {ModelFormat::TORCHSCRIPT, {Backend::POROS}}, - {ModelFormat::SOPHGO, {Backend::SOPHGOTPU}} + {ModelFormat::SOPHGO, {Backend::SOPHGOTPU}}, + {ModelFormat::TVMFormat, {Backend::TVM}} }; /// Describle all the supported backends for specified device static std::map> s_default_backends_by_device = { {Device::CPU, {Backend::LITE, Backend::PDINFER, Backend::ORT, - Backend::OPENVINO, Backend::POROS}}, + Backend::OPENVINO, Backend::POROS, Backend::TVM}}, {Device::GPU, {Backend::LITE, Backend::PDINFER, Backend::ORT, - Backend::TRT, Backend::POROS}}, + Backend::TRT, Backend::POROS, Backend::TVM}}, {Device::RKNPU, {Backend::RKNPU2}}, {Device::SUNRISENPU, {Backend::HORIZONNPU}}, {Device::IPU, {Backend::PDINFER}}, diff --git a/fastdeploy/runtime/runtime.cc b/fastdeploy/runtime/runtime.cc index 785f6f1bc..72004d0da 100644 --- a/fastdeploy/runtime/runtime.cc +++ b/fastdeploy/runtime/runtime.cc @@ -53,6 +53,10 @@ #include "fastdeploy/runtime/backends/horizon/horizon_backend.h" #endif +#ifdef ENABLE_TVM_BACKEND +#include "fastdeploy/runtime/backends/tvm/tvm_backend.h" +#endif + namespace fastdeploy { bool AutoSelectBackend(RuntimeOption& option) { @@ -159,10 +163,11 @@ bool Runtime::Init(const RuntimeOption& _option) { CreateSophgoNPUBackend(); } else if (option.backend == Backend::POROS) { CreatePorosBackend(); - } else if (option.backend == Backend::HORIZONNPU){ + } else if (option.backend == Backend::HORIZONNPU) { CreateHorizonBackend(); - } - else { + } else if (option.backend == Backend::TVM) { + CreateTVMBackend(); + } else { std::string msg = Str(GetAvailableBackends()); FDERROR << "The compiled FastDeploy only supports " << msg << ", " << option.backend << " is not supported now." << std::endl; @@ -287,6 +292,19 @@ void Runtime::CreateOpenVINOBackend() { << "." << std::endl; } +void Runtime::CreateTVMBackend() { +#ifdef ENABLE_TVM_BACKEND + backend_ = utils::make_unique(); + FDASSERT(backend_->Init(option), "Failed to initialize TVM backend."); +#else + FDASSERT(false, + "TVMBackend is not available, please compiled with " + "ENABLE_TVM_BACKEND=ON."); +#endif + FDINFO << "Runtime initialized with Backend::TVM in " << option.device << "." + << std::endl; +} + void Runtime::CreateOrtBackend() { #ifdef ENABLE_ORT_BACKEND backend_ = utils::make_unique(); @@ -342,15 +360,14 @@ void Runtime::CreateRKNPU2Backend() { << "." << std::endl; } -void Runtime::CreateHorizonBackend(){ +void Runtime::CreateHorizonBackend() { #ifdef ENABLE_HORIZON_BACKEND backend_ = utils::make_unique(); FDASSERT(backend_->Init(option), "Failed to initialize Horizon backend."); #else - FDASSERT(false, - "HorizonBackend is not available, please compiled with ", + FDASSERT(false, "HorizonBackend is not available, please compiled with ", " ENABLE_HORIZON_BACKEND=ON."); -#endif +#endif FDINFO << "Runtime initialized with Backend::HORIZONNPU in " << option.device << "." << std::endl; } diff --git a/fastdeploy/runtime/runtime.h b/fastdeploy/runtime/runtime.h index d4b5844ba..a1adf53fa 100755 --- a/fastdeploy/runtime/runtime.h +++ b/fastdeploy/runtime/runtime.h @@ -118,6 +118,7 @@ struct FASTDEPLOY_DECL Runtime { void CreateHorizonBackend(); void CreateSophgoNPUBackend(); void CreatePorosBackend(); + void CreateTVMBackend(); std::unique_ptr backend_; std::vector input_tensors_; std::vector output_tensors_; diff --git a/fastdeploy/runtime/runtime_option.cc b/fastdeploy/runtime/runtime_option.cc index 2af84d482..a2a232ced 100644 --- a/fastdeploy/runtime/runtime_option.cc +++ b/fastdeploy/runtime/runtime_option.cc @@ -49,11 +49,11 @@ void RuntimeOption::UseGpu(int gpu_id) { #if defined(WITH_GPU) || defined(WITH_OPENCL) device = Device::GPU; device_id = gpu_id; - + #if defined(WITH_OPENCL) && defined(ENABLE_LITE_BACKEND) paddle_lite_option.device = device; #endif - + #else FDWARNING << "The FastDeploy didn't compile with GPU, will force to use CPU." << std::endl; @@ -70,9 +70,7 @@ void RuntimeOption::UseRKNPU2(fastdeploy::rknpu2::CpuName rknpu2_name, device = Device::RKNPU; } -void RuntimeOption::UseHorizon(){ - device = Device::SUNRISENPU; -} +void RuntimeOption::UseHorizon() { device = Device::SUNRISENPU; } void RuntimeOption::UseTimVX() { device = Device::TIMVX; @@ -84,8 +82,7 @@ void RuntimeOption::UseKunlunXin(int kunlunxin_id, bool locked, bool autotune, const std::string& autotune_file, const std::string& precision, - bool adaptive_seqlen, - bool enable_multi_stream, + bool adaptive_seqlen, bool enable_multi_stream, int64_t gm_default_size) { #ifdef WITH_KUNLUNXIN device = Device::KUNLUNXIN; @@ -236,7 +233,7 @@ void RuntimeOption::UseLiteBackend() { #endif } -void RuntimeOption::UseHorizonNPUBackend(){ +void RuntimeOption::UseHorizonNPUBackend() { #ifdef ENABLE_HORIZON_BACKEND backend = Backend::HORIZONNPU; #else @@ -524,4 +521,12 @@ void RuntimeOption::DisablePaddleTrtOPs(const std::vector& ops) { paddle_infer_option.DisableTrtOps(ops); } +void RuntimeOption::UseTVMBackend() { +#ifdef ENABLE_TVM_BACKEND + backend = Backend::TVM; +#else + FDASSERT(false, "The FastDeploy didn't compile with TVMBackend."); +#endif +} + } // namespace fastdeploy diff --git a/fastdeploy/runtime/runtime_option.h b/fastdeploy/runtime/runtime_option.h index 38fda025b..205a2184c 100755 --- a/fastdeploy/runtime/runtime_option.h +++ b/fastdeploy/runtime/runtime_option.h @@ -31,6 +31,7 @@ #include "fastdeploy/runtime/backends/rknpu2/option.h" #include "fastdeploy/runtime/backends/sophgo/option.h" #include "fastdeploy/runtime/backends/tensorrt/option.h" +#include "fastdeploy/runtime/backends/tvm/option.h" #include "fastdeploy/benchmark/option.h" namespace fastdeploy { @@ -160,6 +161,8 @@ struct FASTDEPLOY_DECL RuntimeOption { LiteBackendOption paddle_lite_option; /// Option to configure RKNPU2 backend RKNPU2BackendOption rknpu2_option; + /// Option to configure TVM backend + TVMBackendOption tvm_option; // \brief Set the profile mode as 'true'. // @@ -282,6 +285,7 @@ struct FASTDEPLOY_DECL RuntimeOption { void UsePaddleBackend(); void UseLiteBackend(); void UseHorizonNPUBackend(); + void UseTVMBackend(); }; } // namespace fastdeploy diff --git a/fastdeploy/vision/detection/ppdet/model.h b/fastdeploy/vision/detection/ppdet/model.h index ec6b9fbfd..8d242a156 100755 --- a/fastdeploy/vision/detection/ppdet/model.h +++ b/fastdeploy/vision/detection/ppdet/model.h @@ -62,12 +62,12 @@ class FASTDEPLOY_DECL SOLOv2 : public PPDetBase { * \param[in] model_format Model format of the loaded model, default is Paddle format */ SOLOv2(const std::string& model_file, const std::string& params_file, - const std::string& config_file, - const RuntimeOption& custom_option = RuntimeOption(), - const ModelFormat& model_format = ModelFormat::PADDLE) + const std::string& config_file, + const RuntimeOption& custom_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::PADDLE) : PPDetBase(model_file, params_file, config_file, custom_option, model_format) { - valid_cpu_backends = { Backend::PDINFER}; + valid_cpu_backends = {Backend::PDINFER}; valid_gpu_backends = {Backend::PDINFER, Backend::TRT}; initialized = Initialize(); } @@ -92,7 +92,7 @@ class FASTDEPLOY_DECL PPYOLOE : public PPDetBase { : PPDetBase(model_file, params_file, config_file, custom_option, model_format) { valid_cpu_backends = {Backend::OPENVINO, Backend::ORT, Backend::PDINFER, - Backend::LITE}; + Backend::LITE, Backend::TVM}; valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT}; valid_timvx_backends = {Backend::LITE}; valid_kunlunxin_backends = {Backend::LITE}; @@ -468,9 +468,9 @@ class FASTDEPLOY_DECL PaddleDetectionModel : public PPDetBase { class FASTDEPLOY_DECL PPYOLOER : public PPDetBase { public: PPYOLOER(const std::string& model_file, const std::string& params_file, - const std::string& config_file, - const RuntimeOption& custom_option = RuntimeOption(), - const ModelFormat& model_format = ModelFormat::PADDLE) + const std::string& config_file, + const RuntimeOption& custom_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::PADDLE) : PPDetBase(model_file, params_file, config_file, custom_option, model_format) { valid_cpu_backends = {Backend::PDINFER, Backend::OPENVINO, Backend::ORT, diff --git a/tools/tvm/paddle2tvm.py b/tools/tvm/paddle2tvm.py new file mode 100644 index 000000000..c8c82ccb9 --- /dev/null +++ b/tools/tvm/paddle2tvm.py @@ -0,0 +1,58 @@ +import paddle +import tvm +from tvm import relay +from tvm.contrib import graph_executor +import os +import argparse + + +def get_config(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_path", default="./picodet_l_320_coco_lcnet/model") + parser.add_argument( + "--shape_dict", + default={"image": [1, 3, 320, 320], + "scale_factor": [1, 2]}) + parser.add_argument("--tvm_save_name", default="tvm_model") + parser.add_argument("--tvm_save_path", default="./tvm_save") + args = parser.parse_args() + return args + + +def read_model(model_path): + return paddle.jit.load(model_path) + + +def paddle_to_tvm(paddle_model, + shape_dict, + tvm_save_name="tvm_model", + tvm_save_path="./tvm_save"): + if isinstance(shape_dict, str): + shape_dict = eval(shape_dict) + mod, params = relay.frontend.from_paddle(paddle_model, shape_dict) + # 这里首先在PC的CPU上进行测试 所以使用LLVM进行导出 + target = tvm.target.Target("llvm", host="llvm") + dev = tvm.cpu(0) + # 这里利用TVM构建出优化后模型的信息 + with tvm.transform.PassContext(opt_level=2): + base_lib = relay.build_module.build(mod, target, params=params) + if not os.path.exists(tvm_save_path): + os.mkdir(tvm_save_path) + lib_save_path = os.path.join(tvm_save_path, tvm_save_name + ".so") + base_lib.export_library(lib_save_path) + param_save_path = os.path.join(tvm_save_path, + tvm_save_name + ".params") + with open(param_save_path, 'wb') as fo: + fo.write(relay.save_param_dict(base_lib.get_params())) + module = graph_executor.GraphModule(base_lib['default'](dev)) + module.load_params(relay.save_param_dict(base_lib.get_params())) + print("export success") + + +if __name__ == "__main__": + config = get_config() + paddle_model = read_model(config.model_path) + shape_dict = config.shape_dict + paddle_to_tvm(paddle_model, shape_dict, config.tvm_save_name, + config.tvm_save_path)