[Other] Unify initialize api for lite/trt backend (#1249)

* Unify initialize api for lite/trt backend

* Unify initialize api for lite/trt backend
This commit is contained in:
Jason
2023-02-08 11:16:39 +08:00
committed by GitHub
parent 9712f250a5
commit c5b414a774
5 changed files with 94 additions and 75 deletions

View File

@@ -70,14 +70,8 @@ FDDataType GetFDDataType(const nvinfer1::DataType& dtype);
class TrtBackend : public BaseBackend {
public:
TrtBackend() : engine_(nullptr), context_(nullptr) {}
void BuildOption(const TrtBackendOption& option);
bool InitFromPaddle(const std::string& model_buffer,
const std::string& params_buffer,
const TrtBackendOption& option = TrtBackendOption(),
bool verbose = false);
bool InitFromOnnx(const std::string& model_buffer,
const TrtBackendOption& option = TrtBackendOption());
bool Init(const RuntimeOption& runtime_option);
bool Infer(std::vector<FDTensor>& inputs, std::vector<FDTensor>* outputs,
bool copy_to_fd = true) override;
@@ -98,6 +92,15 @@ class TrtBackend : public BaseBackend {
}
private:
void BuildOption(const TrtBackendOption& option);
bool InitFromPaddle(const std::string& model_buffer,
const std::string& params_buffer,
const TrtBackendOption& option = TrtBackendOption(),
bool verbose = false);
bool InitFromOnnx(const std::string& model_buffer,
const TrtBackendOption& option = TrtBackendOption());
TrtBackendOption option_;
std::shared_ptr<nvinfer1::ICudaEngine> engine_;
std::shared_ptr<nvinfer1::IExecutionContext> context_;