Support Poros Backend (#188)

* Add poros backend

* Add torch lib

* Add python3 lib

* set c++ 14 for poros

* fixed bugs

* fixed grammer bugs

* fixed grammer bugs

* fixed code bugs

* fixed code bugs

* fixed CreatePorosValue bug

* Add AtType2String for Log

* fixed trt_option

* fixed poros.cmake path

* fixed grammer bug

* fixed grammer bug

* fixed ambiguous reference

* fixed ambiguous reference

* fixed reference error

* fixed include files

* rm ENABLE_TRT_BACKEND in poros

* update CMakeLists.txt

* fixed CMakeLists.txt

* Add libtorch.so in CMakeLists.txt

* Fixed CMakeLists.txt

* Fixed CMakeLists.txt

* Fixed copy bug

* Fixed copy bug

* Fixed copy bug

* Fixed Cmake

* Fixed Cmake

* debug

* debug

* debug

* debug

* debug

* debug

* debug utils

* debug utils

* copy to cpu

* rm log info

* test share mem

* test share mem

* test share mem

* test multi outputs

* test multi outputs

* test multi outputs

* test multi outputs

* test multi outputs

* test multi outputs

* test multi outputs

* time cost

* time cost

* fixed bug

* time collect

* mem copy

* mem copy

* rm time log

* rm share mem

* fixed multi inputs bug

* add set_input_dtypes func

* add SetInputDtypes

* fixed bug

* fixed bug

* fixed prewarm data order

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* fixed bug

* Add compile func

* Add compile func

* Add compile func

* Add is_dynamic option

* Add is_dynamic option

* Add is_dynamic option

* Add is_dynamic option

* rm infer log

* add cuda11.6 poros lib

* fixed bug

* fixed bug

* fixed multi outputs

* fixed multi outputs

* fixed multi outputs

* fixed multi outputs

* fixed multi outputs

* fixed multi outputs

* fixed multi outputs

* fixed multi outputs

* fixed multi outputs

* fixed multi outputs

* fixed multi outputs

* rm logs

* test

* test

* test

* add test log

* add test log

* add test log

* add test log

* support cpu

* support cpu

* support cpu

* support cpu

* support member variable definition

* rm useless log

* fixed name

* resolve conflict

* resolve conflict

* resolve conflict

* fixed cmake

* add GetInputInfos&GetOutputInfos

* add GetInputInfos&GetOutputInfos

* fixed bug

* fixed runtime.py

* add compile func

* add np

* deal with comments

* rm to_inter func

* add property
This commit is contained in:
WJJ1995
2022-10-17 15:28:12 +08:00
committed by GitHub
parent c8db2dd1ef
commit f5c94e5471
19 changed files with 1333 additions and 12 deletions

View File

@@ -0,0 +1,167 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <algorithm>
#include <unordered_map>
#include <set>
#include "torch/script.h"
#include "iengine.h"
#include "poros_module.h"
namespace baidu {
namespace mirana {
namespace poros {
/**
* @brief compile graph
*
* @param [in] module : original module
* @param [in] input_ivalues : prewarm datas
* @param [in] options : Inference options
* @return porosmodule
* @retval !nullptr => succeed nullptr => failed
**/
std::unique_ptr<PorosModule> Compile(const torch::jit::Module& module,
const std::vector<std::vector<c10::IValue> >& prewarm_datas,
const PorosOptions& options);
class Compiler {
public:
typedef std::unordered_map<const torch::jit::Node*, IEngine*> engine_map_t;
typedef std::vector<std::vector<c10::IValue> > ivalue_vec_t;
Compiler() : _origin_module(NULL) {}
~Compiler();
/**
* @brief initial Compiler
*
* @param [in] options : poros options
* @return int
* @retval 0 => succeed <0 => failed
**/
int init(const PorosOptions& options);
/**
* @brief compile whole graph
*
* @param [in] origin_module
* @param [in] prewarm_datas : ivalue_vec_t, vector of IValue
* @param [out] optimized_module : optimized graph
* @return int
* @retval 0 => succeed <0 => failed
**/
int compile(const torch::jit::Module& origin_module,
const ivalue_vec_t& prewarm_datas,
torch::jit::Module* optimized_module);
private:
/**
* @brief preprocess this calculation graph
*
* @param [in] prewarm_datas : ivalue_vec_t, vector of IValue
* @param [out] graph : preprcessed graph
* @return int
* @retval 0 => succeed <0 => failed
**/
int preprocess_graph(const ivalue_vec_t& prewarm_datas, std::shared_ptr<torch::jit::Graph>& graph);
/**
* @brief segement this calculation graph
*
* @param [in/out] graph
* @return int
* @retval 0 => succeed <0 => failed
**/
int segment_graph(std::shared_ptr<torch::jit::Graph>& graph);
// Split subgraphblock)
// The divided subgraph, as a subgraph, is associated with the block
int segment_block(torch::jit::Block& block, IEngine* engine, int current_depth);
// Subgraph optimization
/**
* @brief Subgraph optimization
*
* @param [in] prewarm_datas : ivalue_vec_t, vector of IValue
* @param [in] opt_graph : ivalue_vec_t, vector of IValue
* @param [out] optimized_module : optimized graph
* @return int
* @retval 0 => succeed <0 => failed
**/
int optimize_subgraph(const ivalue_vec_t& prewarm_datas,
const std::shared_ptr<torch::jit::Graph>& opt_graph,
torch::jit::Module* optimized_module);
// Subgraph optimization(block)
int optimize_subblock(torch::jit::Block* block,
torch::jit::Module* optimized_module);
/**
* @brief Compile the subgraph into a new graph based on the engine
*
* @param [in] engine : The engine used by the subgraph
* @param [in] subgraph_node : Subgraph node
* @return [out] module : Transformed model
* @retval 0 => succeed <0 => failed
**/
int transform(IEngine* engine, torch::jit::Node& subgraph_node,
torch::jit::Module& module);
/**
* @brief Select engine based on subgraph and options
*
* @param [in] node : Jit Node
* @return int
* @retval 0 => succeed <0 => failed
**/
IEngine* select_engine(const torch::jit::Node* n);
/**
* @brief destory
*
* @return void
**/
void close();
private:
int _max_segment_depth{5}; // Maximum subgraph segmentation depth
ivalue_vec_t _prewarm_datas; // Prewarm datas
PorosOptions _options;
engine_map_t _engine_map; // The engine used to record the subgraph
const torch::jit::Module* _origin_module; // Origin_module
std::atomic<int> _engine_index = {0}; // Record engine index
};
/**
* @brief compile graph, internal use
*
* @param [in] module : Origin module
* @param [in] input_ivalues : Prewarm datas
* @param [in] options : Inference options
* @return optimized_module
* @retval !nullptr => succeed nullptr => failed
**/
std::unique_ptr<torch::jit::Module> CompileGraph(const torch::jit::Module& module,
const std::vector<std::vector<c10::IValue> >& prewarm_datas,
const PorosOptions& options);
} // namespace poros
} // namespace mirana
} // namespace baidu

View File

@@ -0,0 +1,84 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
//from pytorch
#include "torch/script.h"
#include "torch/csrc/jit/ir/ir.h"
#include "ATen/core/interned_strings.h"
#include "plugin_create.h"
namespace baidu {
namespace mirana {
namespace poros {
/**
* the base engine class
* every registered engine should inherit from this IEngine
**/
struct PorosGraph {
torch::jit::Graph* graph = NULL;
torch::jit::Node* node = NULL;
};
typedef uint64_t EngineID;
class IEngine : public IPlugin, public torch::CustomClassHolder{
public:
virtual ~IEngine() {}
/**
* @brief init, initialization must be successful if the init is successful
* @return int
* @retval 0 => success, <0 => fail
**/
virtual int init() = 0;
/**
* @brief During compilation, the subgraph is converted into the graph structure of the corresponding engine and stored inside the engine, so that the execute_engine at runtime can be called
* @param [in] sub_graph : subgraph
* @return [res]int
* @retval 0 => success, <0 => fail
**/
virtual int transform(const PorosGraph& sub_graph) = 0;
/**
* @brief Subgraph execution period logic
* @param [in] inputs : input tensor
* @return [res] output tensor
**/
virtual std::vector<at::Tensor> excute_engine(const std::vector<at::Tensor>& inputs) = 0;
virtual void register_module_attribute(const std::string& name, torch::jit::Module& module) = 0;
// Logo
virtual const std::string who_am_i() = 0;
// Whether the node is supported by the current engine
bool is_node_supported(const torch::jit::Node* node);
public:
std::pair<uint64_t, uint64_t> _num_io; // Number of input/output parameters
EngineID _id;
};
} // namespace poros
} // namespace mirana
} // namespace baidu

View File

@@ -0,0 +1,65 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <unordered_map>
#include <string>
namespace baidu {
namespace mirana {
namespace poros {
class IPlugin {
public:
virtual ~IPlugin() {}
virtual const std::string who_am_i() = 0;
};
typedef IPlugin* (*plugin_creator_t)();
typedef std::unordered_map<std::string, plugin_creator_t> plugin_creator_map_t;
IPlugin* create_plugin(const std::string& plugin_name);
IPlugin* create_plugin(const std::string& plugin_name, const plugin_creator_map_t& plugin_creator_map);
void create_all_plugins(const plugin_creator_map_t& plugin_creator_map,
std::unordered_map<std::string, IPlugin*>& plugin_m);
//void create_all_plugins(std::unordered_map<std::string, IPlugin*>& plugin_m);
template <typename PluginType>
IPlugin* default_plugin_creator() {
return new (std::nothrow)PluginType;
}
void register_plugin_creator(const std::string& plugin_name, plugin_creator_t creator);
void register_plugin_creator(const std::string& plugin_name,
plugin_creator_t creator, plugin_creator_map_t& plugin_creator_map);
template <typename PluginType>
void register_plugin_class(const std::string& plugin_name) {
return register_plugin_creator(plugin_name, default_plugin_creator<PluginType>);
}
// This version is recommended
template <typename PluginType>
void register_plugin_class(const std::string& plugin_name, plugin_creator_map_t& plugin_creator_map) {
return register_plugin_creator(plugin_name, default_plugin_creator<PluginType>, plugin_creator_map);
}
}//poros
}//mirana
}//baidu
/* vim: set ts=4 sw=4 sts=4 tw=100 */

View File

@@ -0,0 +1,67 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "torch/script.h"
#include "torch/csrc/jit/jit_log.h"
// #include "ATen/Context.h"
namespace baidu {
namespace mirana {
namespace poros {
enum Device : int8_t {
GPU = 0,
CPU,
XPU,
UNKNOW
};
struct PorosOptions {
Device device = GPU;
bool debug = false;
bool use_fp16 = false;
bool is_dynamic = false;
bool long_to_int = true;
uint64_t max_workspace_size = 1ULL << 30;
int32_t device_id = -1;
int32_t unconst_ops_thres = -1;
bool use_nvidia_tf32 = false;
};
class PorosModule : public torch::jit::Module {
public:
PorosModule(torch::jit::Module module) : torch::jit::Module(module) {
}
~PorosModule() = default;
void to_device(Device device){
_options.device = device;
}
//c10::IValue forward(std::vector<c10::IValue> inputs);
//void save(const std::string& filename);
public:
PorosOptions _options;
};
//via porosmodule.save
std::unique_ptr<PorosModule> Load(const std::string& filename, const PorosOptions& options);
} // namespace poros
} // namespace mirana
} // namespace baidu

View File

@@ -0,0 +1,240 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/backends/poros/poros_backend.h"
#include <sys/time.h>
namespace fastdeploy {
TensorInfo PorosBackend::GetInputInfo(int index) {
// eager mode cann't obtain input information before infer
TensorInfo info_input;
return info_input;
}
TensorInfo PorosBackend::GetOutputInfo(int index) {
// eager mode cann't obtain output information before infer
TensorInfo info_output;
return info_output;
}
std::vector<TensorInfo> PorosBackend::GetInputInfos() {
// eager mode cann't obtain inputs information before infer
std::vector<TensorInfo> info_inputs;
return info_inputs;
}
std::vector<TensorInfo> PorosBackend::GetOutputInfos() {
// eager mode cann't obtain outputs information before infer
std::vector<TensorInfo> info_outputs;
return info_outputs;
}
void PorosBackend::BuildOption(const PorosBackendOption& option) {
_options.device = option.use_gpu ? baidu::mirana::poros::Device::GPU
: baidu::mirana::poros::Device::CPU;
_options.long_to_int = option.long_to_int;
_options.use_nvidia_tf32 = option.use_nvidia_tf32;
_options.device_id = option.gpu_id;
_options.unconst_ops_thres = option.unconst_ops_thres;
_options.is_dynamic = option.is_dynamic;
_options.max_workspace_size = option.max_workspace_size;
_options.use_fp16 = option.enable_fp16;
return;
}
bool PorosBackend::Compile(const std::string& model_file,
std::vector<std::vector<FDTensor>>& prewarm_tensors,
const PorosBackendOption& option) {
if (initialized_) {
FDERROR << "PorosBackend is already initlized, cannot initialize again."
<< std::endl;
return false;
}
BuildOption(option);
torch::jit::Module mod;
mod = torch::jit::load(model_file);
mod.eval();
if (option.use_gpu) {
mod.to(at::kCUDA);
} else {
mod.to(at::kCPU);
}
// get inputs_nums and outputs_nums
auto graph = mod.get_method("forward").graph();
auto inputs = graph->inputs();
// remove self node
_numinputs = inputs.size() - 1;
// FDTensor to at::Tensor
std::vector<std::vector<c10::IValue>> prewarm_datas;
bool is_backend_cuda = option.use_gpu ? true : false;
for (size_t i = 0; i < prewarm_tensors.size(); ++i) {
std::vector<c10::IValue> prewarm_data;
for (size_t j = 0; j < prewarm_tensors[i].size(); ++j) {
auto tensor = CreatePorosValue(prewarm_tensors[i][j], is_backend_cuda);
prewarm_data.push_back(tensor);
}
prewarm_datas.push_back(prewarm_data);
}
// get outputs nums
auto temp_result = mod.forward(prewarm_datas[0]);
size_t outputs_nums = 0;
if (temp_result.isTensor()) {
outputs_nums += 1;
} else if (temp_result.isTuple()) {
auto temp_result_tuple = temp_result.toTuple();
for (size_t i = 0; i < temp_result_tuple->elements().size(); ++i) {
auto poros_tensor = temp_result_tuple->elements()[i];
if (poros_tensor.isTensor()) {
outputs_nums += 1;
} else if (poros_tensor.isList()) {
auto poros_tensor_list = poros_tensor.toList();
outputs_nums += poros_tensor_list.size();
} else if (poros_tensor.isTuple()) {
auto poros_tensor_tuple = poros_tensor.toTuple();
outputs_nums += poros_tensor_tuple->elements().size();
} else {
continue;
}
}
}
_numoutputs = outputs_nums;
_poros_module = baidu::mirana::poros::Compile(mod, prewarm_datas, _options);
if (_poros_module == nullptr) {
FDERROR << "PorosBackend initlize Failed, try initialize again."
<< std::endl;
return false;
}
initialized_ = true;
return true;
}
bool PorosBackend::InitFromTorchScript(const std::string& model_file,
const PorosBackendOption& option) {
if (initialized_) {
FDERROR << "PorosBackend is already initlized, cannot initialize again."
<< std::endl;
return false;
}
if (option.poros_file != "") {
std::ifstream fin(option.poros_file, std::ios::binary | std::ios::in);
if (fin) {
FDINFO << "Detect compiled Poros file in " << option.poros_file
<< ", will load it directly." << std::endl;
fin.close();
return InitFromPoros(option.poros_file, option);
}
}
BuildOption(option);
torch::jit::Module mod;
mod = torch::jit::load(model_file);
mod.eval();
if (option.use_gpu) {
mod.to(at::kCUDA);
} else {
mod.to(at::kCPU);
}
// get inputs_nums and outputs_nums
auto graph = mod.get_method("forward").graph();
auto inputs = graph->inputs();
// remove self node
_numinputs = inputs.size() - 1;
auto outputs = graph->outputs();
_numoutputs = outputs.size();
_poros_module = baidu::mirana::poros::Compile(mod, _prewarm_datas, _options);
if (_poros_module == nullptr) {
FDERROR << "PorosBackend initlize Failed, try initialize again."
<< std::endl;
return false;
}
initialized_ = true;
return true;
}
bool PorosBackend::InitFromPoros(const std::string& model_file,
const PorosBackendOption& option) {
if (initialized_) {
FDERROR << "PorosBackend is already initlized, cannot initialize again."
<< std::endl;
return false;
}
BuildOption(option);
_poros_module = baidu::mirana::poros::Load(model_file, _options);
if (_poros_module == nullptr) {
FDERROR << "PorosBackend initlize Failed, try initialize again."
<< std::endl;
return false;
}
// get inputs_nums and outputs_nums
auto graph = _poros_module->get_method("forward").graph();
auto inputs = graph->inputs();
// remove self node
_numinputs = inputs.size() - 1;
auto outputs = graph->outputs();
_numoutputs = outputs.size();
initialized_ = true;
return true;
}
bool PorosBackend::Infer(std::vector<FDTensor>& inputs,
std::vector<FDTensor>* outputs) {
// Convert FD Tensor to PyTorch Tensor
std::vector<torch::jit::IValue> poros_inputs;
bool is_backend_cuda =
_options.device == baidu::mirana::poros::Device::GPU ? true : false;
for (size_t i = 0; i < inputs.size(); ++i) {
poros_inputs.push_back(CreatePorosValue(inputs[i], is_backend_cuda));
}
// Infer
auto poros_outputs = _poros_module->forward(poros_inputs);
// Convert PyTorch Tensor to FD Tensor
if (poros_outputs.isTensor()) {
CopyTensorToCpu(poros_outputs.toTensor(), &((*outputs)[0]),
is_backend_cuda);
} else if (poros_outputs.isTuple()) {
// deal with multi outputs
auto poros_outputs_tuple = poros_outputs.toTuple();
size_t index = 0;
for (size_t i = 0; i < poros_outputs_tuple->elements().size(); ++i) {
auto poros_tensor = poros_outputs_tuple->elements()[i];
if (poros_tensor.isTensor()) {
CopyTensorToCpu(poros_tensor.toTensor(), &((*outputs)[index]),
is_backend_cuda);
index += 1;
} else if (poros_tensor.isList()) {
auto poros_tensor_list = poros_tensor.toList();
for (const auto list_idx : c10::irange(0, poros_tensor_list.size())) {
const auto& elt = poros_tensor_list.get(list_idx);
CopyTensorToCpu(elt.toTensor(), &((*outputs)[index]),
is_backend_cuda);
index += 1;
}
} else if (poros_tensor.isTuple()) {
auto poros_tensor_tuple = poros_tensor.toTuple();
for (size_t j = 0; j < poros_tensor_tuple->elements().size(); ++j) {
CopyTensorToCpu(poros_tensor_tuple->elements()[j].toTensor(),
&((*outputs)[index]), is_backend_cuda);
index += 1;
}
} else {
continue;
}
}
} else {
FDERROR << "Convert to FDTensor Failed!!!!!" << std::endl;
}
return true;
}
} // namespace fastdeploy

View File

@@ -0,0 +1,107 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "fastdeploy/backends/backend.h"
#include "fastdeploy/backends/poros/common/compile.h"
#include "fastdeploy/backends/poros/common/poros_module.h"
namespace fastdeploy {
struct PorosBackendOption {
#ifdef WITH_GPU
bool use_gpu = true;
#else
bool use_gpu = false;
#endif
int gpu_id = 0;
bool long_to_int = true;
// There is calculation precision in tf32 mode on A10, it can bring some
// performance improvement, but there may be diff
bool use_nvidia_tf32 = false;
// Threshold for the number of non-const ops
int32_t unconst_ops_thres = -1;
std::string poros_file = "";
std::vector<FDDataType> prewarm_datatypes = {FDDataType::FP32};
// TRT options
bool enable_fp16 = false;
bool enable_int8 = false;
bool is_dynamic = false;
size_t max_batch_size = 32;
size_t max_workspace_size = 1 << 30;
};
// Convert data type from fastdeploy to poros
at::ScalarType GetPorosDtype(const FDDataType& fd_dtype);
// Convert data type from poros to fastdeploy
FDDataType GetFdDtype(const at::ScalarType& dtype);
// at::ScalarType to std::string for FDERROR
std::string AtType2String(const at::ScalarType& dtype);
// Create at::Tensor
// is_backend_cuda specify if Poros use GPU Device
// While is_backend_cuda = true, and tensor.device = Device::GPU
at::Tensor CreatePorosValue(FDTensor& tensor, bool is_backend_cuda = false);
// Copy memory data from at::Tensor to fastdeploy::FDTensor
void CopyTensorToCpu(const at::Tensor& tensor, FDTensor* fd_tensor,
bool is_backend_cuda = false);
class PorosBackend : public BaseBackend {
public:
PorosBackend() {}
virtual ~PorosBackend() = default;
void BuildOption(const PorosBackendOption& option);
bool InitFromTorchScript(
const std::string& model_file,
const PorosBackendOption& option = PorosBackendOption());
bool InitFromPoros(const std::string& model_file,
const PorosBackendOption& option = PorosBackendOption());
bool Compile(const std::string& model_file,
std::vector<std::vector<FDTensor>>& prewarm_tensors,
const PorosBackendOption& option = PorosBackendOption());
bool Infer(std::vector<FDTensor>& inputs, std::vector<FDTensor>* outputs);
int NumInputs() const { return _numinputs; }
int NumOutputs() const { return _numoutputs; }
TensorInfo GetInputInfo(int index) override;
TensorInfo GetOutputInfo(int index) override;
std::vector<TensorInfo> GetInputInfos() override;
std::vector<TensorInfo> GetOutputInfos() override;
private:
baidu::mirana::poros::PorosOptions _options;
std::unique_ptr<baidu::mirana::poros::PorosModule> _poros_module;
std::vector<std::vector<c10::IValue>> _prewarm_datas;
int _numinputs = 1;
int _numoutputs = 1;
};
} // namespace fastdeploy

View File

@@ -0,0 +1,186 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/backends/poros/poros_backend.h"
#ifdef WITH_GPU
#include <cuda_runtime_api.h>
#endif
namespace fastdeploy {
std::string AtType2String(const at::ScalarType& dtype) {
std::string out;
switch (dtype) {
case at::kByte:
out = "at::kByte";
break;
case at::kChar:
out = "at::kChar";
break;
case at::kShort:
out = "at::kShort";
break;
case at::kInt:
out = "at::kInt";
break;
case at::kLong:
out = "at::kLong";
break;
case at::kHalf:
out = "at::kHalf";
break;
case at::kFloat:
out = "at::kFloat";
break;
case at::kDouble:
out = "at::kDouble";
break;
default:
out = "at::UNKNOWN";
}
return out;
}
at::ScalarType GetPorosDtype(const FDDataType& fd_dtype) {
if (fd_dtype == FDDataType::FP32) {
return at::kFloat;
} else if (fd_dtype == FDDataType::FP64) {
return at::kDouble;
} else if (fd_dtype == FDDataType::INT32) {
return at::kInt;
} else if (fd_dtype == FDDataType::INT64) {
return at::kLong;
}
FDERROR << "Unrecognized fastdeply data type:" << Str(fd_dtype) << "."
<< std::endl;
return at::kFloat;
}
FDDataType GetFdDtype(const at::ScalarType& poros_dtype) {
if (poros_dtype == at::kFloat) {
return FDDataType::FP32;
} else if (poros_dtype == at::kDouble) {
return FDDataType::FP64;
} else if (poros_dtype == at::kInt) {
return FDDataType::INT32;
} else if (poros_dtype == at::kLong) {
return FDDataType::INT64;
}
FDERROR << "Unrecognized poros data type:" << AtType2String(poros_dtype)
<< "." << std::endl;
return FDDataType::FP32;
}
at::Tensor CreatePorosValue(FDTensor& tensor, bool is_backend_cuda) {
FDASSERT(tensor.device == Device::GPU || tensor.device == Device::CPU,
"Only support tensor which device is CPU or GPU for PorosBackend.");
auto data_type = GetPorosDtype(tensor.dtype);
size_t numel = tensor.Numel();
at::Tensor poros_value;
if (is_backend_cuda) {
poros_value = std::move(
at::empty(tensor.shape, {at::kCUDA}).to(data_type).contiguous());
} else {
poros_value = std::move(
at::empty(tensor.shape, {at::kCPU}).to(data_type).contiguous());
}
if (data_type == at::kFloat) {
if (is_backend_cuda) {
cudaMemcpy(poros_value.data_ptr(), static_cast<void*>(tensor.Data()),
numel * sizeof(float), cudaMemcpyHostToDevice);
} else {
memcpy(poros_value.data_ptr(), static_cast<void*>(tensor.Data()),
numel * sizeof(float));
}
} else if (data_type == at::kInt) {
if (is_backend_cuda) {
cudaMemcpy(poros_value.data_ptr(), static_cast<void*>(tensor.Data()),
numel * sizeof(int32_t), cudaMemcpyHostToDevice);
} else {
memcpy(poros_value.data_ptr(), static_cast<void*>(tensor.Data()),
numel * sizeof(int32_t));
}
} else if (data_type == at::kLong) {
if (is_backend_cuda) {
cudaMemcpy(poros_value.data_ptr(), static_cast<void*>(tensor.Data()),
numel * sizeof(int64_t), cudaMemcpyHostToDevice);
} else {
memcpy(poros_value.data_ptr(), static_cast<void*>(tensor.Data()),
numel * sizeof(int64_t));
}
} else if (data_type == at::kDouble) {
if (is_backend_cuda) {
cudaMemcpy(poros_value.data_ptr(), static_cast<void*>(tensor.Data()),
numel * sizeof(double), cudaMemcpyHostToDevice);
} else {
memcpy(poros_value.data_ptr(), static_cast<void*>(tensor.Data()),
numel * sizeof(double));
}
} else {
FDASSERT(false,
"Unrecognized data type while calling "
"PorosBackend::CreatePorosValue().");
}
return poros_value;
}
void CopyTensorToCpu(const at::Tensor& tensor, FDTensor* fd_tensor,
bool is_backend_cuda) {
const auto data_type = tensor.scalar_type();
std::vector<int64_t> shape;
auto sizes = tensor.sizes();
for (size_t i = 0; i < sizes.size(); i++) {
shape.push_back(sizes[i]);
}
auto fd_dtype = GetFdDtype(data_type);
fd_tensor->Resize(shape, fd_dtype);
size_t numel = tensor.numel();
// at::Tensor -> FDTensor
if (data_type == at::kFloat) {
if (is_backend_cuda) {
cudaMemcpy(fd_tensor->Data(), tensor.data_ptr(), numel * sizeof(float),
cudaMemcpyDeviceToHost);
} else {
memcpy(fd_tensor->Data(), tensor.data_ptr(), numel * sizeof(float));
}
return;
} else if (data_type == at::kInt) {
if (is_backend_cuda) {
cudaMemcpy(fd_tensor->Data(), tensor.data_ptr(), numel * sizeof(int32_t),
cudaMemcpyDeviceToHost);
} else {
memcpy(fd_tensor->Data(), tensor.data_ptr(), numel * sizeof(int32_t));
}
return;
} else if (data_type == at::kLong) {
if (is_backend_cuda) {
cudaMemcpy(fd_tensor->Data(), tensor.data_ptr(), numel * sizeof(int64_t),
cudaMemcpyDeviceToHost);
} else {
memcpy(fd_tensor->Data(), tensor.data_ptr(), numel * sizeof(int64_t));
}
return;
} else if (data_type == at::kDouble) {
if (is_backend_cuda) {
cudaMemcpy(fd_tensor->Data(), tensor.data_ptr(), numel * sizeof(double),
cudaMemcpyDeviceToHost);
} else {
memcpy(fd_tensor->Data(), tensor.data_ptr(), numel * sizeof(double));
}
return;
}
}
} // namespace fastdeploy

4
fastdeploy/core/config.h.in Normal file → Executable file
View File

@@ -29,6 +29,10 @@
#cmakedefine ENABLE_PADDLE_BACKEND
#endif
#ifndef ENABLE_POROS_BACKEND
#cmakedefine ENABLE_POROS_BACKEND
#endif
#ifndef ENABLE_OPENVINO_BACKEND
#cmakedefine ENABLE_OPENVINO_BACKEND
#endif

View File

@@ -24,6 +24,7 @@ void BindRuntime(pybind11::module& m) {
.def("use_cpu", &RuntimeOption::UseCpu)
.def("set_cpu_thread_num", &RuntimeOption::SetCpuThreadNum)
.def("use_paddle_backend", &RuntimeOption::UsePaddleBackend)
.def("use_poros_backend", &RuntimeOption::UsePorosBackend)
.def("use_ort_backend", &RuntimeOption::UseOrtBackend)
.def("set_ort_graph_opt_level", &RuntimeOption::SetOrtGraphOptLevel)
.def("use_trt_backend", &RuntimeOption::UseTrtBackend)
@@ -62,7 +63,12 @@ void BindRuntime(pybind11::module& m) {
.def_readwrite("trt_enable_int8", &RuntimeOption::trt_enable_int8)
.def_readwrite("trt_max_batch_size", &RuntimeOption::trt_max_batch_size)
.def_readwrite("trt_max_workspace_size",
&RuntimeOption::trt_max_workspace_size);
&RuntimeOption::trt_max_workspace_size)
.def_readwrite("is_dynamic", &RuntimeOption::is_dynamic)
.def_readwrite("long_to_int", &RuntimeOption::long_to_int)
.def_readwrite("use_nvidia_tf32", &RuntimeOption::use_nvidia_tf32)
.def_readwrite("unconst_ops_thres", &RuntimeOption::unconst_ops_thres)
.def_readwrite("poros_file", &RuntimeOption::poros_file);
pybind11::class_<TensorInfo>(m, "TensorInfo")
.def_readwrite("name", &TensorInfo::name)
@@ -72,6 +78,30 @@ void BindRuntime(pybind11::module& m) {
pybind11::class_<Runtime>(m, "Runtime")
.def(pybind11::init())
.def("init", &Runtime::Init)
.def("compile",
[](Runtime& self,
std::vector<std::vector<pybind11::array>>& warm_datas,
const RuntimeOption& _option) {
size_t rows = warm_datas.size();
size_t columns = warm_datas[0].size();
std::vector<std::vector<FDTensor>> warm_tensors(
rows, std::vector<FDTensor>(columns));
for (size_t i = 0; i < rows; ++i) {
for (size_t j = 0; j < columns; ++j) {
auto dtype =
NumpyDataTypeToFDDataType(warm_datas[i][j].dtype());
std::vector<int64_t> data_shape;
data_shape.insert(
data_shape.begin(), warm_datas[i][j].shape(),
warm_datas[i][j].shape() + warm_datas[i][j].ndim());
warm_tensors[i][j].Resize(data_shape, dtype);
memcpy(warm_tensors[i][j].MutableData(),
warm_datas[i][j].mutable_data(),
warm_datas[i][j].nbytes());
}
}
return self.Compile(warm_tensors, _option);
})
.def("infer",
[](Runtime& self, std::vector<FDTensor>& inputs) {
std::vector<FDTensor> outputs(self.NumOutputs());
@@ -121,11 +151,13 @@ void BindRuntime(pybind11::module& m) {
.value("UNKOWN", Backend::UNKNOWN)
.value("ORT", Backend::ORT)
.value("TRT", Backend::TRT)
.value("POROS", Backend::POROS)
.value("PDINFER", Backend::PDINFER)
.value("LITE", Backend::LITE);
pybind11::enum_<ModelFormat>(m, "ModelFormat", pybind11::arithmetic(),
"ModelFormat for inference.")
.value("PADDLE", ModelFormat::PADDLE)
.value("TORCHSCRIPT", ModelFormat::TORCHSCRIPT)
.value("ONNX", ModelFormat::ONNX);
pybind11::enum_<Device>(m, "Device", pybind11::arithmetic(),
"Device for inference.")

View File

@@ -29,6 +29,10 @@
#include "fastdeploy/backends/paddle/paddle_backend.h"
#endif
#ifdef ENABLE_POROS_BACKEND
#include "fastdeploy/backends/poros/poros_backend.h"
#endif
#ifdef ENABLE_OPENVINO_BACKEND
#include "fastdeploy/backends/openvino/ov_backend.h"
#endif
@@ -50,6 +54,9 @@ std::vector<Backend> GetAvailableBackends() {
#ifdef ENABLE_PADDLE_BACKEND
backends.push_back(Backend::PDINFER);
#endif
#ifdef ENABLE_POROS_BACKEND
backends.push_back(Backend::POROS);
#endif
#ifdef ENABLE_OPENVINO_BACKEND
backends.push_back(Backend::OPENVINO);
#endif
@@ -76,6 +83,8 @@ std::string Str(const Backend& b) {
return "Backend::TRT";
} else if (b == Backend::PDINFER) {
return "Backend::PDINFER";
} else if (b == Backend::POROS) {
return "Backend::POROS";
} else if (b == Backend::OPENVINO) {
return "Backend::OPENVINO";
} else if (b == Backend::LITE) {
@@ -89,6 +98,8 @@ std::string Str(const ModelFormat& f) {
return "ModelFormat::PADDLE";
} else if (f == ModelFormat::ONNX) {
return "ModelFormat::ONNX";
} else if (f == ModelFormat::TORCHSCRIPT) {
return "ModelFormat::TORCHSCRIPT";
}
return "UNKNOWN-ModelFormat";
}
@@ -102,6 +113,8 @@ std::ostream& operator<<(std::ostream& out, const Backend& backend) {
out << "Backend::PDINFER";
} else if (backend == Backend::OPENVINO) {
out << "Backend::OPENVINO";
} else if (backend == Backend::POROS) {
out << "Backend::POROS";
} else if (backend == Backend::LITE) {
out << "Backend::LITE";
}
@@ -114,6 +127,8 @@ std::ostream& operator<<(std::ostream& out, const ModelFormat& format) {
out << "ModelFormat::PADDLE";
} else if (format == ModelFormat::ONNX) {
out << "ModelFormat::ONNX";
} else if (format == ModelFormat::TORCHSCRIPT) {
out << "ModelFormat::TORCHSCRIPT";
}
out << "UNKNOWN-ModelFormat";
return out;
@@ -137,9 +152,17 @@ bool CheckModelFormat(const std::string& model_file,
<< model_file << std::endl;
return false;
}
} else if (model_format == ModelFormat::TORCHSCRIPT) {
if (model_file.size() < 3 ||
model_file.substr(model_file.size() - 3, 3) != ".pt") {
FDERROR << "With model format of ModelFormat::TORCHSCRIPT, the model file "
"should ends with `.pt`, but now it's "
<< model_file << std::endl;
return false;
}
} else {
FDERROR << "Only support model format with frontend ModelFormat::PADDLE / "
"ModelFormat::ONNX."
"ModelFormat::ONNX / ModelFormat::TORCHSCRIPT."
<< std::endl;
return false;
}
@@ -155,6 +178,10 @@ ModelFormat GuessModelFormat(const std::string& model_file) {
model_file.substr(model_file.size() - 5, 5) == ".onnx") {
FDINFO << "Model Format: ONNX." << std::endl;
return ModelFormat::ONNX;
} else if (model_file.size() > 3 &&
model_file.substr(model_file.size() - 3, 3) == ".pt") {
FDINFO << "Model Format: Torchscript." << std::endl;
return ModelFormat::TORCHSCRIPT;
}
FDERROR << "Cannot guess which model format you are using, please set "
@@ -173,10 +200,13 @@ void RuntimeOption::SetModelPath(const std::string& model_path,
} else if (format == ModelFormat::ONNX) {
model_file = model_path;
model_format = ModelFormat::ONNX;
} else if (format == ModelFormat::TORCHSCRIPT) {
model_file = model_path;
model_format = ModelFormat::TORCHSCRIPT;
} else {
FDASSERT(
false,
"The model format only can be ModelFormat::PADDLE/ModelFormat::ONNX.");
"The model format only can be ModelFormat::PADDLE/ModelFormat::ONNX/ModelFormat::TORCHSCRIPT.");
}
}
@@ -223,6 +253,15 @@ void RuntimeOption::UseOrtBackend() {
#endif
}
// use poros backend
void RuntimeOption::UsePorosBackend() {
#ifdef ENABLE_POROS_BACKEND
backend = Backend::POROS;
#else
FDASSERT(false, "The FastDeploy didn't compile with PorosBackend.");
#endif
}
void RuntimeOption::UseTrtBackend() {
#ifdef ENABLE_TRT_BACKEND
backend = Backend::TRT;
@@ -324,6 +363,36 @@ void RuntimeOption::SetTrtCacheFile(const std::string& cache_file_path) {
trt_serialize_file = cache_file_path;
}
bool Runtime::Compile(std::vector<std::vector<FDTensor>>& prewarm_tensors,
const RuntimeOption& _option) {
#ifdef ENABLE_POROS_BACKEND
option = _option;
auto poros_option = PorosBackendOption();
poros_option.use_gpu = (option.device == Device::GPU) ? true : false;
poros_option.gpu_id = option.device_id;
poros_option.long_to_int = option.long_to_int;
poros_option.use_nvidia_tf32 = option.use_nvidia_tf32;
poros_option.unconst_ops_thres = option.unconst_ops_thres;
poros_option.poros_file = option.poros_file;
poros_option.is_dynamic = option.is_dynamic;
poros_option.enable_fp16 = option.trt_enable_fp16;
poros_option.max_batch_size = option.trt_max_batch_size;
poros_option.max_workspace_size = option.trt_max_workspace_size;
FDASSERT(option.model_format == ModelFormat::TORCHSCRIPT,
"PorosBackend only support model format of ModelFormat::TORCHSCRIPT.");
backend_ = utils::make_unique<PorosBackend>();
auto casted_backend = dynamic_cast<PorosBackend*>(backend_.get());
FDASSERT(
casted_backend->Compile(option.model_file, prewarm_tensors, poros_option),
"Load model from Torchscript failed while initliazing PorosBackend.");
#else
FDASSERT(false,
"PorosBackend is not available, please compiled with "
"ENABLE_POROS_BACKEND=ON.");
#endif
return true;
}
bool Runtime::Init(const RuntimeOption& _option) {
option = _option;
if (option.model_format == ModelFormat::AUTOREC) {
@@ -334,6 +403,8 @@ bool Runtime::Init(const RuntimeOption& _option) {
option.backend = Backend::ORT;
} else if (IsBackendAvailable(Backend::PDINFER)) {
option.backend = Backend::PDINFER;
} else if (IsBackendAvailable(Backend::POROS)) {
option.backend = Backend::POROS;
} else if (IsBackendAvailable(Backend::OPENVINO)) {
option.backend = Backend::OPENVINO;
} else {
@@ -365,6 +436,15 @@ bool Runtime::Init(const RuntimeOption& _option) {
CreatePaddleBackend();
FDINFO << "Runtime initialized with Backend::PDINFER in "
<< Str(option.device) << "." << std::endl;
} else if (option.backend == Backend::POROS) {
FDASSERT(option.device == Device::CPU || option.device == Device::GPU,
"Backend::POROS only supports Device::CPU/Device::GPU.");
FDASSERT(
option.model_format == ModelFormat::TORCHSCRIPT,
"Backend::POROS only supports model format of ModelFormat::TORCHSCRIPT.");
FDINFO << "Runtime initialized with Backend::POROS in "
<< Str(option.device) << "." << std::endl;
return true;
} else if (option.backend == Backend::OPENVINO) {
FDASSERT(option.device == Device::CPU,
"Backend::OPENVINO only supports Device::CPU");
@@ -379,7 +459,8 @@ bool Runtime::Init(const RuntimeOption& _option) {
<< "." << std::endl;
} else {
FDERROR << "Runtime only support "
"Backend::ORT/Backend::TRT/Backend::PDINFER as backend now."
"Backend::ORT/Backend::TRT/Backend::PDINFER/Backend::POROS as "
"backend now."
<< std::endl;
return false;
}

View File

@@ -38,6 +38,7 @@ enum Backend {
ORT, ///< ONNX Runtime, support Paddle/ONNX format model, CPU / Nvidia GPU
TRT, ///< TensorRT, support Paddle/ONNX format model, Nvidia GPU only
PDINFER, ///< Paddle Inference, support Paddle format model, CPU / Nvidia GPU
POROS, ///< Poros, support TorchScript format model, CPU / Nvidia GPU
OPENVINO, ///< Intel OpenVINO, support Paddle/ONNX format, CPU only
LITE, ///< Paddle Lite, support Paddle format model, ARM CPU only
};
@@ -47,6 +48,7 @@ enum ModelFormat {
AUTOREC, ///< Auto recognize the model format by model file name
PADDLE, ///< Model with paddlepaddle format
ONNX, ///< Model with ONNX format
TORCHSCRIPT, ///< Model with TorchScript format
};
FASTDEPLOY_DECL std::ostream& operator<<(std::ostream& out,
@@ -117,6 +119,9 @@ struct FASTDEPLOY_DECL RuntimeOption {
/// Set TensorRT as inference backend, only support GPU
void UseTrtBackend();
/// Set Poros backend as inference backend, support CPU/GPU
void UsePorosBackend();
/// Set OpenVINO as inference backend, only support CPU
void UseOpenVINOBackend();
@@ -243,6 +248,13 @@ struct FASTDEPLOY_DECL RuntimeOption {
size_t trt_max_batch_size = 32;
size_t trt_max_workspace_size = 1 << 30;
// ======Only for Poros Backend=======
bool is_dynamic = false;
bool long_to_int = true;
bool use_nvidia_tf32 = false;
int unconst_ops_thres = -1;
std::string poros_file = "";
std::string model_file = ""; // Path of model file
std::string params_file = ""; // Path of parameters file, can be empty
ModelFormat model_format = ModelFormat::AUTOREC; // format of input model
@@ -270,6 +282,15 @@ struct FASTDEPLOY_DECL Runtime {
bool Infer(std::vector<FDTensor>& input_tensors,
std::vector<FDTensor>* output_tensors);
/** \brief Compile TorchScript Module, only for Poros backend
*
* \param[in] prewarm_tensors Prewarm datas for compile
* \param[in] _option Runtime option
* \return true if compile successed, otherwise false
*/
bool Compile(std::vector<std::vector<FDTensor>>& prewarm_tensors,
const RuntimeOption& _option);
/** \brief Get number of inputs
*/
int NumInputs() { return backend_->NumInputs(); }