mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 09:31:35 +08:00
[Paddle Lite] Support stable-diffusion model (#830)
* support stable-diffusion model for paddlelite * update code
This commit is contained in:
@@ -206,24 +206,24 @@ bool LiteBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
// Adjust dims only, allocate lazy.
|
||||
tensor->Resize(inputs[i].shape);
|
||||
if (inputs[i].dtype == FDDataType::FP32) {
|
||||
tensor->CopyFromCpu<float, paddle::lite_api::TargetType::kARM>(
|
||||
tensor->CopyFromCpu<float, paddle::lite_api::TargetType::kHost>(
|
||||
reinterpret_cast<const float*>(const_cast<void*>(
|
||||
inputs[i].CpuData())));
|
||||
} else if (inputs[i].dtype == FDDataType::INT32) {
|
||||
tensor->CopyFromCpu<int, paddle::lite_api::TargetType::kARM>(
|
||||
tensor->CopyFromCpu<int, paddle::lite_api::TargetType::kHost>(
|
||||
reinterpret_cast<const int*>(const_cast<void*>(
|
||||
inputs[i].CpuData())));
|
||||
} else if (inputs[i].dtype == FDDataType::INT8) {
|
||||
tensor->CopyFromCpu<int8_t, paddle::lite_api::TargetType::kARM>(
|
||||
tensor->CopyFromCpu<int8_t, paddle::lite_api::TargetType::kHost>(
|
||||
reinterpret_cast<const int8_t*>(const_cast<void*>(
|
||||
inputs[i].CpuData())));
|
||||
} else if (inputs[i].dtype == FDDataType::UINT8) {
|
||||
tensor->CopyFromCpu<uint8_t, paddle::lite_api::TargetType::kARM>(
|
||||
tensor->CopyFromCpu<uint8_t, paddle::lite_api::TargetType::kHost>(
|
||||
reinterpret_cast<const uint8_t*>(const_cast<void*>(
|
||||
inputs[i].CpuData())));
|
||||
} else if (inputs[i].dtype == FDDataType::INT64) {
|
||||
#ifdef __aarch64__
|
||||
tensor->CopyFromCpu<int64_t, paddle::lite_api::TargetType::kARM>(
|
||||
tensor->CopyFromCpu<int64_t, paddle::lite_api::TargetType::kHost>(
|
||||
reinterpret_cast<const int64_t*>(const_cast<void*>(
|
||||
inputs[i].CpuData())));
|
||||
#else
|
||||
|
@@ -35,6 +35,8 @@ pybind11::dtype FDDataTypeToNumpyDataType(const FDDataType& fd_dtype) {
|
||||
dt = pybind11::dtype::of<double>();
|
||||
} else if (fd_dtype == FDDataType::UINT8) {
|
||||
dt = pybind11::dtype::of<uint8_t>();
|
||||
} else if (fd_dtype == FDDataType::INT8) {
|
||||
dt = pybind11::dtype::of<int8_t>();
|
||||
} else if (fd_dtype == FDDataType::FP16) {
|
||||
dt = pybind11::dtype::of<float16>();
|
||||
} else {
|
||||
@@ -55,12 +57,14 @@ FDDataType NumpyDataTypeToFDDataType(const pybind11::dtype& np_dtype) {
|
||||
return FDDataType::FP64;
|
||||
} else if (np_dtype.is(pybind11::dtype::of<uint8_t>())) {
|
||||
return FDDataType::UINT8;
|
||||
} else if (np_dtype.is(pybind11::dtype::of<int8_t>())) {
|
||||
return FDDataType::INT8;
|
||||
} else if (np_dtype.is(pybind11::dtype::of<float16>())) {
|
||||
return FDDataType::FP16;
|
||||
}
|
||||
FDASSERT(false,
|
||||
"NumpyDataTypeToFDDataType() only support "
|
||||
"int32/int64/float32/float64/float16 now.");
|
||||
"int8/int32/int64/float32/float64/float16 now.");
|
||||
return FDDataType::FP32;
|
||||
}
|
||||
|
||||
|
@@ -67,9 +67,11 @@ FDDataType CTypeToFDDataType() {
|
||||
return FDDataType::FP32;
|
||||
} else if (std::is_same<T, double>::value) {
|
||||
return FDDataType::FP64;
|
||||
} else if (std::is_same<T, int8_t>::value) {
|
||||
return FDDataType::INT8;
|
||||
}
|
||||
FDASSERT(false,
|
||||
"CTypeToFDDataType only support int32/int64/float32/float64 now.");
|
||||
FDASSERT(false, "CTypeToFDDataType only support "
|
||||
"int8/int32/int64/float32/float64 now.");
|
||||
return FDDataType::FP32;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user