[Other] Unify initialize api for lite/trt backend (#1249)

* Unify initialize api for lite/trt backend

* Unify initialize api for lite/trt backend
This commit is contained in:
Jason
2023-02-08 11:16:39 +08:00
committed by GitHub
parent 9712f250a5
commit c5b414a774
5 changed files with 94 additions and 75 deletions

View File

@@ -56,18 +56,39 @@ void LiteBackend::BuildOption(const LiteBackendOption& option) {
}
}
bool LiteBackend::InitFromPaddle(const std::string& model_file,
const std::string& params_file,
const LiteBackendOption& option) {
bool LiteBackend::Init(const RuntimeOption& runtime_option) {
if (initialized_) {
FDERROR << "LiteBackend is already initialized, cannot initialize again."
<< std::endl;
return false;
}
config_.set_model_file(model_file);
config_.set_param_file(params_file);
BuildOption(option);
if (runtime_option.model_format != ModelFormat::PADDLE) {
FDERROR
<< "PaddleLiteBackend only supports model format PADDLE, but now it's "
<< runtime_option.model_format << "." << std::endl;
return false;
}
if (runtime_option.device != Device::CPU &&
runtime_option.device != Device::KUNLUNXIN &&
runtime_option.device != Device::ASCEND &&
runtime_option.device != Device::TIMVX) {
FDERROR << "PaddleLiteBackend only supports "
"Device::CPU/Device::TIMVX/Device::KUNLUNXIN/Device::ASCEND, "
"but now it's "
<< runtime_option.device << "." << std::endl;
return false;
}
if (runtime_option.model_from_memory_) {
FDERROR << "PaddleLiteBackend doesn't support load model from memory, "
"please load model from disk."
<< std::endl;
return false;
}
config_.set_model_file(runtime_option.model_file);
config_.set_param_file(runtime_option.params_file);
BuildOption(runtime_option.paddle_lite_option);
predictor_ =
paddle::lite_api::CreatePaddlePredictor<paddle::lite_api::CxxConfig>(
config_);
@@ -177,7 +198,7 @@ bool LiteBackend::Infer(std::vector<FDTensor>& inputs,
FDASSERT(false, "Unexpected data type of %d.", inputs[i].dtype);
}
}
RUNTIME_PROFILE_LOOP_BEGIN(1)
predictor_->Run();
RUNTIME_PROFILE_LOOP_END

View File

@@ -22,6 +22,7 @@
#include "paddle_api.h" // NOLINT
#include "fastdeploy/runtime/backends/backend.h"
#include "fastdeploy/runtime/runtime_option.h"
#include "fastdeploy/runtime/backends/lite/option.h"
namespace fastdeploy {
@@ -30,11 +31,8 @@ class LiteBackend : public BaseBackend {
public:
LiteBackend() {}
virtual ~LiteBackend() = default;
void BuildOption(const LiteBackendOption& option);
bool InitFromPaddle(const std::string& model_file,
const std::string& params_file,
const LiteBackendOption& option = LiteBackendOption());
bool Init(const RuntimeOption& option);
bool Infer(std::vector<FDTensor>& inputs,
std::vector<FDTensor>* outputs,
@@ -50,6 +48,8 @@ class LiteBackend : public BaseBackend {
std::vector<TensorInfo> GetOutputInfos() override;
private:
void BuildOption(const LiteBackendOption& option);
void ConfigureCpu(const LiteBackendOption& option);
void ConfigureTimvx(const LiteBackendOption& option);
void ConfigureAscend(const LiteBackendOption& option);