diff --git a/CMakeLists.txt b/CMakeLists.txt index 04954acb1..314348a00 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -262,7 +262,6 @@ endif() if(ENABLE_POROS_BACKEND) set(CMAKE_CXX_STANDARD 14) - add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) add_definitions(-DENABLE_POROS_BACKEND) list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_POROS_SRCS}) include(${PROJECT_SOURCE_DIR}/cmake/poros.cmake) @@ -279,20 +278,7 @@ if(ENABLE_POROS_BACKEND) else () message(STATUS "site-packages: ${Python3_SITELIB}") endif () - # find pytorch - find_package(Torch ${PYTORCH_MINIMUM_VERSION} REQUIRED HINTS ${Python3_SITELIB}) - include_directories(${TORCH_INCLUDE_DIRS}) include_directories(${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/poros/common) - list(APPEND DEPEND_LIBS ${TORCH_LIBRARY}) - if(NOT EXISTS "${CMAKE_CURRENT_BINARY_DIR}/third_libs/install/torch") - file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/third_libs/install/torch") - endif() - if(EXISTS "${CMAKE_CURRENT_BINARY_DIR}/third_libs/install/torch/lib") - file(REMOVE_RECURSE "${CMAKE_CURRENT_BINARY_DIR}/third_libs/install/torch/lib") - endif() - find_package(Python COMPONENTS Interpreter Development REQUIRED) - message(STATUS "Copying ${TORCH_INSTALL_PREFIX}/lib to ${CMAKE_CURRENT_BINARY_DIR}/third_libs/install/torch/lib ...") - execute_process(COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/copy_directory.py ${TORCH_INSTALL_PREFIX}/lib ${CMAKE_CURRENT_BINARY_DIR}/third_libs/install/torch/lib) # find trt if(NOT WITH_GPU) message(FATAL_ERROR "While -DENABLE_POROS_BACKEND=ON, must set -DWITH_GPU=ON, but now it's OFF") diff --git a/FastDeploy.cmake.in b/FastDeploy.cmake.in old mode 100644 new mode 100755 index 44e7eb8a6..e148e5121 --- a/FastDeploy.cmake.in +++ b/FastDeploy.cmake.in @@ -121,8 +121,10 @@ endif() if(ENABLE_POROS_BACKEND) find_library(POROS_LIB poros ${CMAKE_CURRENT_LIST_DIR}/third_libs/install/poros/lib NO_DEFAULT_PATH) - list(APPEND FASTDEPLOY_LIBS ${POROS_LIB}) - list(APPEND FASTDEPLOY_INCS ${CMAKE_CURRENT_LIST_DIR}/third_libs/install/poros/include) + find_library(TORCH_LIB torch ${CMAKE_CURRENT_LIST_DIR}/third_libs/install/torch/lib NO_DEFAULT_PATH) + set(TORCH_INCLUDE "${CMAKE_CURRENT_LIST_DIR}/third_libs/install/torch/include") + list(APPEND FASTDEPLOY_LIBS ${POROS_LIB} ${TORCH_LIB}) + list(APPEND FASTDEPLOY_INCS ${CMAKE_CURRENT_LIST_DIR}/third_libs/install/poros/include ${TORCH_INCLUDE}) endif() if(WITH_GPU) diff --git a/cmake/poros.cmake b/cmake/poros.cmake index a457f9181..894f76b37 100755 --- a/cmake/poros.cmake +++ b/cmake/poros.cmake @@ -47,11 +47,10 @@ elseif(APPLE) else() if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64") message(FATAL_ERROR "Poros Backend doesn't support linux aarch64 now.") - set(POROS_FILE "poros-linux-aarch64-${POROS_VERSION}.tgz") else() - set(POROS_FILE "poros-linux-x64-${POROS_VERSION}.tgz") + message(FATAL_ERROR "Poros currently only provides precompiled packages for the GPU version.") if(WITH_GPU) - set(POROS_FILE "poros-linux-x64-gpu-${POROS_VERSION}.tgz") + set(POROS_FILE "poros_manylinux_torch1.12.1_cu116_trt8.4_gcc82-${POROS_VERSION}.tar.gz") endif() endif() endif() @@ -74,3 +73,18 @@ add_library(external_poros STATIC IMPORTED GLOBAL) set_property(TARGET external_poros PROPERTY IMPORTED_LOCATION ${POROS_COMPILE_LIB}) add_dependencies(external_poros ${POROS_PROJECT}) + +# Download libtorch.so with ABI=1 +set(TORCH_URL_BASE "https://bj.bcebos.com/fastdeploy/third_libs/") +set(TORCH_FILE "libtorch-cxx11-abi-shared-with-deps-1.12.1%2Bcu116.zip") +set(TROCH_URL "${TORCH_URL_BASE}${TORCH_FILE}") +message(STATUS "Use the default Torch lib from: ${TORCH_URL}") +download_and_decompress(${TORCH_URL} ${CMAKE_CURRENT_BINARY_DIR}/${TORCH_FILE} ${THIRD_PARTY_PATH}/install) +if(EXISTS ${THIRD_PARTY_PATH}/install/torch) + file(REMOVE_RECURSE ${THIRD_PARTY_PATH}/install/torch) +endif() +file(RENAME ${THIRD_PARTY_PATH}/install/libtorch/ ${THIRD_PARTY_PATH}/install/torch) +set(TORCH_INCLUDE_DIRS ${THIRD_PARTY_PATH}/install/torch/include) +find_library(TORCH_LIBRARY torch ${THIRD_PARTY_PATH}/install/torch/lib NO_DEFAULT_PATH) +include_directories(${TORCH_INCLUDE_DIRS}) +list(APPEND DEPEND_LIBS ${TORCH_LIBRARY}) diff --git a/examples/runtime/README.md b/examples/runtime/README.md index 2f739b860..a4fb921c7 100755 --- a/examples/runtime/README.md +++ b/examples/runtime/README.md @@ -13,6 +13,7 @@ FastDeploy Runtime 推理示例如下 | python/infer_onnx_openvino.py | Python | Deploy ONNX model with OpenVINO(CPU) | | python/infer_onnx_tensorrt.py | Python | Deploy ONNX model with TensorRT(GPU) | | python/infer_onnx_onnxruntime.py | Python | Deploy ONNX model with ONNX Runtime(CPU/GPU) | +| python/infer_torchscript_poros.py | Python | Deploy TorchScript model with Poros Runtime(CPU/GPU) | ## C++ 示例 @@ -25,6 +26,7 @@ FastDeploy Runtime 推理示例如下 | cpp/infer_onnx_openvino.cc | C++ | Deploy ONNX model with OpenVINO(CPU) | | cpp/infer_onnx_tensorrt.cc | C++ | Deploy ONNX model with TensorRT(GPU) | | cpp/infer_onnx_onnxruntime.cc | C++ | Deploy ONNX model with ONNX Runtime(CPU/GPU) | +| cpp/infer_torchscript_poros.cc | C++ | Deploy TorchScript model with Poros Runtime(CPU/GPU) | ## 详细部署文档 diff --git a/examples/runtime/cpp/infer_torchscript_poros.cc b/examples/runtime/cpp/infer_torchscript_poros.cc new file mode 100644 index 000000000..d9bf4ebad --- /dev/null +++ b/examples/runtime/cpp/infer_torchscript_poros.cc @@ -0,0 +1,117 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/runtime.h" + +namespace fd = fastdeploy; + +void build_test_data(std::vector> &prewarm_datas, bool is_dynamic) { + if (is_dynamic == false) { + std::vector inputs_data; + inputs_data.resize(1 * 3 * 224 * 224); + for (size_t i = 0; i < inputs_data.size(); ++i) { + inputs_data[i] = std::rand() % 1000 / 1000.0f; + } + prewarm_datas[0][0].Resize({1, 3, 224, 224}, fd::FDDataType::FP32); + fd::FDTensor::CopyBuffer(prewarm_datas[0][0].Data(), + inputs_data.data(), + prewarm_datas[0][0].Nbytes()); + return; + } + //max + std::vector inputs_data_max; + inputs_data_max.resize(1 * 3 * 224 * 224); + for (size_t i = 0; i < inputs_data_max.size(); ++i) { + inputs_data_max[i] = std::rand() % 1000 / 1000.0f; + } + prewarm_datas[0][0].Resize({1, 3, 224, 224}, fd::FDDataType::FP32); + fd::FDTensor::CopyBuffer(prewarm_datas[0][0].Data(), + inputs_data_max.data(), + prewarm_datas[0][0].Nbytes()); + //min + std::vector inputs_data_min; + inputs_data_min.resize(1 * 3 * 224 * 224); + for (size_t i = 0; i < inputs_data_min.size(); ++i) { + inputs_data_min[i] = std::rand() % 1000 / 1000.0f; + } + prewarm_datas[1][0].Resize({1, 3, 224, 224}, fd::FDDataType::FP32); + fd::FDTensor::CopyBuffer(prewarm_datas[1][0].Data(), + inputs_data_min.data(), + prewarm_datas[1][0].Nbytes()); + + //opt + std::vector inputs_data_opt; + inputs_data_opt.resize(1 * 3 * 224 * 224); + for (size_t i = 0; i < inputs_data_opt.size(); ++i) { + inputs_data_opt[i] = std::rand() % 1000 / 1000.0f; + } + prewarm_datas[2][0].Resize({1, 3, 224, 224}, fd::FDDataType::FP32); + fd::FDTensor::CopyBuffer(prewarm_datas[2][0].Data(), + inputs_data_opt.data(), + prewarm_datas[2][0].Nbytes()); + +} + +int main(int argc, char* argv[]) { + // prewarm_datas + bool is_dynamic = true; + std::vector> prewarm_datas; + if (is_dynamic) { + prewarm_datas.resize(3); + prewarm_datas[0].resize(1); + prewarm_datas[1].resize(1); + prewarm_datas[2].resize(1); + } else { + prewarm_datas.resize(1); + prewarm_datas[0].resize(1); + } + build_test_data(prewarm_datas, is_dynamic); + std::string model_file = "std_resnet50_script.pt"; + + // setup option + fd::RuntimeOption runtime_option; + runtime_option.SetModelPath(model_file, "", fd::ModelFormat::TORCHSCRIPT); + runtime_option.UsePorosBackend(); + runtime_option.UseGpu(0); + runtime_option.is_dynamic = true; + + // Compile runtime + std::unique_ptr runtime = + std::unique_ptr(new fd::Runtime()); + if (!runtime->Compile(prewarm_datas, runtime_option)) { + std::cerr << "--- Init FastDeploy Runitme Failed! " + << "\n--- Model: " << model_file << std::endl; + return -1; + } else { + std::cout << "--- Init FastDeploy Runitme Done! " + << "\n--- Model: " << model_file << std::endl; + } + + std::vector input_tensors; + input_tensors.resize(1); + std::vector output_tensors; + output_tensors.resize(1); + + std::vector inputs_data; + inputs_data.resize(1 * 3 * 224 * 224); + for (size_t i = 0; i < inputs_data.size(); ++i) { + inputs_data[i] = std::rand() % 1000 / 1000.0f; + } + input_tensors[0].SetExternalData({1, 3, 224, 224}, fd::FDDataType::FP32, inputs_data.data()); + + runtime->Infer(input_tensors, &output_tensors); + + output_tensors[0].PrintInfo(); + return 0; +} \ No newline at end of file diff --git a/examples/runtime/python/infer_torchscript_poros.py b/examples/runtime/python/infer_torchscript_poros.py new file mode 100644 index 000000000..de31061f0 --- /dev/null +++ b/examples/runtime/python/infer_torchscript_poros.py @@ -0,0 +1,62 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from fastdeploy import ModelFormat + +import fastdeploy as fd +import numpy as np + + +def load_example_input_datas(): + """prewarm datas""" + data_list = [] + # max size + input_1 = np.ones((1, 3, 224, 224), dtype=np.float32) + max_inputs = [input_1] + data_list.append(tuple(max_inputs)) + + # min size + input_1 = np.ones((1, 3, 224, 224), dtype=np.float32) + min_inputs = [input_1] + data_list.append(tuple(min_inputs)) + + # opt size + input_1 = np.ones((1, 3, 224, 224), dtype=np.float32) + opt_inputs = [input_1] + data_list.append(tuple(opt_inputs)) + + return data_list + + +if __name__ == '__main__': + # prewarm_datas + prewarm_datas = load_example_input_datas() + # download model + model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/std_resnet50_script.pt" + fd.download(model_url, path=".") + + option = fd.RuntimeOption() + option.use_gpu(0) + option.use_poros_backend() + option.set_model_path( + "std_resnet50_script.pt", model_format=ModelFormat.TORCHSCRIPT) + option.is_dynamic = True + # compile + runtime = fd.Runtime(option) + runtime.compile(prewarm_datas) + + # infer + input_data_0 = np.random.rand(1, 3, 224, 224).astype("float32") + result = runtime.forward(input_data_0) + print(result[0].shape) diff --git a/poros/CMakeLists.txt b/poros/CMakeLists.txt old mode 100644 new mode 100755 index 3245bbae3..0ed93f256 --- a/poros/CMakeLists.txt +++ b/poros/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.21) project(poros) set(CMAKE_CXX_STANDARD 14) -add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) +add_definitions(-D_GLIBCXX_USE_CXX11_ABI=1) option(BUILD_STATIC "build lib${PROJECT_NAME}.a static lib" OFF)