fix some usage problem in linux (#25)

Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com>
This commit is contained in:
Jason
2022-07-19 11:47:13 +08:00
committed by GitHub
parent e7c6a9d346
commit f670520bf8
5 changed files with 56 additions and 17 deletions

View File

@@ -40,12 +40,15 @@ void BindRuntime(pybind11::module& m) {
.def_readwrite("trt_max_batch_size", &RuntimeOption::trt_max_batch_size)
.def_readwrite("trt_max_workspace_size",
&RuntimeOption::trt_max_workspace_size);
pybind11::class_<TensorInfo>(m, "TensorInfo")
.def_readwrite("name", &TensorInfo::name)
.def_readwrite("shape", &TensorInfo::shape)
.def_readwrite("dtype", &TensorInfo::dtype);
pybind11::class_<Runtime>(m, "Runtime")
.def(pybind11::init([](RuntimeOption& option) {
Runtime* runtime = new Runtime();
runtime->Init(option);
return runtime;
}))
.def(pybind11::init())
.def("init", &Runtime::Init)
.def("infer", [](Runtime& self,
std::map<std::string, pybind11::array>& data) {
std::vector<FDTensor> inputs(data.size());
@@ -75,7 +78,12 @@ void BindRuntime(pybind11::module& m) {
outputs[i].Numel() * FDDataTypeSize(outputs[i].dtype));
}
return results;
});
})
.def("num_inputs", &Runtime::NumInputs)
.def("num_outputs", &Runtime::NumOutputs)
.def("get_input_info", &Runtime::GetInputInfo)
.def("get_output_info", &Runtime::GetOutputInfo)
.def_readonly("option", &Runtime::option);
pybind11::enum_<Backend>(m, "Backend", pybind11::arithmetic(),
"Backend for inference.")
@@ -103,11 +111,6 @@ void BindRuntime(pybind11::module& m) {
.value("FP64", FDDataType::FP64)
.value("UINT8", FDDataType::UINT8);
pybind11::class_<TensorInfo>(m, "TensorInfo")
.def_readwrite("name", &TensorInfo::name)
.def_readwrite("shape", &TensorInfo::shape)
.def_readwrite("dtype", &TensorInfo::dtype);
m.def("get_available_backends", []() { return GetAvailableBackends(); });
}