mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-22 08:09:28 +08:00
Support Poros Backend (#188)
* Add poros backend * Add torch lib * Add python3 lib * set c++ 14 for poros * fixed bugs * fixed grammer bugs * fixed grammer bugs * fixed code bugs * fixed code bugs * fixed CreatePorosValue bug * Add AtType2String for Log * fixed trt_option * fixed poros.cmake path * fixed grammer bug * fixed grammer bug * fixed ambiguous reference * fixed ambiguous reference * fixed reference error * fixed include files * rm ENABLE_TRT_BACKEND in poros * update CMakeLists.txt * fixed CMakeLists.txt * Add libtorch.so in CMakeLists.txt * Fixed CMakeLists.txt * Fixed CMakeLists.txt * Fixed copy bug * Fixed copy bug * Fixed copy bug * Fixed Cmake * Fixed Cmake * debug * debug * debug * debug * debug * debug * debug utils * debug utils * copy to cpu * rm log info * test share mem * test share mem * test share mem * test multi outputs * test multi outputs * test multi outputs * test multi outputs * test multi outputs * test multi outputs * test multi outputs * time cost * time cost * fixed bug * time collect * mem copy * mem copy * rm time log * rm share mem * fixed multi inputs bug * add set_input_dtypes func * add SetInputDtypes * fixed bug * fixed bug * fixed prewarm data order * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * fixed bug * Add compile func * Add compile func * Add compile func * Add is_dynamic option * Add is_dynamic option * Add is_dynamic option * Add is_dynamic option * rm infer log * add cuda11.6 poros lib * fixed bug * fixed bug * fixed multi outputs * fixed multi outputs * fixed multi outputs * fixed multi outputs * fixed multi outputs * fixed multi outputs * fixed multi outputs * fixed multi outputs * fixed multi outputs * fixed multi outputs * fixed multi outputs * rm logs * test * test * test * add test log * add test log * add test log * add test log * support cpu * support cpu * support cpu * support cpu * support member variable definition * rm useless log * fixed name * resolve conflict * resolve conflict * resolve conflict * fixed cmake * add GetInputInfos&GetOutputInfos * add GetInputInfos&GetOutputInfos * fixed bug * fixed runtime.py * add compile func * add np * deal with comments * rm to_inter func * add property
This commit is contained in:
167
fastdeploy/backends/poros/common/compile.h
Executable file
167
fastdeploy/backends/poros/common/compile.h
Executable file
@@ -0,0 +1,167 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
#include <set>
|
||||
|
||||
#include "torch/script.h"
|
||||
#include "iengine.h"
|
||||
#include "poros_module.h"
|
||||
|
||||
namespace baidu {
|
||||
namespace mirana {
|
||||
namespace poros {
|
||||
|
||||
/**
|
||||
* @brief compile graph
|
||||
*
|
||||
* @param [in] module : original module
|
||||
* @param [in] input_ivalues : prewarm datas
|
||||
* @param [in] options : Inference options
|
||||
* @return porosmodule
|
||||
* @retval !nullptr => succeed nullptr => failed
|
||||
**/
|
||||
std::unique_ptr<PorosModule> Compile(const torch::jit::Module& module,
|
||||
const std::vector<std::vector<c10::IValue> >& prewarm_datas,
|
||||
const PorosOptions& options);
|
||||
|
||||
class Compiler {
|
||||
public:
|
||||
typedef std::unordered_map<const torch::jit::Node*, IEngine*> engine_map_t;
|
||||
typedef std::vector<std::vector<c10::IValue> > ivalue_vec_t;
|
||||
|
||||
Compiler() : _origin_module(NULL) {}
|
||||
~Compiler();
|
||||
|
||||
/**
|
||||
* @brief initial Compiler
|
||||
*
|
||||
* @param [in] options : poros options
|
||||
* @return int
|
||||
* @retval 0 => succeed <0 => failed
|
||||
**/
|
||||
int init(const PorosOptions& options);
|
||||
|
||||
/**
|
||||
* @brief compile whole graph
|
||||
*
|
||||
* @param [in] origin_module
|
||||
* @param [in] prewarm_datas : ivalue_vec_t, vector of IValue
|
||||
* @param [out] optimized_module : optimized graph
|
||||
* @return int
|
||||
* @retval 0 => succeed <0 => failed
|
||||
**/
|
||||
int compile(const torch::jit::Module& origin_module,
|
||||
const ivalue_vec_t& prewarm_datas,
|
||||
torch::jit::Module* optimized_module);
|
||||
|
||||
private:
|
||||
|
||||
/**
|
||||
* @brief preprocess this calculation graph
|
||||
*
|
||||
* @param [in] prewarm_datas : ivalue_vec_t, vector of IValue
|
||||
* @param [out] graph : preprcessed graph
|
||||
* @return int
|
||||
* @retval 0 => succeed <0 => failed
|
||||
**/
|
||||
int preprocess_graph(const ivalue_vec_t& prewarm_datas, std::shared_ptr<torch::jit::Graph>& graph);
|
||||
|
||||
/**
|
||||
* @brief segement this calculation graph
|
||||
*
|
||||
* @param [in/out] graph
|
||||
* @return int
|
||||
* @retval 0 => succeed <0 => failed
|
||||
**/
|
||||
int segment_graph(std::shared_ptr<torch::jit::Graph>& graph);
|
||||
|
||||
// Split subgraph(block)
|
||||
// The divided subgraph, as a subgraph, is associated with the block
|
||||
int segment_block(torch::jit::Block& block, IEngine* engine, int current_depth);
|
||||
|
||||
// Subgraph optimization
|
||||
/**
|
||||
* @brief Subgraph optimization
|
||||
*
|
||||
* @param [in] prewarm_datas : ivalue_vec_t, vector of IValue
|
||||
* @param [in] opt_graph : ivalue_vec_t, vector of IValue
|
||||
* @param [out] optimized_module : optimized graph
|
||||
* @return int
|
||||
* @retval 0 => succeed <0 => failed
|
||||
**/
|
||||
int optimize_subgraph(const ivalue_vec_t& prewarm_datas,
|
||||
const std::shared_ptr<torch::jit::Graph>& opt_graph,
|
||||
torch::jit::Module* optimized_module);
|
||||
|
||||
// Subgraph optimization(block)
|
||||
int optimize_subblock(torch::jit::Block* block,
|
||||
torch::jit::Module* optimized_module);
|
||||
|
||||
/**
|
||||
* @brief Compile the subgraph into a new graph based on the engine
|
||||
*
|
||||
* @param [in] engine : The engine used by the subgraph
|
||||
* @param [in] subgraph_node : Subgraph node
|
||||
* @return [out] module : Transformed model
|
||||
* @retval 0 => succeed <0 => failed
|
||||
**/
|
||||
int transform(IEngine* engine, torch::jit::Node& subgraph_node,
|
||||
torch::jit::Module& module);
|
||||
|
||||
/**
|
||||
* @brief Select engine based on subgraph and options
|
||||
*
|
||||
* @param [in] node : Jit Node
|
||||
* @return int
|
||||
* @retval 0 => succeed <0 => failed
|
||||
**/
|
||||
IEngine* select_engine(const torch::jit::Node* n);
|
||||
|
||||
/**
|
||||
* @brief destory
|
||||
*
|
||||
* @return void
|
||||
**/
|
||||
void close();
|
||||
|
||||
private:
|
||||
int _max_segment_depth{5}; // Maximum subgraph segmentation depth
|
||||
ivalue_vec_t _prewarm_datas; // Prewarm datas
|
||||
PorosOptions _options;
|
||||
engine_map_t _engine_map; // The engine used to record the subgraph
|
||||
const torch::jit::Module* _origin_module; // Origin_module
|
||||
std::atomic<int> _engine_index = {0}; // Record engine index
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief compile graph, internal use
|
||||
*
|
||||
* @param [in] module : Origin module
|
||||
* @param [in] input_ivalues : Prewarm datas
|
||||
* @param [in] options : Inference options
|
||||
* @return optimized_module
|
||||
* @retval !nullptr => succeed nullptr => failed
|
||||
**/
|
||||
std::unique_ptr<torch::jit::Module> CompileGraph(const torch::jit::Module& module,
|
||||
const std::vector<std::vector<c10::IValue> >& prewarm_datas,
|
||||
const PorosOptions& options);
|
||||
|
||||
} // namespace poros
|
||||
} // namespace mirana
|
||||
} // namespace baidu
|
84
fastdeploy/backends/poros/common/iengine.h
Executable file
84
fastdeploy/backends/poros/common/iengine.h
Executable file
@@ -0,0 +1,84 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
//from pytorch
|
||||
#include "torch/script.h"
|
||||
#include "torch/csrc/jit/ir/ir.h"
|
||||
#include "ATen/core/interned_strings.h"
|
||||
|
||||
#include "plugin_create.h"
|
||||
|
||||
namespace baidu {
|
||||
namespace mirana {
|
||||
namespace poros {
|
||||
|
||||
/**
|
||||
* the base engine class
|
||||
* every registered engine should inherit from this IEngine
|
||||
**/
|
||||
|
||||
struct PorosGraph {
|
||||
torch::jit::Graph* graph = NULL;
|
||||
torch::jit::Node* node = NULL;
|
||||
};
|
||||
|
||||
typedef uint64_t EngineID;
|
||||
|
||||
class IEngine : public IPlugin, public torch::CustomClassHolder{
|
||||
public:
|
||||
virtual ~IEngine() {}
|
||||
|
||||
/**
|
||||
* @brief init, initialization must be successful if the init is successful
|
||||
* @return int
|
||||
* @retval 0 => success, <0 => fail
|
||||
**/
|
||||
virtual int init() = 0;
|
||||
|
||||
/**
|
||||
* @brief During compilation, the subgraph is converted into the graph structure of the corresponding engine and stored inside the engine, so that the execute_engine at runtime can be called
|
||||
* @param [in] sub_graph : subgraph
|
||||
* @return [res]int
|
||||
* @retval 0 => success, <0 => fail
|
||||
**/
|
||||
virtual int transform(const PorosGraph& sub_graph) = 0;
|
||||
|
||||
/**
|
||||
* @brief Subgraph execution period logic
|
||||
* @param [in] inputs : input tensor
|
||||
* @return [res] output tensor
|
||||
**/
|
||||
virtual std::vector<at::Tensor> excute_engine(const std::vector<at::Tensor>& inputs) = 0;
|
||||
|
||||
virtual void register_module_attribute(const std::string& name, torch::jit::Module& module) = 0;
|
||||
|
||||
// Logo
|
||||
virtual const std::string who_am_i() = 0;
|
||||
|
||||
// Whether the node is supported by the current engine
|
||||
bool is_node_supported(const torch::jit::Node* node);
|
||||
|
||||
public:
|
||||
std::pair<uint64_t, uint64_t> _num_io; // Number of input/output parameters
|
||||
EngineID _id;
|
||||
|
||||
};
|
||||
|
||||
} // namespace poros
|
||||
} // namespace mirana
|
||||
} // namespace baidu
|
65
fastdeploy/backends/poros/common/plugin_create.h
Executable file
65
fastdeploy/backends/poros/common/plugin_create.h
Executable file
@@ -0,0 +1,65 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
|
||||
namespace baidu {
|
||||
namespace mirana {
|
||||
namespace poros {
|
||||
|
||||
class IPlugin {
|
||||
public:
|
||||
virtual ~IPlugin() {}
|
||||
virtual const std::string who_am_i() = 0;
|
||||
};
|
||||
|
||||
typedef IPlugin* (*plugin_creator_t)();
|
||||
typedef std::unordered_map<std::string, plugin_creator_t> plugin_creator_map_t;
|
||||
|
||||
IPlugin* create_plugin(const std::string& plugin_name);
|
||||
IPlugin* create_plugin(const std::string& plugin_name, const plugin_creator_map_t& plugin_creator_map);
|
||||
|
||||
void create_all_plugins(const plugin_creator_map_t& plugin_creator_map,
|
||||
std::unordered_map<std::string, IPlugin*>& plugin_m);
|
||||
//void create_all_plugins(std::unordered_map<std::string, IPlugin*>& plugin_m);
|
||||
|
||||
template <typename PluginType>
|
||||
IPlugin* default_plugin_creator() {
|
||||
return new (std::nothrow)PluginType;
|
||||
}
|
||||
|
||||
void register_plugin_creator(const std::string& plugin_name, plugin_creator_t creator);
|
||||
void register_plugin_creator(const std::string& plugin_name,
|
||||
plugin_creator_t creator, plugin_creator_map_t& plugin_creator_map);
|
||||
|
||||
template <typename PluginType>
|
||||
void register_plugin_class(const std::string& plugin_name) {
|
||||
return register_plugin_creator(plugin_name, default_plugin_creator<PluginType>);
|
||||
}
|
||||
|
||||
// This version is recommended
|
||||
template <typename PluginType>
|
||||
void register_plugin_class(const std::string& plugin_name, plugin_creator_map_t& plugin_creator_map) {
|
||||
return register_plugin_creator(plugin_name, default_plugin_creator<PluginType>, plugin_creator_map);
|
||||
}
|
||||
|
||||
}//poros
|
||||
}//mirana
|
||||
}//baidu
|
||||
|
||||
|
||||
/* vim: set ts=4 sw=4 sts=4 tw=100 */
|
67
fastdeploy/backends/poros/common/poros_module.h
Executable file
67
fastdeploy/backends/poros/common/poros_module.h
Executable file
@@ -0,0 +1,67 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "torch/script.h"
|
||||
#include "torch/csrc/jit/jit_log.h"
|
||||
// #include "ATen/Context.h"
|
||||
|
||||
namespace baidu {
|
||||
namespace mirana {
|
||||
namespace poros {
|
||||
|
||||
enum Device : int8_t {
|
||||
GPU = 0,
|
||||
CPU,
|
||||
XPU,
|
||||
UNKNOW
|
||||
};
|
||||
|
||||
struct PorosOptions {
|
||||
Device device = GPU;
|
||||
bool debug = false;
|
||||
bool use_fp16 = false;
|
||||
bool is_dynamic = false;
|
||||
bool long_to_int = true;
|
||||
uint64_t max_workspace_size = 1ULL << 30;
|
||||
int32_t device_id = -1;
|
||||
int32_t unconst_ops_thres = -1;
|
||||
bool use_nvidia_tf32 = false;
|
||||
};
|
||||
|
||||
class PorosModule : public torch::jit::Module {
|
||||
public:
|
||||
PorosModule(torch::jit::Module module) : torch::jit::Module(module) {
|
||||
}
|
||||
~PorosModule() = default;
|
||||
|
||||
void to_device(Device device){
|
||||
_options.device = device;
|
||||
}
|
||||
|
||||
//c10::IValue forward(std::vector<c10::IValue> inputs);
|
||||
//void save(const std::string& filename);
|
||||
public:
|
||||
PorosOptions _options;
|
||||
|
||||
};
|
||||
|
||||
//via porosmodule.save
|
||||
std::unique_ptr<PorosModule> Load(const std::string& filename, const PorosOptions& options);
|
||||
|
||||
} // namespace poros
|
||||
} // namespace mirana
|
||||
} // namespace baidu
|
240
fastdeploy/backends/poros/poros_backend.cc
Executable file
240
fastdeploy/backends/poros/poros_backend.cc
Executable file
@@ -0,0 +1,240 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/backends/poros/poros_backend.h"
|
||||
#include <sys/time.h>
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
TensorInfo PorosBackend::GetInputInfo(int index) {
|
||||
// eager mode cann't obtain input information before infer
|
||||
TensorInfo info_input;
|
||||
return info_input;
|
||||
}
|
||||
|
||||
TensorInfo PorosBackend::GetOutputInfo(int index) {
|
||||
// eager mode cann't obtain output information before infer
|
||||
TensorInfo info_output;
|
||||
return info_output;
|
||||
}
|
||||
|
||||
std::vector<TensorInfo> PorosBackend::GetInputInfos() {
|
||||
// eager mode cann't obtain inputs information before infer
|
||||
std::vector<TensorInfo> info_inputs;
|
||||
return info_inputs;
|
||||
}
|
||||
|
||||
std::vector<TensorInfo> PorosBackend::GetOutputInfos() {
|
||||
// eager mode cann't obtain outputs information before infer
|
||||
std::vector<TensorInfo> info_outputs;
|
||||
return info_outputs;
|
||||
}
|
||||
|
||||
void PorosBackend::BuildOption(const PorosBackendOption& option) {
|
||||
_options.device = option.use_gpu ? baidu::mirana::poros::Device::GPU
|
||||
: baidu::mirana::poros::Device::CPU;
|
||||
_options.long_to_int = option.long_to_int;
|
||||
_options.use_nvidia_tf32 = option.use_nvidia_tf32;
|
||||
_options.device_id = option.gpu_id;
|
||||
_options.unconst_ops_thres = option.unconst_ops_thres;
|
||||
_options.is_dynamic = option.is_dynamic;
|
||||
_options.max_workspace_size = option.max_workspace_size;
|
||||
_options.use_fp16 = option.enable_fp16;
|
||||
return;
|
||||
}
|
||||
|
||||
bool PorosBackend::Compile(const std::string& model_file,
|
||||
std::vector<std::vector<FDTensor>>& prewarm_tensors,
|
||||
const PorosBackendOption& option) {
|
||||
if (initialized_) {
|
||||
FDERROR << "PorosBackend is already initlized, cannot initialize again."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
BuildOption(option);
|
||||
torch::jit::Module mod;
|
||||
mod = torch::jit::load(model_file);
|
||||
mod.eval();
|
||||
if (option.use_gpu) {
|
||||
mod.to(at::kCUDA);
|
||||
} else {
|
||||
mod.to(at::kCPU);
|
||||
}
|
||||
// get inputs_nums and outputs_nums
|
||||
auto graph = mod.get_method("forward").graph();
|
||||
auto inputs = graph->inputs();
|
||||
// remove self node
|
||||
_numinputs = inputs.size() - 1;
|
||||
// FDTensor to at::Tensor
|
||||
std::vector<std::vector<c10::IValue>> prewarm_datas;
|
||||
bool is_backend_cuda = option.use_gpu ? true : false;
|
||||
for (size_t i = 0; i < prewarm_tensors.size(); ++i) {
|
||||
std::vector<c10::IValue> prewarm_data;
|
||||
for (size_t j = 0; j < prewarm_tensors[i].size(); ++j) {
|
||||
auto tensor = CreatePorosValue(prewarm_tensors[i][j], is_backend_cuda);
|
||||
prewarm_data.push_back(tensor);
|
||||
}
|
||||
prewarm_datas.push_back(prewarm_data);
|
||||
}
|
||||
// get outputs nums
|
||||
auto temp_result = mod.forward(prewarm_datas[0]);
|
||||
size_t outputs_nums = 0;
|
||||
if (temp_result.isTensor()) {
|
||||
outputs_nums += 1;
|
||||
} else if (temp_result.isTuple()) {
|
||||
auto temp_result_tuple = temp_result.toTuple();
|
||||
for (size_t i = 0; i < temp_result_tuple->elements().size(); ++i) {
|
||||
auto poros_tensor = temp_result_tuple->elements()[i];
|
||||
if (poros_tensor.isTensor()) {
|
||||
outputs_nums += 1;
|
||||
} else if (poros_tensor.isList()) {
|
||||
auto poros_tensor_list = poros_tensor.toList();
|
||||
outputs_nums += poros_tensor_list.size();
|
||||
} else if (poros_tensor.isTuple()) {
|
||||
auto poros_tensor_tuple = poros_tensor.toTuple();
|
||||
outputs_nums += poros_tensor_tuple->elements().size();
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
_numoutputs = outputs_nums;
|
||||
_poros_module = baidu::mirana::poros::Compile(mod, prewarm_datas, _options);
|
||||
if (_poros_module == nullptr) {
|
||||
FDERROR << "PorosBackend initlize Failed, try initialize again."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
initialized_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PorosBackend::InitFromTorchScript(const std::string& model_file,
|
||||
const PorosBackendOption& option) {
|
||||
if (initialized_) {
|
||||
FDERROR << "PorosBackend is already initlized, cannot initialize again."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
if (option.poros_file != "") {
|
||||
std::ifstream fin(option.poros_file, std::ios::binary | std::ios::in);
|
||||
if (fin) {
|
||||
FDINFO << "Detect compiled Poros file in " << option.poros_file
|
||||
<< ", will load it directly." << std::endl;
|
||||
fin.close();
|
||||
return InitFromPoros(option.poros_file, option);
|
||||
}
|
||||
}
|
||||
BuildOption(option);
|
||||
torch::jit::Module mod;
|
||||
mod = torch::jit::load(model_file);
|
||||
mod.eval();
|
||||
if (option.use_gpu) {
|
||||
mod.to(at::kCUDA);
|
||||
} else {
|
||||
mod.to(at::kCPU);
|
||||
}
|
||||
// get inputs_nums and outputs_nums
|
||||
auto graph = mod.get_method("forward").graph();
|
||||
auto inputs = graph->inputs();
|
||||
// remove self node
|
||||
_numinputs = inputs.size() - 1;
|
||||
auto outputs = graph->outputs();
|
||||
_numoutputs = outputs.size();
|
||||
_poros_module = baidu::mirana::poros::Compile(mod, _prewarm_datas, _options);
|
||||
if (_poros_module == nullptr) {
|
||||
FDERROR << "PorosBackend initlize Failed, try initialize again."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
initialized_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PorosBackend::InitFromPoros(const std::string& model_file,
|
||||
const PorosBackendOption& option) {
|
||||
if (initialized_) {
|
||||
FDERROR << "PorosBackend is already initlized, cannot initialize again."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
BuildOption(option);
|
||||
_poros_module = baidu::mirana::poros::Load(model_file, _options);
|
||||
if (_poros_module == nullptr) {
|
||||
FDERROR << "PorosBackend initlize Failed, try initialize again."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
// get inputs_nums and outputs_nums
|
||||
auto graph = _poros_module->get_method("forward").graph();
|
||||
auto inputs = graph->inputs();
|
||||
// remove self node
|
||||
_numinputs = inputs.size() - 1;
|
||||
auto outputs = graph->outputs();
|
||||
_numoutputs = outputs.size();
|
||||
initialized_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PorosBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
std::vector<FDTensor>* outputs) {
|
||||
// Convert FD Tensor to PyTorch Tensor
|
||||
std::vector<torch::jit::IValue> poros_inputs;
|
||||
bool is_backend_cuda =
|
||||
_options.device == baidu::mirana::poros::Device::GPU ? true : false;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
poros_inputs.push_back(CreatePorosValue(inputs[i], is_backend_cuda));
|
||||
}
|
||||
// Infer
|
||||
auto poros_outputs = _poros_module->forward(poros_inputs);
|
||||
// Convert PyTorch Tensor to FD Tensor
|
||||
if (poros_outputs.isTensor()) {
|
||||
CopyTensorToCpu(poros_outputs.toTensor(), &((*outputs)[0]),
|
||||
is_backend_cuda);
|
||||
} else if (poros_outputs.isTuple()) {
|
||||
// deal with multi outputs
|
||||
auto poros_outputs_tuple = poros_outputs.toTuple();
|
||||
size_t index = 0;
|
||||
for (size_t i = 0; i < poros_outputs_tuple->elements().size(); ++i) {
|
||||
auto poros_tensor = poros_outputs_tuple->elements()[i];
|
||||
if (poros_tensor.isTensor()) {
|
||||
CopyTensorToCpu(poros_tensor.toTensor(), &((*outputs)[index]),
|
||||
is_backend_cuda);
|
||||
index += 1;
|
||||
} else if (poros_tensor.isList()) {
|
||||
auto poros_tensor_list = poros_tensor.toList();
|
||||
for (const auto list_idx : c10::irange(0, poros_tensor_list.size())) {
|
||||
const auto& elt = poros_tensor_list.get(list_idx);
|
||||
CopyTensorToCpu(elt.toTensor(), &((*outputs)[index]),
|
||||
is_backend_cuda);
|
||||
index += 1;
|
||||
}
|
||||
} else if (poros_tensor.isTuple()) {
|
||||
auto poros_tensor_tuple = poros_tensor.toTuple();
|
||||
for (size_t j = 0; j < poros_tensor_tuple->elements().size(); ++j) {
|
||||
CopyTensorToCpu(poros_tensor_tuple->elements()[j].toTensor(),
|
||||
&((*outputs)[index]), is_backend_cuda);
|
||||
index += 1;
|
||||
}
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
FDERROR << "Convert to FDTensor Failed!!!!!" << std::endl;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
107
fastdeploy/backends/poros/poros_backend.h
Executable file
107
fastdeploy/backends/poros/poros_backend.h
Executable file
@@ -0,0 +1,107 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "fastdeploy/backends/backend.h"
|
||||
|
||||
#include "fastdeploy/backends/poros/common/compile.h"
|
||||
#include "fastdeploy/backends/poros/common/poros_module.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
struct PorosBackendOption {
|
||||
#ifdef WITH_GPU
|
||||
bool use_gpu = true;
|
||||
#else
|
||||
bool use_gpu = false;
|
||||
#endif
|
||||
int gpu_id = 0;
|
||||
bool long_to_int = true;
|
||||
// There is calculation precision in tf32 mode on A10, it can bring some
|
||||
// performance improvement, but there may be diff
|
||||
bool use_nvidia_tf32 = false;
|
||||
// Threshold for the number of non-const ops
|
||||
int32_t unconst_ops_thres = -1;
|
||||
std::string poros_file = "";
|
||||
std::vector<FDDataType> prewarm_datatypes = {FDDataType::FP32};
|
||||
// TRT options
|
||||
bool enable_fp16 = false;
|
||||
bool enable_int8 = false;
|
||||
bool is_dynamic = false;
|
||||
size_t max_batch_size = 32;
|
||||
size_t max_workspace_size = 1 << 30;
|
||||
};
|
||||
|
||||
// Convert data type from fastdeploy to poros
|
||||
at::ScalarType GetPorosDtype(const FDDataType& fd_dtype);
|
||||
|
||||
// Convert data type from poros to fastdeploy
|
||||
FDDataType GetFdDtype(const at::ScalarType& dtype);
|
||||
|
||||
// at::ScalarType to std::string for FDERROR
|
||||
std::string AtType2String(const at::ScalarType& dtype);
|
||||
|
||||
// Create at::Tensor
|
||||
// is_backend_cuda specify if Poros use GPU Device
|
||||
// While is_backend_cuda = true, and tensor.device = Device::GPU
|
||||
at::Tensor CreatePorosValue(FDTensor& tensor, bool is_backend_cuda = false);
|
||||
|
||||
// Copy memory data from at::Tensor to fastdeploy::FDTensor
|
||||
void CopyTensorToCpu(const at::Tensor& tensor, FDTensor* fd_tensor,
|
||||
bool is_backend_cuda = false);
|
||||
|
||||
class PorosBackend : public BaseBackend {
|
||||
public:
|
||||
PorosBackend() {}
|
||||
virtual ~PorosBackend() = default;
|
||||
|
||||
void BuildOption(const PorosBackendOption& option);
|
||||
|
||||
bool InitFromTorchScript(
|
||||
const std::string& model_file,
|
||||
const PorosBackendOption& option = PorosBackendOption());
|
||||
|
||||
bool InitFromPoros(const std::string& model_file,
|
||||
const PorosBackendOption& option = PorosBackendOption());
|
||||
|
||||
bool Compile(const std::string& model_file,
|
||||
std::vector<std::vector<FDTensor>>& prewarm_tensors,
|
||||
const PorosBackendOption& option = PorosBackendOption());
|
||||
|
||||
bool Infer(std::vector<FDTensor>& inputs, std::vector<FDTensor>* outputs);
|
||||
|
||||
int NumInputs() const { return _numinputs; }
|
||||
|
||||
int NumOutputs() const { return _numoutputs; }
|
||||
|
||||
TensorInfo GetInputInfo(int index) override;
|
||||
TensorInfo GetOutputInfo(int index) override;
|
||||
std::vector<TensorInfo> GetInputInfos() override;
|
||||
std::vector<TensorInfo> GetOutputInfos() override;
|
||||
|
||||
private:
|
||||
baidu::mirana::poros::PorosOptions _options;
|
||||
std::unique_ptr<baidu::mirana::poros::PorosModule> _poros_module;
|
||||
std::vector<std::vector<c10::IValue>> _prewarm_datas;
|
||||
int _numinputs = 1;
|
||||
int _numoutputs = 1;
|
||||
};
|
||||
|
||||
} // namespace fastdeploy
|
186
fastdeploy/backends/poros/utils.cc
Normal file
186
fastdeploy/backends/poros/utils.cc
Normal file
@@ -0,0 +1,186 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/backends/poros/poros_backend.h"
|
||||
|
||||
#ifdef WITH_GPU
|
||||
#include <cuda_runtime_api.h>
|
||||
#endif
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
std::string AtType2String(const at::ScalarType& dtype) {
|
||||
std::string out;
|
||||
switch (dtype) {
|
||||
case at::kByte:
|
||||
out = "at::kByte";
|
||||
break;
|
||||
case at::kChar:
|
||||
out = "at::kChar";
|
||||
break;
|
||||
case at::kShort:
|
||||
out = "at::kShort";
|
||||
break;
|
||||
case at::kInt:
|
||||
out = "at::kInt";
|
||||
break;
|
||||
case at::kLong:
|
||||
out = "at::kLong";
|
||||
break;
|
||||
case at::kHalf:
|
||||
out = "at::kHalf";
|
||||
break;
|
||||
case at::kFloat:
|
||||
out = "at::kFloat";
|
||||
break;
|
||||
case at::kDouble:
|
||||
out = "at::kDouble";
|
||||
break;
|
||||
default:
|
||||
out = "at::UNKNOWN";
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
at::ScalarType GetPorosDtype(const FDDataType& fd_dtype) {
|
||||
if (fd_dtype == FDDataType::FP32) {
|
||||
return at::kFloat;
|
||||
} else if (fd_dtype == FDDataType::FP64) {
|
||||
return at::kDouble;
|
||||
} else if (fd_dtype == FDDataType::INT32) {
|
||||
return at::kInt;
|
||||
} else if (fd_dtype == FDDataType::INT64) {
|
||||
return at::kLong;
|
||||
}
|
||||
FDERROR << "Unrecognized fastdeply data type:" << Str(fd_dtype) << "."
|
||||
<< std::endl;
|
||||
return at::kFloat;
|
||||
}
|
||||
|
||||
FDDataType GetFdDtype(const at::ScalarType& poros_dtype) {
|
||||
if (poros_dtype == at::kFloat) {
|
||||
return FDDataType::FP32;
|
||||
} else if (poros_dtype == at::kDouble) {
|
||||
return FDDataType::FP64;
|
||||
} else if (poros_dtype == at::kInt) {
|
||||
return FDDataType::INT32;
|
||||
} else if (poros_dtype == at::kLong) {
|
||||
return FDDataType::INT64;
|
||||
}
|
||||
FDERROR << "Unrecognized poros data type:" << AtType2String(poros_dtype)
|
||||
<< "." << std::endl;
|
||||
return FDDataType::FP32;
|
||||
}
|
||||
|
||||
at::Tensor CreatePorosValue(FDTensor& tensor, bool is_backend_cuda) {
|
||||
FDASSERT(tensor.device == Device::GPU || tensor.device == Device::CPU,
|
||||
"Only support tensor which device is CPU or GPU for PorosBackend.");
|
||||
auto data_type = GetPorosDtype(tensor.dtype);
|
||||
size_t numel = tensor.Numel();
|
||||
at::Tensor poros_value;
|
||||
if (is_backend_cuda) {
|
||||
poros_value = std::move(
|
||||
at::empty(tensor.shape, {at::kCUDA}).to(data_type).contiguous());
|
||||
} else {
|
||||
poros_value = std::move(
|
||||
at::empty(tensor.shape, {at::kCPU}).to(data_type).contiguous());
|
||||
}
|
||||
if (data_type == at::kFloat) {
|
||||
if (is_backend_cuda) {
|
||||
cudaMemcpy(poros_value.data_ptr(), static_cast<void*>(tensor.Data()),
|
||||
numel * sizeof(float), cudaMemcpyHostToDevice);
|
||||
} else {
|
||||
memcpy(poros_value.data_ptr(), static_cast<void*>(tensor.Data()),
|
||||
numel * sizeof(float));
|
||||
}
|
||||
} else if (data_type == at::kInt) {
|
||||
if (is_backend_cuda) {
|
||||
cudaMemcpy(poros_value.data_ptr(), static_cast<void*>(tensor.Data()),
|
||||
numel * sizeof(int32_t), cudaMemcpyHostToDevice);
|
||||
} else {
|
||||
memcpy(poros_value.data_ptr(), static_cast<void*>(tensor.Data()),
|
||||
numel * sizeof(int32_t));
|
||||
}
|
||||
} else if (data_type == at::kLong) {
|
||||
if (is_backend_cuda) {
|
||||
cudaMemcpy(poros_value.data_ptr(), static_cast<void*>(tensor.Data()),
|
||||
numel * sizeof(int64_t), cudaMemcpyHostToDevice);
|
||||
} else {
|
||||
memcpy(poros_value.data_ptr(), static_cast<void*>(tensor.Data()),
|
||||
numel * sizeof(int64_t));
|
||||
}
|
||||
} else if (data_type == at::kDouble) {
|
||||
if (is_backend_cuda) {
|
||||
cudaMemcpy(poros_value.data_ptr(), static_cast<void*>(tensor.Data()),
|
||||
numel * sizeof(double), cudaMemcpyHostToDevice);
|
||||
} else {
|
||||
memcpy(poros_value.data_ptr(), static_cast<void*>(tensor.Data()),
|
||||
numel * sizeof(double));
|
||||
}
|
||||
} else {
|
||||
FDASSERT(false,
|
||||
"Unrecognized data type while calling "
|
||||
"PorosBackend::CreatePorosValue().");
|
||||
}
|
||||
return poros_value;
|
||||
}
|
||||
|
||||
void CopyTensorToCpu(const at::Tensor& tensor, FDTensor* fd_tensor,
|
||||
bool is_backend_cuda) {
|
||||
const auto data_type = tensor.scalar_type();
|
||||
std::vector<int64_t> shape;
|
||||
auto sizes = tensor.sizes();
|
||||
for (size_t i = 0; i < sizes.size(); i++) {
|
||||
shape.push_back(sizes[i]);
|
||||
}
|
||||
auto fd_dtype = GetFdDtype(data_type);
|
||||
fd_tensor->Resize(shape, fd_dtype);
|
||||
size_t numel = tensor.numel();
|
||||
// at::Tensor -> FDTensor
|
||||
if (data_type == at::kFloat) {
|
||||
if (is_backend_cuda) {
|
||||
cudaMemcpy(fd_tensor->Data(), tensor.data_ptr(), numel * sizeof(float),
|
||||
cudaMemcpyDeviceToHost);
|
||||
} else {
|
||||
memcpy(fd_tensor->Data(), tensor.data_ptr(), numel * sizeof(float));
|
||||
}
|
||||
return;
|
||||
} else if (data_type == at::kInt) {
|
||||
if (is_backend_cuda) {
|
||||
cudaMemcpy(fd_tensor->Data(), tensor.data_ptr(), numel * sizeof(int32_t),
|
||||
cudaMemcpyDeviceToHost);
|
||||
} else {
|
||||
memcpy(fd_tensor->Data(), tensor.data_ptr(), numel * sizeof(int32_t));
|
||||
}
|
||||
return;
|
||||
} else if (data_type == at::kLong) {
|
||||
if (is_backend_cuda) {
|
||||
cudaMemcpy(fd_tensor->Data(), tensor.data_ptr(), numel * sizeof(int64_t),
|
||||
cudaMemcpyDeviceToHost);
|
||||
} else {
|
||||
memcpy(fd_tensor->Data(), tensor.data_ptr(), numel * sizeof(int64_t));
|
||||
}
|
||||
return;
|
||||
} else if (data_type == at::kDouble) {
|
||||
if (is_backend_cuda) {
|
||||
cudaMemcpy(fd_tensor->Data(), tensor.data_ptr(), numel * sizeof(double),
|
||||
cudaMemcpyDeviceToHost);
|
||||
} else {
|
||||
memcpy(fd_tensor->Data(), tensor.data_ptr(), numel * sizeof(double));
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
4
fastdeploy/core/config.h.in
Normal file → Executable file
4
fastdeploy/core/config.h.in
Normal file → Executable file
@@ -29,6 +29,10 @@
|
||||
#cmakedefine ENABLE_PADDLE_BACKEND
|
||||
#endif
|
||||
|
||||
#ifndef ENABLE_POROS_BACKEND
|
||||
#cmakedefine ENABLE_POROS_BACKEND
|
||||
#endif
|
||||
|
||||
#ifndef ENABLE_OPENVINO_BACKEND
|
||||
#cmakedefine ENABLE_OPENVINO_BACKEND
|
||||
#endif
|
||||
|
@@ -24,6 +24,7 @@ void BindRuntime(pybind11::module& m) {
|
||||
.def("use_cpu", &RuntimeOption::UseCpu)
|
||||
.def("set_cpu_thread_num", &RuntimeOption::SetCpuThreadNum)
|
||||
.def("use_paddle_backend", &RuntimeOption::UsePaddleBackend)
|
||||
.def("use_poros_backend", &RuntimeOption::UsePorosBackend)
|
||||
.def("use_ort_backend", &RuntimeOption::UseOrtBackend)
|
||||
.def("set_ort_graph_opt_level", &RuntimeOption::SetOrtGraphOptLevel)
|
||||
.def("use_trt_backend", &RuntimeOption::UseTrtBackend)
|
||||
@@ -62,7 +63,12 @@ void BindRuntime(pybind11::module& m) {
|
||||
.def_readwrite("trt_enable_int8", &RuntimeOption::trt_enable_int8)
|
||||
.def_readwrite("trt_max_batch_size", &RuntimeOption::trt_max_batch_size)
|
||||
.def_readwrite("trt_max_workspace_size",
|
||||
&RuntimeOption::trt_max_workspace_size);
|
||||
&RuntimeOption::trt_max_workspace_size)
|
||||
.def_readwrite("is_dynamic", &RuntimeOption::is_dynamic)
|
||||
.def_readwrite("long_to_int", &RuntimeOption::long_to_int)
|
||||
.def_readwrite("use_nvidia_tf32", &RuntimeOption::use_nvidia_tf32)
|
||||
.def_readwrite("unconst_ops_thres", &RuntimeOption::unconst_ops_thres)
|
||||
.def_readwrite("poros_file", &RuntimeOption::poros_file);
|
||||
|
||||
pybind11::class_<TensorInfo>(m, "TensorInfo")
|
||||
.def_readwrite("name", &TensorInfo::name)
|
||||
@@ -72,6 +78,30 @@ void BindRuntime(pybind11::module& m) {
|
||||
pybind11::class_<Runtime>(m, "Runtime")
|
||||
.def(pybind11::init())
|
||||
.def("init", &Runtime::Init)
|
||||
.def("compile",
|
||||
[](Runtime& self,
|
||||
std::vector<std::vector<pybind11::array>>& warm_datas,
|
||||
const RuntimeOption& _option) {
|
||||
size_t rows = warm_datas.size();
|
||||
size_t columns = warm_datas[0].size();
|
||||
std::vector<std::vector<FDTensor>> warm_tensors(
|
||||
rows, std::vector<FDTensor>(columns));
|
||||
for (size_t i = 0; i < rows; ++i) {
|
||||
for (size_t j = 0; j < columns; ++j) {
|
||||
auto dtype =
|
||||
NumpyDataTypeToFDDataType(warm_datas[i][j].dtype());
|
||||
std::vector<int64_t> data_shape;
|
||||
data_shape.insert(
|
||||
data_shape.begin(), warm_datas[i][j].shape(),
|
||||
warm_datas[i][j].shape() + warm_datas[i][j].ndim());
|
||||
warm_tensors[i][j].Resize(data_shape, dtype);
|
||||
memcpy(warm_tensors[i][j].MutableData(),
|
||||
warm_datas[i][j].mutable_data(),
|
||||
warm_datas[i][j].nbytes());
|
||||
}
|
||||
}
|
||||
return self.Compile(warm_tensors, _option);
|
||||
})
|
||||
.def("infer",
|
||||
[](Runtime& self, std::vector<FDTensor>& inputs) {
|
||||
std::vector<FDTensor> outputs(self.NumOutputs());
|
||||
@@ -121,11 +151,13 @@ void BindRuntime(pybind11::module& m) {
|
||||
.value("UNKOWN", Backend::UNKNOWN)
|
||||
.value("ORT", Backend::ORT)
|
||||
.value("TRT", Backend::TRT)
|
||||
.value("POROS", Backend::POROS)
|
||||
.value("PDINFER", Backend::PDINFER)
|
||||
.value("LITE", Backend::LITE);
|
||||
pybind11::enum_<ModelFormat>(m, "ModelFormat", pybind11::arithmetic(),
|
||||
"ModelFormat for inference.")
|
||||
.value("PADDLE", ModelFormat::PADDLE)
|
||||
.value("TORCHSCRIPT", ModelFormat::TORCHSCRIPT)
|
||||
.value("ONNX", ModelFormat::ONNX);
|
||||
pybind11::enum_<Device>(m, "Device", pybind11::arithmetic(),
|
||||
"Device for inference.")
|
||||
|
@@ -29,6 +29,10 @@
|
||||
#include "fastdeploy/backends/paddle/paddle_backend.h"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_POROS_BACKEND
|
||||
#include "fastdeploy/backends/poros/poros_backend.h"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_OPENVINO_BACKEND
|
||||
#include "fastdeploy/backends/openvino/ov_backend.h"
|
||||
#endif
|
||||
@@ -50,6 +54,9 @@ std::vector<Backend> GetAvailableBackends() {
|
||||
#ifdef ENABLE_PADDLE_BACKEND
|
||||
backends.push_back(Backend::PDINFER);
|
||||
#endif
|
||||
#ifdef ENABLE_POROS_BACKEND
|
||||
backends.push_back(Backend::POROS);
|
||||
#endif
|
||||
#ifdef ENABLE_OPENVINO_BACKEND
|
||||
backends.push_back(Backend::OPENVINO);
|
||||
#endif
|
||||
@@ -76,6 +83,8 @@ std::string Str(const Backend& b) {
|
||||
return "Backend::TRT";
|
||||
} else if (b == Backend::PDINFER) {
|
||||
return "Backend::PDINFER";
|
||||
} else if (b == Backend::POROS) {
|
||||
return "Backend::POROS";
|
||||
} else if (b == Backend::OPENVINO) {
|
||||
return "Backend::OPENVINO";
|
||||
} else if (b == Backend::LITE) {
|
||||
@@ -89,6 +98,8 @@ std::string Str(const ModelFormat& f) {
|
||||
return "ModelFormat::PADDLE";
|
||||
} else if (f == ModelFormat::ONNX) {
|
||||
return "ModelFormat::ONNX";
|
||||
} else if (f == ModelFormat::TORCHSCRIPT) {
|
||||
return "ModelFormat::TORCHSCRIPT";
|
||||
}
|
||||
return "UNKNOWN-ModelFormat";
|
||||
}
|
||||
@@ -102,6 +113,8 @@ std::ostream& operator<<(std::ostream& out, const Backend& backend) {
|
||||
out << "Backend::PDINFER";
|
||||
} else if (backend == Backend::OPENVINO) {
|
||||
out << "Backend::OPENVINO";
|
||||
} else if (backend == Backend::POROS) {
|
||||
out << "Backend::POROS";
|
||||
} else if (backend == Backend::LITE) {
|
||||
out << "Backend::LITE";
|
||||
}
|
||||
@@ -114,6 +127,8 @@ std::ostream& operator<<(std::ostream& out, const ModelFormat& format) {
|
||||
out << "ModelFormat::PADDLE";
|
||||
} else if (format == ModelFormat::ONNX) {
|
||||
out << "ModelFormat::ONNX";
|
||||
} else if (format == ModelFormat::TORCHSCRIPT) {
|
||||
out << "ModelFormat::TORCHSCRIPT";
|
||||
}
|
||||
out << "UNKNOWN-ModelFormat";
|
||||
return out;
|
||||
@@ -137,9 +152,17 @@ bool CheckModelFormat(const std::string& model_file,
|
||||
<< model_file << std::endl;
|
||||
return false;
|
||||
}
|
||||
} else if (model_format == ModelFormat::TORCHSCRIPT) {
|
||||
if (model_file.size() < 3 ||
|
||||
model_file.substr(model_file.size() - 3, 3) != ".pt") {
|
||||
FDERROR << "With model format of ModelFormat::TORCHSCRIPT, the model file "
|
||||
"should ends with `.pt`, but now it's "
|
||||
<< model_file << std::endl;
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
FDERROR << "Only support model format with frontend ModelFormat::PADDLE / "
|
||||
"ModelFormat::ONNX."
|
||||
"ModelFormat::ONNX / ModelFormat::TORCHSCRIPT."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
@@ -155,6 +178,10 @@ ModelFormat GuessModelFormat(const std::string& model_file) {
|
||||
model_file.substr(model_file.size() - 5, 5) == ".onnx") {
|
||||
FDINFO << "Model Format: ONNX." << std::endl;
|
||||
return ModelFormat::ONNX;
|
||||
} else if (model_file.size() > 3 &&
|
||||
model_file.substr(model_file.size() - 3, 3) == ".pt") {
|
||||
FDINFO << "Model Format: Torchscript." << std::endl;
|
||||
return ModelFormat::TORCHSCRIPT;
|
||||
}
|
||||
|
||||
FDERROR << "Cannot guess which model format you are using, please set "
|
||||
@@ -173,10 +200,13 @@ void RuntimeOption::SetModelPath(const std::string& model_path,
|
||||
} else if (format == ModelFormat::ONNX) {
|
||||
model_file = model_path;
|
||||
model_format = ModelFormat::ONNX;
|
||||
} else if (format == ModelFormat::TORCHSCRIPT) {
|
||||
model_file = model_path;
|
||||
model_format = ModelFormat::TORCHSCRIPT;
|
||||
} else {
|
||||
FDASSERT(
|
||||
false,
|
||||
"The model format only can be ModelFormat::PADDLE/ModelFormat::ONNX.");
|
||||
"The model format only can be ModelFormat::PADDLE/ModelFormat::ONNX/ModelFormat::TORCHSCRIPT.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -223,6 +253,15 @@ void RuntimeOption::UseOrtBackend() {
|
||||
#endif
|
||||
}
|
||||
|
||||
// use poros backend
|
||||
void RuntimeOption::UsePorosBackend() {
|
||||
#ifdef ENABLE_POROS_BACKEND
|
||||
backend = Backend::POROS;
|
||||
#else
|
||||
FDASSERT(false, "The FastDeploy didn't compile with PorosBackend.");
|
||||
#endif
|
||||
}
|
||||
|
||||
void RuntimeOption::UseTrtBackend() {
|
||||
#ifdef ENABLE_TRT_BACKEND
|
||||
backend = Backend::TRT;
|
||||
@@ -324,6 +363,36 @@ void RuntimeOption::SetTrtCacheFile(const std::string& cache_file_path) {
|
||||
trt_serialize_file = cache_file_path;
|
||||
}
|
||||
|
||||
bool Runtime::Compile(std::vector<std::vector<FDTensor>>& prewarm_tensors,
|
||||
const RuntimeOption& _option) {
|
||||
#ifdef ENABLE_POROS_BACKEND
|
||||
option = _option;
|
||||
auto poros_option = PorosBackendOption();
|
||||
poros_option.use_gpu = (option.device == Device::GPU) ? true : false;
|
||||
poros_option.gpu_id = option.device_id;
|
||||
poros_option.long_to_int = option.long_to_int;
|
||||
poros_option.use_nvidia_tf32 = option.use_nvidia_tf32;
|
||||
poros_option.unconst_ops_thres = option.unconst_ops_thres;
|
||||
poros_option.poros_file = option.poros_file;
|
||||
poros_option.is_dynamic = option.is_dynamic;
|
||||
poros_option.enable_fp16 = option.trt_enable_fp16;
|
||||
poros_option.max_batch_size = option.trt_max_batch_size;
|
||||
poros_option.max_workspace_size = option.trt_max_workspace_size;
|
||||
FDASSERT(option.model_format == ModelFormat::TORCHSCRIPT,
|
||||
"PorosBackend only support model format of ModelFormat::TORCHSCRIPT.");
|
||||
backend_ = utils::make_unique<PorosBackend>();
|
||||
auto casted_backend = dynamic_cast<PorosBackend*>(backend_.get());
|
||||
FDASSERT(
|
||||
casted_backend->Compile(option.model_file, prewarm_tensors, poros_option),
|
||||
"Load model from Torchscript failed while initliazing PorosBackend.");
|
||||
#else
|
||||
FDASSERT(false,
|
||||
"PorosBackend is not available, please compiled with "
|
||||
"ENABLE_POROS_BACKEND=ON.");
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Runtime::Init(const RuntimeOption& _option) {
|
||||
option = _option;
|
||||
if (option.model_format == ModelFormat::AUTOREC) {
|
||||
@@ -334,6 +403,8 @@ bool Runtime::Init(const RuntimeOption& _option) {
|
||||
option.backend = Backend::ORT;
|
||||
} else if (IsBackendAvailable(Backend::PDINFER)) {
|
||||
option.backend = Backend::PDINFER;
|
||||
} else if (IsBackendAvailable(Backend::POROS)) {
|
||||
option.backend = Backend::POROS;
|
||||
} else if (IsBackendAvailable(Backend::OPENVINO)) {
|
||||
option.backend = Backend::OPENVINO;
|
||||
} else {
|
||||
@@ -365,6 +436,15 @@ bool Runtime::Init(const RuntimeOption& _option) {
|
||||
CreatePaddleBackend();
|
||||
FDINFO << "Runtime initialized with Backend::PDINFER in "
|
||||
<< Str(option.device) << "." << std::endl;
|
||||
} else if (option.backend == Backend::POROS) {
|
||||
FDASSERT(option.device == Device::CPU || option.device == Device::GPU,
|
||||
"Backend::POROS only supports Device::CPU/Device::GPU.");
|
||||
FDASSERT(
|
||||
option.model_format == ModelFormat::TORCHSCRIPT,
|
||||
"Backend::POROS only supports model format of ModelFormat::TORCHSCRIPT.");
|
||||
FDINFO << "Runtime initialized with Backend::POROS in "
|
||||
<< Str(option.device) << "." << std::endl;
|
||||
return true;
|
||||
} else if (option.backend == Backend::OPENVINO) {
|
||||
FDASSERT(option.device == Device::CPU,
|
||||
"Backend::OPENVINO only supports Device::CPU");
|
||||
@@ -379,7 +459,8 @@ bool Runtime::Init(const RuntimeOption& _option) {
|
||||
<< "." << std::endl;
|
||||
} else {
|
||||
FDERROR << "Runtime only support "
|
||||
"Backend::ORT/Backend::TRT/Backend::PDINFER as backend now."
|
||||
"Backend::ORT/Backend::TRT/Backend::PDINFER/Backend::POROS as "
|
||||
"backend now."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
@@ -38,6 +38,7 @@ enum Backend {
|
||||
ORT, ///< ONNX Runtime, support Paddle/ONNX format model, CPU / Nvidia GPU
|
||||
TRT, ///< TensorRT, support Paddle/ONNX format model, Nvidia GPU only
|
||||
PDINFER, ///< Paddle Inference, support Paddle format model, CPU / Nvidia GPU
|
||||
POROS, ///< Poros, support TorchScript format model, CPU / Nvidia GPU
|
||||
OPENVINO, ///< Intel OpenVINO, support Paddle/ONNX format, CPU only
|
||||
LITE, ///< Paddle Lite, support Paddle format model, ARM CPU only
|
||||
};
|
||||
@@ -47,6 +48,7 @@ enum ModelFormat {
|
||||
AUTOREC, ///< Auto recognize the model format by model file name
|
||||
PADDLE, ///< Model with paddlepaddle format
|
||||
ONNX, ///< Model with ONNX format
|
||||
TORCHSCRIPT, ///< Model with TorchScript format
|
||||
};
|
||||
|
||||
FASTDEPLOY_DECL std::ostream& operator<<(std::ostream& out,
|
||||
@@ -117,6 +119,9 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
||||
/// Set TensorRT as inference backend, only support GPU
|
||||
void UseTrtBackend();
|
||||
|
||||
/// Set Poros backend as inference backend, support CPU/GPU
|
||||
void UsePorosBackend();
|
||||
|
||||
/// Set OpenVINO as inference backend, only support CPU
|
||||
void UseOpenVINOBackend();
|
||||
|
||||
@@ -243,6 +248,13 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
||||
size_t trt_max_batch_size = 32;
|
||||
size_t trt_max_workspace_size = 1 << 30;
|
||||
|
||||
// ======Only for Poros Backend=======
|
||||
bool is_dynamic = false;
|
||||
bool long_to_int = true;
|
||||
bool use_nvidia_tf32 = false;
|
||||
int unconst_ops_thres = -1;
|
||||
std::string poros_file = "";
|
||||
|
||||
std::string model_file = ""; // Path of model file
|
||||
std::string params_file = ""; // Path of parameters file, can be empty
|
||||
ModelFormat model_format = ModelFormat::AUTOREC; // format of input model
|
||||
@@ -270,6 +282,15 @@ struct FASTDEPLOY_DECL Runtime {
|
||||
bool Infer(std::vector<FDTensor>& input_tensors,
|
||||
std::vector<FDTensor>* output_tensors);
|
||||
|
||||
/** \brief Compile TorchScript Module, only for Poros backend
|
||||
*
|
||||
* \param[in] prewarm_tensors Prewarm datas for compile
|
||||
* \param[in] _option Runtime option
|
||||
* \return true if compile successed, otherwise false
|
||||
*/
|
||||
bool Compile(std::vector<std::vector<FDTensor>>& prewarm_tensors,
|
||||
const RuntimeOption& _option);
|
||||
|
||||
/** \brief Get number of inputs
|
||||
*/
|
||||
int NumInputs() { return backend_->NumInputs(); }
|
||||
|
Reference in New Issue
Block a user