[Paddle Lite] Support stable-diffusion model (#830)

* support stable-diffusion model for paddlelite

* update code
This commit is contained in:
shentanyue
2022-12-09 13:20:33 +08:00
committed by GitHub
parent b0988bf423
commit 3c05c74513
3 changed files with 14 additions and 8 deletions

View File

@@ -206,24 +206,24 @@ bool LiteBackend::Infer(std::vector<FDTensor>& inputs,
// Adjust dims only, allocate lazy.
tensor->Resize(inputs[i].shape);
if (inputs[i].dtype == FDDataType::FP32) {
tensor->CopyFromCpu<float, paddle::lite_api::TargetType::kARM>(
tensor->CopyFromCpu<float, paddle::lite_api::TargetType::kHost>(
reinterpret_cast<const float*>(const_cast<void*>(
inputs[i].CpuData())));
} else if (inputs[i].dtype == FDDataType::INT32) {
tensor->CopyFromCpu<int, paddle::lite_api::TargetType::kARM>(
tensor->CopyFromCpu<int, paddle::lite_api::TargetType::kHost>(
reinterpret_cast<const int*>(const_cast<void*>(
inputs[i].CpuData())));
} else if (inputs[i].dtype == FDDataType::INT8) {
tensor->CopyFromCpu<int8_t, paddle::lite_api::TargetType::kARM>(
tensor->CopyFromCpu<int8_t, paddle::lite_api::TargetType::kHost>(
reinterpret_cast<const int8_t*>(const_cast<void*>(
inputs[i].CpuData())));
} else if (inputs[i].dtype == FDDataType::UINT8) {
tensor->CopyFromCpu<uint8_t, paddle::lite_api::TargetType::kARM>(
tensor->CopyFromCpu<uint8_t, paddle::lite_api::TargetType::kHost>(
reinterpret_cast<const uint8_t*>(const_cast<void*>(
inputs[i].CpuData())));
} else if (inputs[i].dtype == FDDataType::INT64) {
#ifdef __aarch64__
tensor->CopyFromCpu<int64_t, paddle::lite_api::TargetType::kARM>(
tensor->CopyFromCpu<int64_t, paddle::lite_api::TargetType::kHost>(
reinterpret_cast<const int64_t*>(const_cast<void*>(
inputs[i].CpuData())));
#else

View File

@@ -35,6 +35,8 @@ pybind11::dtype FDDataTypeToNumpyDataType(const FDDataType& fd_dtype) {
dt = pybind11::dtype::of<double>();
} else if (fd_dtype == FDDataType::UINT8) {
dt = pybind11::dtype::of<uint8_t>();
} else if (fd_dtype == FDDataType::INT8) {
dt = pybind11::dtype::of<int8_t>();
} else if (fd_dtype == FDDataType::FP16) {
dt = pybind11::dtype::of<float16>();
} else {
@@ -55,12 +57,14 @@ FDDataType NumpyDataTypeToFDDataType(const pybind11::dtype& np_dtype) {
return FDDataType::FP64;
} else if (np_dtype.is(pybind11::dtype::of<uint8_t>())) {
return FDDataType::UINT8;
} else if (np_dtype.is(pybind11::dtype::of<int8_t>())) {
return FDDataType::INT8;
} else if (np_dtype.is(pybind11::dtype::of<float16>())) {
return FDDataType::FP16;
}
FDASSERT(false,
"NumpyDataTypeToFDDataType() only support "
"int32/int64/float32/float64/float16 now.");
"int8/int32/int64/float32/float64/float16 now.");
return FDDataType::FP32;
}

View File

@@ -67,9 +67,11 @@ FDDataType CTypeToFDDataType() {
return FDDataType::FP32;
} else if (std::is_same<T, double>::value) {
return FDDataType::FP64;
} else if (std::is_same<T, int8_t>::value) {
return FDDataType::INT8;
}
FDASSERT(false,
"CTypeToFDDataType only support int32/int64/float32/float64 now.");
FDASSERT(false, "CTypeToFDDataType only support "
"int8/int32/int64/float32/float64 now.");
return FDDataType::FP32;
}