[Runtime]FDTensor pybind add from_dlpack interface (#1001)

add from_dlpack
This commit is contained in:
heliqi
2022-12-28 21:58:32 +08:00
committed by GitHub
parent 02e2a5365b
commit ab0929662b
2 changed files with 71 additions and 1 deletions

View File

@@ -166,6 +166,74 @@ pybind11::capsule FDTensorToDLPack(FDTensor& fd_tensor) {
static_cast<void*>(dlpack_tensor), "dltensor", &DeleteUnusedDltensor);
}
FDTensor FDTensorFromDLPack(const std::string& name,
const pybind11::capsule& dlpack_tensor) {
DLManagedTensor* dl_managed_tensor =
static_cast<DLManagedTensor*>(dlpack_tensor.get_pointer());
void* memory_ptr = dl_managed_tensor->dl_tensor.data;
memory_ptr = reinterpret_cast<char*>(memory_ptr) +
dl_managed_tensor->dl_tensor.byte_offset;
int64_t* strides = dl_managed_tensor->dl_tensor.strides;
int ndim = dl_managed_tensor->dl_tensor.ndim;
std::vector<int64_t> dims(
dl_managed_tensor->dl_tensor.shape,
dl_managed_tensor->dl_tensor.shape + ndim);
// Check if the input is contiguous and in C order
if (strides != nullptr) {
int64_t calculated_stride{1};
bool is_contiguous_c_order = true;
for (size_t i = 1; i < dims.size(); i++) {
if (strides[ndim - i] != calculated_stride) {
is_contiguous_c_order = false;
break;
}
calculated_stride *= dims[ndim - i];
}
FDASSERT(is_contiguous_c_order,
"DLPack tensor is not contiguous. Only contiguous DLPack "
"tensors that are stored in C-Order are supported.");
}
Device device;
int32_t device_id = -1;
bool is_pinned_memory = false;
switch (dl_managed_tensor->dl_tensor.device.device_type) {
case DLDeviceType::kDLCUDA:
device = Device::GPU;
device_id = dl_managed_tensor->dl_tensor.device.device_id;
break;
case DLDeviceType::kDLCPU:
device = Device::CPU;
break;
case DLDeviceType::kDLCUDAHost:
device = Device::CPU;
is_pinned_memory = true;
break;
default:
FDASSERT(false,
("DLDevice type " +
std::to_string(dl_managed_tensor->dl_tensor.device.device_type) +
" is not support by Python backend.").c_str());
break;
}
FDDataType dtype =
DlpackToFDType(dl_managed_tensor->dl_tensor.dtype);
PyCapsule_SetName(dlpack_tensor.ptr(), "used_dlpack");
FDTensor fd_tensor(name);
fd_tensor.SetExternalData(
dims, dtype, memory_ptr, device, device_id
);
fd_tensor.is_pinned_memory = is_pinned_memory;
return fd_tensor;
}
void BindFDTensor(pybind11::module& m) {
pybind11::class_<FDTensor>(m, "FDTensor")
@@ -182,6 +250,7 @@ void BindFDTensor(pybind11::module& m) {
PyArrayToTensor(pyarray, &self, share_buffer);
})
.def("to_dlpack", &FDTensorToDLPack)
.def("from_dlpack",&FDTensorFromDLPack)
.def("print_info", &FDTensor::PrintInfo);
}