mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-11 11:30:20 +08:00
[Runtime]FDTensor pybind add from_dlpack interface (#1001)
add from_dlpack
This commit is contained in:
@@ -166,6 +166,74 @@ pybind11::capsule FDTensorToDLPack(FDTensor& fd_tensor) {
|
||||
static_cast<void*>(dlpack_tensor), "dltensor", &DeleteUnusedDltensor);
|
||||
}
|
||||
|
||||
FDTensor FDTensorFromDLPack(const std::string& name,
|
||||
const pybind11::capsule& dlpack_tensor) {
|
||||
DLManagedTensor* dl_managed_tensor =
|
||||
static_cast<DLManagedTensor*>(dlpack_tensor.get_pointer());
|
||||
|
||||
void* memory_ptr = dl_managed_tensor->dl_tensor.data;
|
||||
memory_ptr = reinterpret_cast<char*>(memory_ptr) +
|
||||
dl_managed_tensor->dl_tensor.byte_offset;
|
||||
|
||||
int64_t* strides = dl_managed_tensor->dl_tensor.strides;
|
||||
|
||||
int ndim = dl_managed_tensor->dl_tensor.ndim;
|
||||
std::vector<int64_t> dims(
|
||||
dl_managed_tensor->dl_tensor.shape,
|
||||
dl_managed_tensor->dl_tensor.shape + ndim);
|
||||
|
||||
// Check if the input is contiguous and in C order
|
||||
if (strides != nullptr) {
|
||||
int64_t calculated_stride{1};
|
||||
bool is_contiguous_c_order = true;
|
||||
for (size_t i = 1; i < dims.size(); i++) {
|
||||
if (strides[ndim - i] != calculated_stride) {
|
||||
is_contiguous_c_order = false;
|
||||
break;
|
||||
}
|
||||
|
||||
calculated_stride *= dims[ndim - i];
|
||||
}
|
||||
|
||||
FDASSERT(is_contiguous_c_order,
|
||||
"DLPack tensor is not contiguous. Only contiguous DLPack "
|
||||
"tensors that are stored in C-Order are supported.");
|
||||
}
|
||||
|
||||
Device device;
|
||||
int32_t device_id = -1;
|
||||
bool is_pinned_memory = false;
|
||||
switch (dl_managed_tensor->dl_tensor.device.device_type) {
|
||||
case DLDeviceType::kDLCUDA:
|
||||
device = Device::GPU;
|
||||
device_id = dl_managed_tensor->dl_tensor.device.device_id;
|
||||
break;
|
||||
case DLDeviceType::kDLCPU:
|
||||
device = Device::CPU;
|
||||
break;
|
||||
case DLDeviceType::kDLCUDAHost:
|
||||
device = Device::CPU;
|
||||
is_pinned_memory = true;
|
||||
break;
|
||||
default:
|
||||
FDASSERT(false,
|
||||
("DLDevice type " +
|
||||
std::to_string(dl_managed_tensor->dl_tensor.device.device_type) +
|
||||
" is not support by Python backend.").c_str());
|
||||
break;
|
||||
}
|
||||
|
||||
FDDataType dtype =
|
||||
DlpackToFDType(dl_managed_tensor->dl_tensor.dtype);
|
||||
|
||||
PyCapsule_SetName(dlpack_tensor.ptr(), "used_dlpack");
|
||||
FDTensor fd_tensor(name);
|
||||
fd_tensor.SetExternalData(
|
||||
dims, dtype, memory_ptr, device, device_id
|
||||
);
|
||||
fd_tensor.is_pinned_memory = is_pinned_memory;
|
||||
return fd_tensor;
|
||||
}
|
||||
|
||||
void BindFDTensor(pybind11::module& m) {
|
||||
pybind11::class_<FDTensor>(m, "FDTensor")
|
||||
@@ -182,6 +250,7 @@ void BindFDTensor(pybind11::module& m) {
|
||||
PyArrayToTensor(pyarray, &self, share_buffer);
|
||||
})
|
||||
.def("to_dlpack", &FDTensorToDLPack)
|
||||
.def("from_dlpack",&FDTensorFromDLPack)
|
||||
.def("print_info", &FDTensor::PrintInfo);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user