[Other] Unify initialize api for lite/trt backend (#1249)

* Unify initialize api for lite/trt backend

* Unify initialize api for lite/trt backend
This commit is contained in:
Jason
2023-02-08 11:16:39 +08:00
committed by GitHub
parent 9712f250a5
commit c5b414a774
5 changed files with 94 additions and 75 deletions

View File

@@ -113,6 +113,50 @@ bool TrtBackend::LoadTrtCache(const std::string& trt_engine_file) {
return true;
}
bool TrtBackend::Init(const RuntimeOption& runtime_option) {
if (runtime_option.device != Device::GPU) {
FDERROR << "TrtBackend only supports Device::GPU, but now it's "
<< runtime_option.device << "." << std::endl;
return false;
}
if (runtime_option.model_format != ModelFormat::PADDLE &&
runtime_option.model_format != ModelFormat::ONNX) {
FDERROR
<< "TrtBackend only supports model format PADDLE/ONNX, but now it's "
<< runtime_option.model_format << "." << std::endl;
return false;
}
if (runtime_option.model_format == ModelFormat::PADDLE) {
if (runtime_option.model_from_memory_) {
return InitFromPaddle(runtime_option.model_file,
runtime_option.params_file,
runtime_option.trt_option);
} else {
std::string model_buffer;
std::string params_buffer;
FDASSERT(ReadBinaryFromFile(runtime_option.model_file, &model_buffer),
"Failed to read model file %s.",
runtime_option.model_file.c_str());
FDASSERT(ReadBinaryFromFile(runtime_option.params_file, &params_buffer),
"Failed to read parameters file %s.",
runtime_option.params_file.c_str());
return InitFromPaddle(model_buffer, params_buffer,
runtime_option.trt_option);
}
} else {
if (runtime_option.model_from_memory_) {
return InitFromOnnx(runtime_option.model_file, runtime_option.trt_option);
} else {
std::string model_buffer;
FDASSERT(ReadBinaryFromFile(runtime_option.model_file, &model_buffer),
"Failed to read model file %s.",
runtime_option.model_file.c_str());
return InitFromOnnx(model_buffer, runtime_option.trt_option);
}
}
return true;
}
bool TrtBackend::InitFromPaddle(const std::string& model_buffer,
const std::string& params_buffer,
const TrtBackendOption& option, bool verbose) {
@@ -291,14 +335,14 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
cudaSetDevice(option_.gpu_id);
SetInputs(inputs);
AllocateOutputsBuffer(outputs, copy_to_fd);
RUNTIME_PROFILE_LOOP_BEGIN(1)
if (!context_->enqueueV2(bindings_.data(), stream_, nullptr)) {
FDERROR << "Failed to Infer with TensorRT." << std::endl;
return false;
}
RUNTIME_PROFILE_LOOP_END
for (size_t i = 0; i < outputs->size(); ++i) {
// if the final output tensor's dtype is different from the model output
// tensor's dtype, then we need cast the data to the final output's dtype

View File

@@ -70,14 +70,8 @@ FDDataType GetFDDataType(const nvinfer1::DataType& dtype);
class TrtBackend : public BaseBackend {
public:
TrtBackend() : engine_(nullptr), context_(nullptr) {}
void BuildOption(const TrtBackendOption& option);
bool InitFromPaddle(const std::string& model_buffer,
const std::string& params_buffer,
const TrtBackendOption& option = TrtBackendOption(),
bool verbose = false);
bool InitFromOnnx(const std::string& model_buffer,
const TrtBackendOption& option = TrtBackendOption());
bool Init(const RuntimeOption& runtime_option);
bool Infer(std::vector<FDTensor>& inputs, std::vector<FDTensor>* outputs,
bool copy_to_fd = true) override;
@@ -98,6 +92,15 @@ class TrtBackend : public BaseBackend {
}
private:
void BuildOption(const TrtBackendOption& option);
bool InitFromPaddle(const std::string& model_buffer,
const std::string& params_buffer,
const TrtBackendOption& option = TrtBackendOption(),
bool verbose = false);
bool InitFromOnnx(const std::string& model_buffer,
const TrtBackendOption& option = TrtBackendOption());
TrtBackendOption option_;
std::shared_ptr<nvinfer1::ICudaEngine> engine_;
std::shared_ptr<nvinfer1::IExecutionContext> context_;