[Hackathon 181] Add TVM support for FastDeploy on macOS (#1969)

* update for tvm backend

* update third_party

* update third_party

* update

* update

* update

* update

* update

* update

* update

* update

---------

Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
This commit is contained in:
Zheng-Bicheng
2023-05-25 19:59:02 +08:00
committed by GitHub
parent 49c033a828
commit 643730bf5f
20 changed files with 658 additions and 31 deletions

View File

@@ -63,6 +63,7 @@ option(ENABLE_POROS_BACKEND "Whether to enable poros backend." OFF)
option(ENABLE_OPENVINO_BACKEND "Whether to enable openvino backend." OFF)
option(ENABLE_RKNPU2_BACKEND "Whether to enable RKNPU2 backend." OFF)
option(ENABLE_SOPHGO_BACKEND "Whether to enable SOPHON backend." OFF)
option(ENABLE_TVM_BACKEND "Whether to enable TVM backend." OFF)
option(ENABLE_LITE_BACKEND "Whether to enable paddle lite backend." OFF)
option(ENABLE_HORIZON_BACKEND "Whether to enable HORIZON backend." OFF)
option(ENABLE_VISION "Whether to enable vision models usage." OFF)
@@ -169,6 +170,7 @@ file(GLOB_RECURSE DEPLOY_OPENVINO_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/f
file(GLOB_RECURSE DEPLOY_RKNPU2_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/runtime/backends/rknpu2/*.cc)
file(GLOB_RECURSE DEPLOY_HORIZON_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/runtime/backends/horizon/*.cc)
file(GLOB_RECURSE DEPLOY_SOPHGO_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/runtime/backends/sophgo/*.cc)
file(GLOB_RECURSE DEPLOY_TVM_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/runtime/backends/tvm/*.cc)
file(GLOB_RECURSE DEPLOY_LITE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/runtime/backends/lite/*.cc)
file(GLOB_RECURSE DEPLOY_ENCRYPTION_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/encryption/*.cc)
file(GLOB_RECURSE DEPLOY_PIPELINE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pipeline/*.cc)
@@ -188,7 +190,8 @@ list(REMOVE_ITEM ALL_DEPLOY_SRCS ${DEPLOY_ORT_SRCS} ${DEPLOY_PADDLE_SRCS}
${DEPLOY_OPENVINO_SRCS} ${DEPLOY_LITE_SRCS}
${DEPLOY_VISION_SRCS} ${DEPLOY_TEXT_SRCS}
${DEPLOY_PIPELINE_SRCS} ${DEPLOY_RKNPU2_SRCS}
${DEPLOY_SOPHGO_SRCS} ${DEPLOY_ENCRYPTION_SRCS} ${DEPLOY_HORIZON_SRCS})
${DEPLOY_SOPHGO_SRCS} ${DEPLOY_ENCRYPTION_SRCS}
${DEPLOY_HORIZON_SRCS} ${DEPLOY_TVM_SRCS})
set(DEPEND_LIBS "")
@@ -263,6 +266,14 @@ if(ENABLE_HORIZON_BACKEND)
list(APPEND DEPEND_LIBS ${BPU_libs})
endif()
if(ENABLE_TVM_BACKEND)
set(CMAKE_CXX_STANDARD 17)
add_definitions(-DENABLE_TVM_BACKEND)
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_TVM_SRCS})
include(${PROJECT_SOURCE_DIR}/cmake/tvm.cmake)
list(APPEND DEPEND_LIBS ${TVM_RUNTIME_LIB})
endif()
if(ENABLE_SOPHGO_BACKEND)
add_definitions(-DENABLE_SOPHGO_BACKEND)
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_SOPHGO_SRCS})

View File

@@ -24,6 +24,7 @@ set(RKNN2_TARGET_SOC "@RKNN2_TARGET_SOC@")
# Inference backend and FastDeploy Moudle
set(ENABLE_ORT_BACKEND @ENABLE_ORT_BACKEND@)
set(ENABLE_RKNPU2_BACKEND @ENABLE_RKNPU2_BACKEND@)
set(ENABLE_TVM_BACKEND @ENABLE_TVM_BACKEND@)
set(ENABLE_HORIZON_BACKEND @ENABLE_HORIZON_BACKEND@)
set(ENABLE_SOPHGO_BACKEND @ENABLE_SOPHGO_BACKEND@)
set(ENABLE_LITE_BACKEND @ENABLE_LITE_BACKEND@)
@@ -129,6 +130,15 @@ if(ENABLE_ORT_BACKEND)
list(APPEND FASTDEPLOY_LIBS ${ORT_LIB})
endif()
if(ENABLE_TVM_BACKEND)
if(APPLE)
set(TVM_RUNTIME_LIB ${CMAKE_CURRENT_LIST_DIR}/third_libs/install/tvm/lib/libtvm_runtime.dylib)
else()
set(TVM_RUNTIME_LIB ${CMAKE_CURRENT_LIST_DIR}/third_libs/install/tvm/lib/libtvm_runtime.so)
endif()
list(APPEND FASTDEPLOY_LIBS ${TVM_RUNTIME_LIB})
endif()
if(ENABLE_PADDLE_BACKEND)
find_library(PADDLE_LIB paddle_inference ${CMAKE_CURRENT_LIST_DIR}/third_libs/install/paddle_inference/paddle/lib NO_DEFAULT_PATH)
if(WIN32)

View File

@@ -40,6 +40,7 @@ function(fastdeploy_summary)
message(STATUS " ENABLE_POROS_BACKEND : ${ENABLE_POROS_BACKEND}")
message(STATUS " ENABLE_TRT_BACKEND : ${ENABLE_TRT_BACKEND}")
message(STATUS " ENABLE_OPENVINO_BACKEND : ${ENABLE_OPENVINO_BACKEND}")
message(STATUS " ENABLE_TVM_BACKEND : ${ENABLE_TVM_BACKEND}")
message(STATUS " ENABLE_BENCHMARK : ${ENABLE_BENCHMARK}")
message(STATUS " ENABLE_VISION : ${ENABLE_VISION}")
message(STATUS " ENABLE_TEXT : ${ENABLE_TEXT}")

55
cmake/tvm.cmake Normal file
View File

@@ -0,0 +1,55 @@
# set path
set(TVM_URL_BASE "https://bj.bcebos.com/fastdeploy/third_libs/")
set(TVM_VERSION "0.12.0")
set(TVM_SYSTEM "")
if (${CMAKE_SYSTEM} MATCHES "Darwin")
if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "arm64")
set(TVM_SYSTEM "macos-arm64")
endif ()
elseif (${CMAKE_SYSTEM} MATCHES "Linux")
if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86")
set(TVM_SYSTEM "linux-x86")
endif ()
else ()
error("TVM only support MacOS in Arm64 or linux in x86")
endif ()
set(TVM_FILE "tvm-${TVM_SYSTEM}-${TVM_VERSION}.tgz")
set(TVM_URL "${TVM_URL_BASE}${TVM_FILE}")
set(TVM_RUNTIME_PATH "${THIRD_PARTY_PATH}/install/tvm")
execute_process(COMMAND ${CMAKE_COMMAND} -E make_directory "${TVM_RUNTIME_PATH}")
download_and_decompress(${TVM_URL}
"${CMAKE_CURRENT_BINARY_DIR}/${TVM_FILE}"
"${THIRD_PARTY_PATH}/install/")
include_directories(${TVM_RUNTIME_PATH}/include)
# copy dlpack to third_party
set(DLPACK_PATH "${THIRD_PARTY_PATH}/install/dlpack")
execute_process(COMMAND ${CMAKE_COMMAND} -E make_directory "${DLPACK_PATH}")
execute_process(COMMAND ${CMAKE_COMMAND} -E copy_directory
"${PROJECT_SOURCE_DIR}/third_party/dlpack"
"${THIRD_PARTY_PATH}/install/")
include_directories(${DLPACK_PATH}/include)
set(DMLC_CORE_PATH "${THIRD_PARTY_PATH}/install/dmlc-core")
execute_process(COMMAND ${CMAKE_COMMAND} -E make_directory "${DMLC_CORE_PATH}")
set(DMLC_CORE_URL https://bj.bcebos.com/fastdeploy/third_libs/dmlc-core.tgz)
download_and_decompress(${DMLC_CORE_URL}
"${CMAKE_CURRENT_BINARY_DIR}/dmlc-core.tgz"
"${THIRD_PARTY_PATH}/install/")
include_directories(${DMLC_CORE_PATH}/include)
# include lib
if (EXISTS ${TVM_RUNTIME_PATH})
if (${CMAKE_SYSTEM} MATCHES "Darwin")
set(TVM_RUNTIME_LIB ${TVM_RUNTIME_PATH}/lib/libtvm_runtime.dylib)
elseif (${CMAKE_SYSTEM} MATCHES "Linux")
set(TVM_RUNTIME_LIB ${TVM_RUNTIME_PATH}/lib/libtvm_runtime.so)
endif ()
include(${TVM_RUNTIME_PATH}/lib/cmake/tvm/tvmConfig.cmake)
add_definitions(-DDMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
else ()
error(FATAL_ERROR "[tvm.cmake] TVM_RUNTIME_PATH does not exist.")
endif ()

View File

@@ -0,0 +1,35 @@
[English](README.md) | 简体中文
# PaddleDetection TVM部署示例
在TVM上已经通过测试的PaddleDetection模型如下:
* picodet
* PPYOLOE
### Paddle模型转换为TVM模型
由于TVM不支持NMS算子因此在转换模型前我们需要对PaddleDetection模型进行裁剪将模型的输出节点改为NMS节点的输入节点。
输入以下命令你将得到一个裁剪后的PPYOLOE模型。
```bash
git clone https://github.com/PaddlePaddle/Paddle2ONNX.git
cd Paddle2ONNX/tools/paddle
wget https://bj.bcebos.com/fastdeploy/models/ppyoloe_plus_crn_m_80e_coco.tgz
tar xvf ppyoloe_plus_crn_m_80e_coco.tgz
python prune_paddle_model.py --model_dir ppyoloe_plus_crn_m_80e_coco \
--model_filename model.pdmodel \
--params_filename model.pdiparams \
--output_names tmp_17 concat_14.tmp_0 \
--save_dir ppyoloe_plus_crn_m_80e_coco
```
裁剪完模型后我们就可以通过tvm python库实现编译模型这里为了方便大家使用提供了转换脚本。
输入以下命令你将得到转换过后的TVM模型。
注意FastDeploy在推理PPYOLOE时不关依赖模型还依赖yml文件因此你还需要将对应的yml文件拷贝到模型目录下。
```bash
python path/to/FastDeploy/tools/tvm/paddle2tvm.py --model_path=./ppyoloe_plus_crn_m_80e_coco/model \
--shape_dict="{'image': [1, 3, 640, 640], 'scale_factor': [1, 2]}"
cp ppyoloe_plus_crn_m_80e_coco/infer_cfg.yml tvm_save
```

View File

@@ -0,0 +1,13 @@
PROJECT(infer_demo C CXX)
CMAKE_MINIMUM_REQUIRED (VERSION 3.10)
# 指定下载解压后的fastdeploy库路径
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
# 添加FastDeploy依赖头文件
include_directories(${FASTDEPLOY_INCS})
add_executable(infer_ppyoloe_demo ${PROJECT_SOURCE_DIR}/infer_ppyoloe_demo.cc)
target_link_libraries(infer_ppyoloe_demo ${FASTDEPLOY_LIBS})

View File

@@ -0,0 +1,60 @@
[English](README.md) | 简体中文
# PaddleDetection C++部署示例
本目录下提供`infer_ppyoloe_demo.cc`快速完成PPDetection模型使用TVM加速部署的示例。
## 转换模型并运行
```bash
# build example
mkdir build
cd build
cmake .. -DFASTDEPLOY_INSTALL_DIR=/path/to/fastdeploy-sdk
make -j
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
./infer_ppyoloe_demo ../tvm_save 000000014439.jpg
```
## PaddleDetection C++接口
### 模型类
PaddleDetection目前支持6种模型系列类名分别为`PPYOLOE`, `PicoDet`, `PaddleYOLOX`, `PPYOLO`, `FasterRCNN``SSD`,`PaddleYOLOv5`,`PaddleYOLOv6`,`PaddleYOLOv7`,`RTMDet`,`CascadeRCNN`,`PSSDet`,`RetinaNet`,`PPYOLOESOD`,`FCOS`,`TTFNet`,`TOOD`,`GFL`所有类名的构造函数和预测函数在参数上完全一致本文档以PPYOLOE为例讲解API
```c++
fastdeploy::vision::detection::PPYOLOE(
const string& model_file,
const string& params_file,
const string& config_file
const RuntimeOption& runtime_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE)
```
PaddleDetection PPYOLOE模型加载和初始化其中model_file为导出的ONNX模型格式。
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径
> * **config_file**(str): 配置文件路径即PaddleDetection导出的部署yaml文件
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(ModelFormat): 模型格式默认为PADDLE格式
#### Predict函数
> ```c++
> PPYOLOE::Predict(cv::Mat* im, DetectionResult* result)
> ```
>
> 模型预测接口,输入图像直接输出检测结果。
>
> **参数**
>
> > * **im**: 输入图像注意需为HWCBGR格式
> > * **result**: 检测结果,包括检测框,各个框的置信度, DetectionResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/)
- [模型介绍](../../)
- [Python部署](../python)
- [视觉模型预测结果](../../../../../docs/api/vision_results/)
- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md)

View File

@@ -0,0 +1,57 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
void TVMInfer(const std::string& model_dir, const std::string& image_file) {
auto model_file = model_dir + "/tvm_model";
auto params_file = "";
auto config_file = model_dir + "/infer_cfg.yml";
auto option = fastdeploy::RuntimeOption();
option.UseCpu();
option.UseTVMBackend();
auto format = fastdeploy::ModelFormat::TVMFormat;
auto model = fastdeploy::vision::detection::PPYOLOE(
model_file, params_file, config_file, option, format);
model.GetPostprocessor().ApplyNMS();
auto im = cv::imread(image_file);
fastdeploy::vision::DetectionResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
std::cout << res.Str() << std::endl;
auto vis_im = fastdeploy::vision::VisDetection(im, res, 0.5);
cv::imwrite("infer.jpg", vis_im);
std::cout << "Visualized result saved in ./infer.jpg" << std::endl;
}
int main(int argc, char* argv[]) {
if (argc < 3) {
std::cout
<< "Usage: infer_demo path/to/model_dir path/to/image run_option, "
"e.g ./infer_model ./picodet_model_dir ./test.jpeg"
<< std::endl;
return -1;
}
TVMInfer(argv[1], argv[2]);
return 0;
}

View File

@@ -72,3 +72,7 @@
#ifndef ENABLE_HORIZON_BACKEND
#cmakedefine ENABLE_HORIZON_BACKEND
#endif
#ifndef ENABLE_TVM_BACKEND
#cmakedefine ENABLE_TVM_BACKEND
#endif

View File

@@ -0,0 +1,21 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace fastdeploy {
struct TVMBackendOption {
TVMBackendOption() {}
};
} // namespace fastdeploy

View File

@@ -0,0 +1,205 @@
#include "fastdeploy/runtime/backends/tvm/tvm_backend.h"
#include "yaml-cpp/yaml.h"
namespace fastdeploy {
bool TVMBackend::Init(const fastdeploy::RuntimeOption& runtime_option) {
if (!(Supported(runtime_option.model_format, Backend::TVM) &&
Supported(runtime_option.device, Backend::TVM))) {
FDERROR << "TVMBackend only supports model "
"ModelFormat::TVMFormat/Backend::TVM, but now its "
<< runtime_option.model_format << "/" << runtime_option.device
<< std::endl;
return false;
}
if (runtime_option.model_from_memory_) {
FDERROR << "TVMBackend doesn't support load model from memory, please "
"load model from disk."
<< std::endl;
return false;
}
if (!BuildDLDevice(runtime_option.device)) {
FDERROR << "TVMBackend only don't support run in this device." << std::endl;
return false;
}
if (!BuildModel(runtime_option)) {
FDERROR << "TVMBackend only don't support run with this model path."
<< std::endl;
return false;
}
if (!InitInputAndOutputTensor()) {
FDERROR << "InitInputAndOutputTensor failed." << std::endl;
return false;
}
return true;
}
bool TVMBackend::InitInputAndOutputTensor() {
input_tensor_.resize(NumInputs());
for (int i = 0; i < NumInputs(); ++i) {
TensorInfo tensor_info = GetInputInfo(i);
tvm::ShapeTuple shape(tensor_info.shape.begin(), tensor_info.shape.end());
input_tensor_[i] = tvm::runtime::NDArray::Empty(
shape, FDDataTypeToDLDataType(tensor_info.dtype), dev_);
}
output_tensor_.resize(NumOutputs());
for (int i = 0; i < NumOutputs(); ++i) {
TensorInfo tensor_info = GetOutputInfo(i);
tvm::ShapeTuple shape(tensor_info.shape.begin(), tensor_info.shape.end());
output_tensor_[i] = tvm::runtime::NDArray::Empty(
shape, FDDataTypeToDLDataType(tensor_info.dtype), dev_);
}
return true;
}
bool TVMBackend::BuildModel(const RuntimeOption& runtime_option) {
// load in the library
tvm::runtime::Module mod_factory =
tvm::runtime::Module::LoadFromFile(runtime_option.model_file + ".so");
// create the graph executor module
gmod_ = mod_factory.GetFunction("default")(dev_);
// load params
std::ifstream params_in(runtime_option.model_file + ".params",
std::ios::binary);
std::string params_data((std::istreambuf_iterator<char>(params_in)),
std::istreambuf_iterator<char>());
params_in.close();
TVMByteArray params_arr;
params_arr.data = params_data.c_str();
params_arr.size = params_data.length();
tvm::runtime::PackedFunc load_params = gmod_.GetFunction("load_params");
load_params(params_arr);
// read input and output info
tvm::runtime::PackedFunc get_input_info = gmod_.GetFunction("get_input_info");
tvm::Map<tvm::String, tvm::ObjectRef> input_info = get_input_info();
auto input_info_shape = tvm::Downcast<tvm::Map<tvm::String, tvm::ShapeTuple>>(
input_info["shape"]);
inputs_desc_.reserve(input_info_shape.size());
for (auto map_node : input_info_shape) {
std::string temp_name = map_node.first;
tvm::ShapeTuple tup = map_node.second;
std::vector<int> temp_shape{};
temp_shape.resize(tup.size());
for (int j = 0; j < tup.size(); ++j) {
temp_shape[j] = static_cast<int>(tup[j]);
}
FDDataType temp_dtype = fastdeploy::UNKNOWN1;
TensorInfo temp_input_info = {temp_name, temp_shape, temp_dtype};
inputs_desc_.emplace_back(temp_input_info);
}
int input_dtype_index = 0;
auto input_info_dtype =
tvm::Downcast<tvm::Map<tvm::String, tvm::String>>(input_info["dtype"]);
for (auto map_node : input_info_dtype) {
tvm::String tup = map_node.second;
inputs_desc_[input_dtype_index].dtype = TVMTensorTypeToFDDataType(tup);
input_dtype_index++;
}
tvm::runtime::PackedFunc get_output_info =
gmod_.GetFunction("get_output_info");
tvm::Map<tvm::String, tvm::ObjectRef> output_info = get_output_info();
auto output_info_shape =
tvm::Downcast<tvm::Map<tvm::String, tvm::ShapeTuple>>(
output_info["shape"]);
outputs_desc_.reserve(output_info_shape.size());
for (auto map_node : output_info_shape) {
std::string temp_name = map_node.first;
tvm::ShapeTuple tup = map_node.second;
std::vector<int> temp_shape{};
temp_shape.resize(tup.size());
for (int j = 0; j < tup.size(); ++j) {
temp_shape[j] = static_cast<int>(tup[j]);
}
FDDataType temp_dtype = fastdeploy::FP32;
TensorInfo temp_input_info = {temp_name, temp_shape, temp_dtype};
outputs_desc_.emplace_back(temp_input_info);
}
int output_dtype_index = 0;
auto output_info_dtype =
tvm::Downcast<tvm::Map<tvm::String, tvm::String>>(output_info["dtype"]);
for (auto map_node : output_info_dtype) {
tvm::String tup = map_node.second;
outputs_desc_[output_dtype_index].dtype = TVMTensorTypeToFDDataType(tup);
output_dtype_index++;
}
return true;
}
FDDataType TVMBackend::TVMTensorTypeToFDDataType(tvm::String type) {
if (type == "float32") {
return FDDataType::FP32;
}
FDERROR << "FDDataType don't support this type" << std::endl;
return FDDataType::UNKNOWN1;
}
bool TVMBackend::Infer(std::vector<FDTensor>& inputs,
std::vector<FDTensor>* outputs, bool copy_to_fd) {
for (int i = 0; i < inputs.size(); ++i) {
memcpy(input_tensor_[i]->data, inputs[i].Data(), inputs[i].Nbytes());
}
// get the function from the module(set input data)
tvm::runtime::PackedFunc set_input = gmod_.GetFunction("set_input");
for (int i = 0; i < NumInputs(); ++i) {
set_input(GetInputInfo(i).name, input_tensor_[i]);
}
// get the function from the module(run it)
tvm::runtime::PackedFunc run = gmod_.GetFunction("run");
run();
// get the function from the module(get output data)
tvm::runtime::PackedFunc get_output = gmod_.GetFunction("get_output");
for (int i = 0; i < NumOutputs(); ++i) {
get_output(i, output_tensor_[i]);
}
// get result
outputs->resize(NumOutputs());
std::vector<int64_t> temp_shape{};
for (size_t i = 0; i < outputs_desc_.size(); ++i) {
temp_shape.resize(outputs_desc_[i].shape.size());
for (int j = 0; j < outputs_desc_[i].shape.size(); ++j) {
temp_shape[j] = outputs_desc_[i].shape[j];
}
(*outputs)[i].Resize(temp_shape, outputs_desc_[i].dtype,
outputs_desc_[i].name);
memcpy((*outputs)[i].MutableData(),
static_cast<float*>(output_tensor_[i]->data),
(*outputs)[i].Nbytes());
}
return true;
}
bool TVMBackend::BuildDLDevice(fastdeploy::Device device) {
if (device == Device::CPU) {
dev_ = DLDevice{kDLCPU, 0};
} else {
FDERROR << "TVMBackend only support run in CPU." << std::endl;
return false;
}
return true;
}
DLDataType TVMBackend::FDDataTypeToDLDataType(fastdeploy::FDDataType dtype) {
if (dtype == FDDataType::FP32) {
return DLDataType{kDLFloat, 32, 1};
}
return {};
}
} // namespace fastdeploy

View File

@@ -0,0 +1,61 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/runtime/backends/backend.h"
#include <cstring>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include <dlpack/dlpack.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <unistd.h>
namespace fastdeploy {
class TVMBackend : public BaseBackend {
public:
TVMBackend() = default;
virtual ~TVMBackend() = default;
bool Init(const RuntimeOption& runtime_option) override;
int NumInputs() const override { return inputs_desc_.size(); }
int NumOutputs() const override { return outputs_desc_.size(); }
TensorInfo GetInputInfo(int index) override { return inputs_desc_[index]; }
TensorInfo GetOutputInfo(int index) override { return outputs_desc_[index]; }
std::vector<TensorInfo> GetInputInfos() override { return inputs_desc_; }
std::vector<TensorInfo> GetOutputInfos() override { return outputs_desc_; }
bool Infer(std::vector<FDTensor>& inputs, std::vector<FDTensor>* outputs,
bool copy_to_fd = true) override;
private:
DLDevice dev_{};
tvm::runtime::Module gmod_;
std::vector<TensorInfo> inputs_desc_;
std::vector<TensorInfo> outputs_desc_;
bool BuildDLDevice(Device device);
bool BuildModel(const RuntimeOption& runtime_option);
bool InitInputAndOutputTensor();
std::vector<tvm::runtime::NDArray> input_tensor_;
std::vector<tvm::runtime::NDArray> output_tensor_;
FDDataType TVMTensorTypeToFDDataType(tvm::String type);
DLDataType FDDataTypeToDLDataType(FDDataType dtype);
};
} // namespace fastdeploy

View File

@@ -32,8 +32,10 @@ std::ostream& operator<<(std::ostream& out, const Backend& backend) {
out << "Backend::POROS";
} else if (backend == Backend::LITE) {
out << "Backend::PDLITE";
} else if(backend == Backend::HORIZONNPU){
} else if (backend == Backend::HORIZONNPU) {
out << "Backend::HORIZONNPU";
} else if (backend == Backend::TVM) {
out << "Backend::TVM";
} else {
out << "UNKNOWN-Backend";
}
@@ -88,8 +90,9 @@ std::ostream& operator<<(std::ostream& out, const ModelFormat& format) {
out << "ModelFormat::TORCHSCRIPT";
} else if (format == ModelFormat::HORIZON) {
out << "ModelFormat::HORIZON";
}
else {
} else if (format == ModelFormat::TVMFormat) {
out << "ModelFormat::TVMFormat";
} else {
out << "UNKNOWN-ModelFormat";
}
return out;
@@ -123,6 +126,9 @@ std::vector<Backend> GetAvailableBackends() {
#endif
#ifdef ENABLE_SOPHGO_BACKEND
backends.push_back(Backend::SOPHGOTPU);
#endif
#ifdef ENABLE_TVM_BACKEND
backends.push_back(Backend::TVM);
#endif
return backends;
}

View File

@@ -39,6 +39,7 @@ enum Backend {
RKNPU2, ///< RKNPU2, support RKNN format model, Rockchip NPU only
SOPHGOTPU, ///< SOPHGOTPU, support SOPHGO format model, Sophgo TPU only
HORIZONNPU, ///< HORIZONNPU, support Horizon format model, Horizon NPU
TVM, ///< TVMBackend, support TVM format model, CPU / Nvidia GPU
};
/**
@@ -74,6 +75,7 @@ enum ModelFormat {
TORCHSCRIPT, ///< Model with TorchScript format
SOPHGO, ///< Model with SOPHGO format
HORIZON, ///< Model with HORIZON format
TVMFormat, ///< Model with TVM format
};
/// Describle all the supported backends for specified model format
@@ -85,16 +87,17 @@ static std::map<ModelFormat, std::vector<Backend>>
{ModelFormat::RKNN, {Backend::RKNPU2}},
{ModelFormat::HORIZON, {Backend::HORIZONNPU}},
{ModelFormat::TORCHSCRIPT, {Backend::POROS}},
{ModelFormat::SOPHGO, {Backend::SOPHGOTPU}}
{ModelFormat::SOPHGO, {Backend::SOPHGOTPU}},
{ModelFormat::TVMFormat, {Backend::TVM}}
};
/// Describle all the supported backends for specified device
static std::map<Device, std::vector<Backend>>
s_default_backends_by_device = {
{Device::CPU, {Backend::LITE, Backend::PDINFER, Backend::ORT,
Backend::OPENVINO, Backend::POROS}},
Backend::OPENVINO, Backend::POROS, Backend::TVM}},
{Device::GPU, {Backend::LITE, Backend::PDINFER, Backend::ORT,
Backend::TRT, Backend::POROS}},
Backend::TRT, Backend::POROS, Backend::TVM}},
{Device::RKNPU, {Backend::RKNPU2}},
{Device::SUNRISENPU, {Backend::HORIZONNPU}},
{Device::IPU, {Backend::PDINFER}},

View File

@@ -53,6 +53,10 @@
#include "fastdeploy/runtime/backends/horizon/horizon_backend.h"
#endif
#ifdef ENABLE_TVM_BACKEND
#include "fastdeploy/runtime/backends/tvm/tvm_backend.h"
#endif
namespace fastdeploy {
bool AutoSelectBackend(RuntimeOption& option) {
@@ -159,10 +163,11 @@ bool Runtime::Init(const RuntimeOption& _option) {
CreateSophgoNPUBackend();
} else if (option.backend == Backend::POROS) {
CreatePorosBackend();
} else if (option.backend == Backend::HORIZONNPU){
} else if (option.backend == Backend::HORIZONNPU) {
CreateHorizonBackend();
}
else {
} else if (option.backend == Backend::TVM) {
CreateTVMBackend();
} else {
std::string msg = Str(GetAvailableBackends());
FDERROR << "The compiled FastDeploy only supports " << msg << ", "
<< option.backend << " is not supported now." << std::endl;
@@ -287,6 +292,19 @@ void Runtime::CreateOpenVINOBackend() {
<< "." << std::endl;
}
void Runtime::CreateTVMBackend() {
#ifdef ENABLE_TVM_BACKEND
backend_ = utils::make_unique<TVMBackend>();
FDASSERT(backend_->Init(option), "Failed to initialize TVM backend.");
#else
FDASSERT(false,
"TVMBackend is not available, please compiled with "
"ENABLE_TVM_BACKEND=ON.");
#endif
FDINFO << "Runtime initialized with Backend::TVM in " << option.device << "."
<< std::endl;
}
void Runtime::CreateOrtBackend() {
#ifdef ENABLE_ORT_BACKEND
backend_ = utils::make_unique<OrtBackend>();
@@ -342,13 +360,12 @@ void Runtime::CreateRKNPU2Backend() {
<< "." << std::endl;
}
void Runtime::CreateHorizonBackend(){
void Runtime::CreateHorizonBackend() {
#ifdef ENABLE_HORIZON_BACKEND
backend_ = utils::make_unique<HorizonBackend>();
FDASSERT(backend_->Init(option), "Failed to initialize Horizon backend.");
#else
FDASSERT(false,
"HorizonBackend is not available, please compiled with ",
FDASSERT(false, "HorizonBackend is not available, please compiled with ",
" ENABLE_HORIZON_BACKEND=ON.");
#endif
FDINFO << "Runtime initialized with Backend::HORIZONNPU in " << option.device

View File

@@ -118,6 +118,7 @@ struct FASTDEPLOY_DECL Runtime {
void CreateHorizonBackend();
void CreateSophgoNPUBackend();
void CreatePorosBackend();
void CreateTVMBackend();
std::unique_ptr<BaseBackend> backend_;
std::vector<FDTensor> input_tensors_;
std::vector<FDTensor> output_tensors_;

View File

@@ -70,9 +70,7 @@ void RuntimeOption::UseRKNPU2(fastdeploy::rknpu2::CpuName rknpu2_name,
device = Device::RKNPU;
}
void RuntimeOption::UseHorizon(){
device = Device::SUNRISENPU;
}
void RuntimeOption::UseHorizon() { device = Device::SUNRISENPU; }
void RuntimeOption::UseTimVX() {
device = Device::TIMVX;
@@ -84,8 +82,7 @@ void RuntimeOption::UseKunlunXin(int kunlunxin_id,
bool locked, bool autotune,
const std::string& autotune_file,
const std::string& precision,
bool adaptive_seqlen,
bool enable_multi_stream,
bool adaptive_seqlen, bool enable_multi_stream,
int64_t gm_default_size) {
#ifdef WITH_KUNLUNXIN
device = Device::KUNLUNXIN;
@@ -236,7 +233,7 @@ void RuntimeOption::UseLiteBackend() {
#endif
}
void RuntimeOption::UseHorizonNPUBackend(){
void RuntimeOption::UseHorizonNPUBackend() {
#ifdef ENABLE_HORIZON_BACKEND
backend = Backend::HORIZONNPU;
#else
@@ -524,4 +521,12 @@ void RuntimeOption::DisablePaddleTrtOPs(const std::vector<std::string>& ops) {
paddle_infer_option.DisableTrtOps(ops);
}
void RuntimeOption::UseTVMBackend() {
#ifdef ENABLE_TVM_BACKEND
backend = Backend::TVM;
#else
FDASSERT(false, "The FastDeploy didn't compile with TVMBackend.");
#endif
}
} // namespace fastdeploy

View File

@@ -31,6 +31,7 @@
#include "fastdeploy/runtime/backends/rknpu2/option.h"
#include "fastdeploy/runtime/backends/sophgo/option.h"
#include "fastdeploy/runtime/backends/tensorrt/option.h"
#include "fastdeploy/runtime/backends/tvm/option.h"
#include "fastdeploy/benchmark/option.h"
namespace fastdeploy {
@@ -160,6 +161,8 @@ struct FASTDEPLOY_DECL RuntimeOption {
LiteBackendOption paddle_lite_option;
/// Option to configure RKNPU2 backend
RKNPU2BackendOption rknpu2_option;
/// Option to configure TVM backend
TVMBackendOption tvm_option;
// \brief Set the profile mode as 'true'.
//
@@ -282,6 +285,7 @@ struct FASTDEPLOY_DECL RuntimeOption {
void UsePaddleBackend();
void UseLiteBackend();
void UseHorizonNPUBackend();
void UseTVMBackend();
};
} // namespace fastdeploy

View File

@@ -67,7 +67,7 @@ class FASTDEPLOY_DECL SOLOv2 : public PPDetBase {
const ModelFormat& model_format = ModelFormat::PADDLE)
: PPDetBase(model_file, params_file, config_file, custom_option,
model_format) {
valid_cpu_backends = { Backend::PDINFER};
valid_cpu_backends = {Backend::PDINFER};
valid_gpu_backends = {Backend::PDINFER, Backend::TRT};
initialized = Initialize();
}
@@ -92,7 +92,7 @@ class FASTDEPLOY_DECL PPYOLOE : public PPDetBase {
: PPDetBase(model_file, params_file, config_file, custom_option,
model_format) {
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT, Backend::PDINFER,
Backend::LITE};
Backend::LITE, Backend::TVM};
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
valid_timvx_backends = {Backend::LITE};
valid_kunlunxin_backends = {Backend::LITE};

58
tools/tvm/paddle2tvm.py Normal file
View File

@@ -0,0 +1,58 @@
import paddle
import tvm
from tvm import relay
from tvm.contrib import graph_executor
import os
import argparse
def get_config():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path", default="./picodet_l_320_coco_lcnet/model")
parser.add_argument(
"--shape_dict",
default={"image": [1, 3, 320, 320],
"scale_factor": [1, 2]})
parser.add_argument("--tvm_save_name", default="tvm_model")
parser.add_argument("--tvm_save_path", default="./tvm_save")
args = parser.parse_args()
return args
def read_model(model_path):
return paddle.jit.load(model_path)
def paddle_to_tvm(paddle_model,
shape_dict,
tvm_save_name="tvm_model",
tvm_save_path="./tvm_save"):
if isinstance(shape_dict, str):
shape_dict = eval(shape_dict)
mod, params = relay.frontend.from_paddle(paddle_model, shape_dict)
# 这里首先在PC的CPU上进行测试 所以使用LLVM进行导出
target = tvm.target.Target("llvm", host="llvm")
dev = tvm.cpu(0)
# 这里利用TVM构建出优化后模型的信息
with tvm.transform.PassContext(opt_level=2):
base_lib = relay.build_module.build(mod, target, params=params)
if not os.path.exists(tvm_save_path):
os.mkdir(tvm_save_path)
lib_save_path = os.path.join(tvm_save_path, tvm_save_name + ".so")
base_lib.export_library(lib_save_path)
param_save_path = os.path.join(tvm_save_path,
tvm_save_name + ".params")
with open(param_save_path, 'wb') as fo:
fo.write(relay.save_param_dict(base_lib.get_params()))
module = graph_executor.GraphModule(base_lib['default'](dev))
module.load_params(relay.save_param_dict(base_lib.get_params()))
print("export success")
if __name__ == "__main__":
config = get_config()
paddle_model = read_model(config.model_path)
shape_dict = config.shape_dict
paddle_to_tvm(paddle_model, shape_dict, config.tvm_save_name,
config.tvm_save_path)