diff --git a/examples/multimodal/stable_diffusion/cpp/CMakeLists.txt b/examples/multimodal/stable_diffusion/cpp/CMakeLists.txt new file mode 100644 index 000000000..33b4afb93 --- /dev/null +++ b/examples/multimodal/stable_diffusion/cpp/CMakeLists.txt @@ -0,0 +1,27 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +PROJECT(main C CXX) +CMAKE_MINIMUM_REQUIRED (VERSION 3.10) + +option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.") +set(THIRD_LIBS "") +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) + +include_directories(${FASTDEPLOY_INCS}) + +file(GLOB_RECURSE ALL_SRCS ${PROJECT_SOURCE_DIR}/*.cc) + +add_executable(main ${ALL_SRCS}) +target_link_libraries(main ${FASTDEPLOY_LIBS} ${THIRD_LIBS}) diff --git a/examples/multimodal/stable_diffusion/cpp/dpm_solver_multistep_scheduler.cc b/examples/multimodal/stable_diffusion/cpp/dpm_solver_multistep_scheduler.cc new file mode 100644 index 000000000..cb6cf970b --- /dev/null +++ b/examples/multimodal/stable_diffusion/cpp/dpm_solver_multistep_scheduler.cc @@ -0,0 +1,395 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dpm_solver_multistep_scheduler.h" +#include "fastdeploy/core/fd_scalar.h" +#include "fastdeploy/function/functions.h" +#include +#include + +namespace fastdeploy { + +void DPMSolverMultistepScheduler::BetaForAlphaBar(FDTensor* out, + int num_diffusion_timesteps, + float max_beta) { + auto alpha_bar = [](float time_step) -> float { + constexpr float pi = 3.14159265358979323846; + return std::pow(std::cos((time_step + 0.008) / 1.008 * pi / 2), 2); + }; + std::vector betas; + for (int i = 0; i < num_diffusion_timesteps; ++i) { + float t1 = i / num_diffusion_timesteps; + float t2 = (i + 1) / num_diffusion_timesteps; + float beta_val = (std::min)(1 - alpha_bar(t1) / alpha_bar(t2), max_beta); + betas.emplace_back(Scalar(beta_val)); + } + function::Concat(betas, out); +} + +DPMSolverMultistepScheduler::DPMSolverMultistepScheduler( + int num_train_timesteps, float beta_start, float beta_end, + const std::string& beta_schedule, const std::vector& trained_betas, + int solver_order, bool predict_epsilon, bool thresholding, + float dynamic_thresholding_ratio, float sample_max_value, + const std::string& algorithm_type, const std::string& solver_type, + bool lower_order_final) + : config({num_train_timesteps, beta_start, beta_end, beta_schedule, + solver_order, predict_epsilon, thresholding, + dynamic_thresholding_ratio, sample_max_value, algorithm_type, + solver_type, lower_order_final}) { + int beta_size = trained_betas.size(); + if (beta_size > 0) { + betas_.Allocate({beta_size}, FDDataType::FP32); + std::copy(trained_betas.data(), trained_betas.data() + beta_size, + reinterpret_cast(betas_.Data())); + } else if (beta_schedule == "linear") { + function::Linspace(beta_start, beta_end, num_train_timesteps, &betas_, + FDDataType::FP32); + } else if (beta_schedule == "scaled_linear") { + function::Linspace(beta_start, beta_end, num_train_timesteps, &betas_, + FDDataType::FP32); + betas_ = betas_ * betas_; + } else if (beta_schedule == "squaredcos_cap_v2") { + BetaForAlphaBar(&betas_, num_train_timesteps); + } else { + FDASSERT(false, "%s is not implemented for DPMSolverMultistepScheduler", + beta_schedule.c_str()); + } + + alphas_ = 1.0f - betas_; + function::Cumprod(alphas_, &alphas_cumprod_); + function::Sqrt(alphas_cumprod_, &alpha_t_); + function::Sqrt(1.0f - alphas_cumprod_, &sigma_t_); + FDTensor alpha_t_log, sigma_t_log; + function::Log(alpha_t_, &alpha_t_log); + function::Log(sigma_t_, &sigma_t_log); + lambda_t_ = alpha_t_log - sigma_t_log; + + FDASSERT(config.algorithm_type_ == "dpmsolver" || + config.algorithm_type_ == "dpmsolver++", + "%s does is not implemented for DPMSolverMultistepScheduler", + config.algorithm_type_.c_str()); + FDASSERT(config.solver_type_ == "midpoint" || config.solver_type_ == "heun", + "%s does is not implemented for DPMSolverMultistepScheduler", + config.solver_type_.c_str()); + num_inference_steps_ = -1; + + function::Linspace(0, config.num_train_timesteps_ - 1, + config.num_train_timesteps_, ×teps_); + function::Cast(timesteps_, ×teps_, FDDataType::INT64); + // Reverse timesteps + int64_t* timesteps_data = reinterpret_cast(timesteps_.Data()); + std::reverse(timesteps_data, timesteps_data + timesteps_.Numel()); + + model_outputs_.resize(config.solver_order_); + lower_order_nums_ = 0; +} + +void DPMSolverMultistepScheduler::ConvertModelOutput( + const FDTensor& model_output, int timestep, const FDTensor& sample, + FDTensor* out) { + if (config.algorithm_type_ == "dpmsolver++") { + FDTensor x0_pred; + if (config.predict_epsilon_) { + FDTensor alpha_t, sigma_t; + function::Slice(alpha_t_, {0}, {timestep}, &alpha_t); + function::Slice(sigma_t_, {0}, {timestep}, &sigma_t); + x0_pred = (sample - sigma_t * model_output) / alpha_t; + } else { + x0_pred = model_output; + } + if (config.thresholding_) { + FDTensor dynamic_max_val, x0_pred_abs; + function::Abs(x0_pred, &x0_pred_abs); + x0_pred_abs.Reshape({x0_pred_abs.Shape()[0], -1}); + function::Quantile(x0_pred_abs, {config.dynamic_thresholding_ratio_}, {1}, + &dynamic_max_val); + + FDTensor max_value, dy_max_val; + function::FullLike(dynamic_max_val, config.sample_max_value_, &max_value, + dynamic_max_val.Dtype()); + function::Maximum(dynamic_max_val, max_value, &dy_max_val); + int expand_dims = x0_pred.Shape().size() - 1; + for (int i = 0; i < expand_dims; ++i) { + dy_max_val.ExpandDim(dy_max_val.Shape().size()); + } + float clip_max = reinterpret_cast(dy_max_val.Data())[0]; + function::Clip(x0_pred, -clip_max, clip_max, &x0_pred); + x0_pred = x0_pred / dy_max_val; + } + *out = std::move(x0_pred); + } else if (config.algorithm_type_ == "dpmsolver") { + if (config.predict_epsilon_) { + *out = model_output; + } else { + FDTensor alpha_t, sigma_t; + function::Slice(alpha_t_, {0}, {timestep}, &alpha_t); + function::Slice(sigma_t_, {0}, {timestep}, &sigma_t); + *out = (sample - (alpha_t * model_output)) / sigma_t; + } + } +} + +void DPMSolverMultistepScheduler::DPMSolverFirstOrderUpdate( + const FDTensor& model_output, int timestep, int prev_timestep, + const FDTensor& sample, FDTensor* out) { + FDTensor lambda_t, lambda_s; + function::Slice(lambda_t_, {0}, {prev_timestep}, &lambda_t); + function::Slice(lambda_t_, {0}, {timestep}, &lambda_s); + + FDTensor alpha_t, alpha_s; + function::Slice(alpha_t_, {0}, {prev_timestep}, &alpha_t); + function::Slice(alpha_t_, {0}, {timestep}, &alpha_s); + + FDTensor sigma_t, sigma_s; + function::Slice(sigma_t_, {0}, {prev_timestep}, &sigma_t); + function::Slice(sigma_t_, {0}, {timestep}, &sigma_s); + + FDTensor h = lambda_t - lambda_s; + if (config.algorithm_type_ == "dpmsolver++") { + function::Exp(0.0f - h, &h); + *out = (sigma_t / sigma_s) * sample - (alpha_t * (h - 1.0f)) * model_output; + } else if (config.algorithm_type_ == "dpmsolver") { + function::Exp(h, &h); + *out = (alpha_t / alpha_s) * sample - (sigma_t * (h - 1.0f)) * model_output; + } +} + +void DPMSolverMultistepScheduler::MultiStepDPMSolverSecondOrderUpdate( + const std::vector& model_output_list, + const std::vector& timestep_list, int prev_timestep, + const FDTensor& sample, FDTensor* out) { + int timestep_size = timestep_list.size(); + int model_output_size = model_output_list.size(); + int t = prev_timestep; + int s0 = timestep_list[timestep_size - 1]; + int s1 = timestep_list[timestep_size - 2]; + const FDTensor& m0 = model_output_list[model_output_size - 1]; + const FDTensor& m1 = model_output_list[model_output_size - 2]; + FDTensor lambda_t, lambda_s0, lambda_s1; + function::Slice(lambda_t_, {0}, {t}, &lambda_t); + function::Slice(lambda_t_, {0}, {s0}, &lambda_s0); + function::Slice(lambda_t_, {0}, {s1}, &lambda_s1); + + FDTensor alpha_t, alpha_s0, sigma_t, sigma_s0; + function::Slice(alpha_t_, {0}, {t}, &alpha_t); + function::Slice(alpha_t_, {0}, {s0}, &alpha_s0); + function::Slice(sigma_t_, {0}, {t}, &sigma_t); + function::Slice(sigma_t_, {0}, {s0}, &sigma_s0); + + FDTensor h = lambda_t - lambda_s0; + FDTensor h0 = lambda_s0 - lambda_s1; + FDTensor r0 = h0 / h; + FDTensor D0 = m0; + FDTensor D1 = (1.0f / r0) * (m0 - m1); + if (config.algorithm_type_ == "dpmsolver++") { + if (config.solver_type_ == "midpoint") { + function::Exp(0.0f - h, &h); + *out = (sigma_t / sigma_s0 * sample) - (alpha_t * (h - 1.0f) * D0) - + (0.5f * alpha_t * (h - 1.0f) * D1); + } else if (config.solver_type_ == "heun") { + FDTensor h_exp; + function::Exp(0.0f - h, &h_exp); + *out = (sigma_t / sigma_s0 * sample) - (alpha_t * (h_exp - 1.0f) * D0) + + (alpha_t * ((h_exp - 1.0f) / h + 1.0f) * D1); + } + } else if (config.algorithm_type_ == "dpmsolver") { + FDTensor h_exp; + function::Exp(h, &h_exp); + if (config.solver_type_ == "midpoint") { + *out = alpha_t / alpha_s0 * sample - sigma_t * (h_exp - 1.0f) * D0 - + 0.5 * (sigma_t * (h_exp - 1.0f) * D1); + } else if (config.solver_type_ == "heun") { + *out = alpha_t / alpha_s0 * sample - sigma_t * (h_exp - 1.0f) * D0 - + (sigma_t * ((h_exp - 1.0f) / h - 1.0f) * D1); + } + } +} + +void DPMSolverMultistepScheduler::MultiStepDPMSolverThirdOrderUpdate( + const std::vector& model_output_list, + const std::vector& timestep_list, int prev_timestep, + const FDTensor& sample, FDTensor* out) { + int timestep_size = timestep_list.size(); + int model_output_size = model_output_list.size(); + int t = prev_timestep; + + int s0 = timestep_list[timestep_size - 1]; + int s1 = timestep_list[timestep_size - 2]; + int s2 = timestep_list[timestep_size - 3]; + const FDTensor& m0 = model_output_list[model_output_size - 1]; + const FDTensor& m1 = model_output_list[model_output_size - 2]; + const FDTensor& m2 = model_output_list[model_output_size - 3]; + + FDTensor lambda_t, lambda_s0, lambda_s1, lambda_s2; + function::Slice(lambda_t_, {0}, {t}, &lambda_t); + function::Slice(lambda_t_, {0}, {s0}, &lambda_s0); + function::Slice(lambda_t_, {0}, {s1}, &lambda_s1); + function::Slice(lambda_t_, {0}, {s2}, &lambda_s2); + + FDTensor alpha_t, alpha_s0, sigma_t, sigma_s0; + function::Slice(alpha_t_, {0}, {t}, &alpha_t); + function::Slice(alpha_t_, {0}, {s0}, &alpha_s0); + function::Slice(sigma_t_, {0}, {t}, &sigma_t); + function::Slice(sigma_t_, {0}, {s0}, &sigma_s0); + + FDTensor h = lambda_t - lambda_s0; + FDTensor h0 = lambda_s0 - lambda_s1; + FDTensor h1 = lambda_s1 - lambda_s2; + + FDTensor r0 = h0 / h; + FDTensor r1 = h1 / h; + FDTensor D0 = m0; + FDTensor D1_0 = (1.0f / r0) * (m0 - m1); + FDTensor D1_1 = (1.0f / r1) * (m1 - m2); + FDTensor D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1); + FDTensor D2 = (1.0f / (r0 + r1)) * (D1_0 - D1_1); + + if (config.algorithm_type_ == "dpmsolver++") { + FDTensor h_exp; + function::Exp(0.0f - h, &h_exp); + *out = (sigma_t / sigma_s0) * sample - (alpha_t * (h_exp - 1.0f)) * D0 + + (alpha_t * ((h_exp - 1.0) / h + 1.0)) * D1 - + (alpha_t * ((h_exp - 1.0 + h) / (h * h) - 0.5)) * D2; + + } else if (config.algorithm_type_ == "dpmsolver") { + FDTensor h_exp; + function::Exp(h, &h_exp); + *out = (alpha_t / alpha_s0) * sample - (sigma_t * (h_exp - 1.0f)) * D0 + + (sigma_t * ((h_exp - 1.0) / h - 1.0)) * D1 - + (sigma_t * ((h_exp - 1.0 - h) / (h * h) - 0.5)) * D2; + } +} + +void DPMSolverMultistepScheduler::ScaleModelInput( + const FDTensor& sample, FDTensor* out, + const std::vector& timesteps) { + *out = sample; +} + +void DPMSolverMultistepScheduler::SetTimesteps(int num_inference_steps) { + num_inference_steps_ = num_inference_steps; + function::Linspace(0, config.num_train_timesteps_ - 1, + num_inference_steps + 1, ×teps_); + function::Round(timesteps_, ×teps_); + // Reverse timesteps + float* timesteps_data = reinterpret_cast(timesteps_.Data()); + std::reverse(timesteps_data, timesteps_data + timesteps_.Numel()); + FDTensor timestep_tmp; + timestep_tmp.Allocate({num_inference_steps}, timesteps_.Dtype()); + float* timestep_tmp_data = reinterpret_cast(timestep_tmp.Data()); + std::copy(timesteps_data, timesteps_data + num_inference_steps, + timestep_tmp_data); + timesteps_ = std::move(timestep_tmp); + + function::Cast(timesteps_, ×teps_, FDDataType::INT64); + + model_outputs_.clear(); + model_outputs_.resize(config.solver_order_); + + lower_order_nums_ = 0; +} + +void DPMSolverMultistepScheduler::Step(const FDTensor& model_output, + int timestep, const FDTensor& sample, + FDTensor* prev_sample) { + FDASSERT(num_inference_steps_ > -1, + "Number of inference steps is -1, you need to run SetTimesteps " + "after creating the scheduler"); + int64_t step_index = timesteps_.Numel() - 1; + int64_t* timesteps_data = reinterpret_cast(timesteps_.Data()); + int64_t* timesteps_iter = + std::find(timesteps_data, timesteps_data + timesteps_.Numel(), timestep); + if (timesteps_iter - timesteps_data < timesteps_.Numel()) { + step_index = timesteps_iter - timesteps_data; + } + + int64_t prev_timestep = 0; + if (step_index != timesteps_.Numel() - 1) { + prev_timestep = timesteps_data[step_index + 1]; + } + bool lower_order_final = (step_index == timesteps_.Numel() - 1) && + config.lower_order_final_ && + (timesteps_.Numel() < 15); + bool lower_order_second = (step_index == timesteps_.Numel() - 2) && + config.lower_order_final_ && + (timesteps_.Numel() < 15); + FDTensor model_out; + ConvertModelOutput(model_output, timestep, sample, &model_out); + for (int i = 0; i < config.solver_order_ - 1; ++i) { + model_outputs_[i] = std::move(model_outputs_[i + 1]); + } + model_outputs_[config.solver_order_ - 1] = std::move(model_out); + + if (config.solver_order_ == 1 || lower_order_nums_ < 1 || lower_order_final) { + DPMSolverFirstOrderUpdate(model_outputs_[config.solver_order_ - 1], + timestep, prev_timestep, sample, prev_sample); + } else if (config.solver_order_ == 2 || lower_order_nums_ < 2 || + lower_order_second) { + int t0 = reinterpret_cast(timesteps_.Data())[step_index - 1]; + std::vector timestep_list = {t0, timestep}; + MultiStepDPMSolverSecondOrderUpdate(model_outputs_, timestep_list, + prev_timestep, sample, prev_sample); + } else { + int t0 = reinterpret_cast(timesteps_.Data())[step_index - 1]; + int t1 = reinterpret_cast(timesteps_.Data())[step_index - 2]; + std::vector timestep_list = {t1, t0, timestep}; + MultiStepDPMSolverThirdOrderUpdate(model_outputs_, timestep_list, + prev_timestep, sample, prev_sample); + } + + if (lower_order_nums_ < config.solver_order_) { + lower_order_nums_ += 1; + } +} + +void DPMSolverMultistepScheduler::AddNoise(const FDTensor& original_samples, + const FDTensor& noise, + const FDTensor& timesteps, + FDTensor* out) { + function::Cast(alphas_cumprod_, &alphas_cumprod_, original_samples.Dtype()); + + const int64_t* timesteps_data = + reinterpret_cast(timesteps.Data()); + std::vector timesteps_vec; + for (int i = 0; i < timesteps.Numel(); ++i) { + timesteps_vec.push_back(timesteps_data[i]); + } + FDTensor sqrt_alpha_prod; + function::Slice(alphas_cumprod_, {0}, timesteps_vec, &sqrt_alpha_prod); + function::Sqrt(sqrt_alpha_prod, &sqrt_alpha_prod); + sqrt_alpha_prod.Reshape({-1}); + int rank_diff = + original_samples.Shape().size() - sqrt_alpha_prod.Shape().size(); + for (int i = 0; i < rank_diff; ++i) { + int curr_rank = sqrt_alpha_prod.Shape().size(); + sqrt_alpha_prod.ExpandDim(curr_rank - 1); + } + + FDTensor sqrt_one_minus_alpha_prod; + function::Slice(alphas_cumprod_, {0}, timesteps_vec, + &sqrt_one_minus_alpha_prod); + sqrt_one_minus_alpha_prod = 1.0f - sqrt_one_minus_alpha_prod; + function::Sqrt(sqrt_one_minus_alpha_prod, &sqrt_one_minus_alpha_prod); + sqrt_one_minus_alpha_prod.Reshape({-1}); + rank_diff = original_samples.Shape().size() - + sqrt_one_minus_alpha_prod.Shape().size(); + for (int i = 0; i < rank_diff; ++i) { + int curr_rank = sqrt_one_minus_alpha_prod.Shape().size(); + sqrt_one_minus_alpha_prod.ExpandDim(curr_rank - 1); + } + *out = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise; +} + +} // namespace fastdeploy diff --git a/examples/multimodal/stable_diffusion/cpp/dpm_solver_multistep_scheduler.h b/examples/multimodal/stable_diffusion/cpp/dpm_solver_multistep_scheduler.h new file mode 100644 index 000000000..c6f037fee --- /dev/null +++ b/examples/multimodal/stable_diffusion/cpp/dpm_solver_multistep_scheduler.h @@ -0,0 +1,85 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "./scheduler.h" +#include "fastdeploy/core/fd_tensor.h" + +namespace fastdeploy { + +class DPMSolverMultistepScheduler : public Scheduler { + public: + DPMSolverMultistepScheduler(int num_train_timesteps = 1000, + float beta_start = 0.0001, float beta_end = 0.02, + const std::string& beta_schedule = "linear", + const std::vector& trained_betas = {}, + int solver_order = 2, bool predict_epsilon = true, + bool thresholding = false, + float dynamic_thresholding_ratio = 0.995, + float sample_max_value = 1.0, + const std::string& algorithm_type = "dpmsolver++", + const std::string& solver_type = "midpoint", + bool lower_order_final = true); + void BetaForAlphaBar(FDTensor* out, int num_diffusion_timesteps, + float max_beta = 0.999); + void ConvertModelOutput(const FDTensor& model_output, int timestep, + const FDTensor& sample, FDTensor* out); + void DPMSolverFirstOrderUpdate(const FDTensor& model_output, int timestep, + int prev_timestep, const FDTensor& sample, + FDTensor* out); + void MultiStepDPMSolverSecondOrderUpdate( + const std::vector& model_output_list, + const std::vector& timestep_list, int prev_timestep, + const FDTensor& sample, FDTensor* out); + void MultiStepDPMSolverThirdOrderUpdate( + const std::vector& model_output_list, + const std::vector& timestep_list, int prev_timestep, + const FDTensor& sample, FDTensor* out); + void SetTimesteps(int num_inference_steps) override; + void Step(const FDTensor& model_output, int timestep, const FDTensor& sample, + FDTensor* prev_sample) override; + void ScaleModelInput(const FDTensor& sample, FDTensor* out, + const std::vector& timesteps = {}) override; + void AddNoise(const FDTensor& original_samples, const FDTensor& noise, + const FDTensor& timesteps, FDTensor* out) override; + struct Config { + int num_train_timesteps_; + float beta_start_; + float beta_end_; + std::string beta_schedule_; + int solver_order_; + bool predict_epsilon_; + bool thresholding_; + float dynamic_thresholding_ratio_; + float sample_max_value_; + std::string algorithm_type_; + std::string solver_type_; + bool lower_order_final_; + } config; + + private: + FDTensor betas_; + FDTensor alphas_; + FDTensor alphas_cumprod_; + FDTensor alpha_t_; + FDTensor sigma_t_; + FDTensor lambda_t_; + int num_inference_steps_; + FDTensor timesteps_; + int lower_order_nums_; + std::vector model_outputs_; +}; + +} // namespace fastdeploy diff --git a/examples/multimodal/stable_diffusion/cpp/main.cc b/examples/multimodal/stable_diffusion/cpp/main.cc new file mode 100644 index 000000000..3c7d33029 --- /dev/null +++ b/examples/multimodal/stable_diffusion/cpp/main.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dpm_solver_multistep_scheduler.h" +#include + +int main() { + fastdeploy::DPMSolverMultistepScheduler dpm( + /* num_train_timesteps */ 1000, + /* beta_start = */ 0.00085, + /* beta_end = */ 0.012, + /* beta_schedule = */ "scaled_linear", + /* trained_betas = */ {}, + /* solver_order = */ 2, + /* predict_epsilon = */ true, + /* thresholding = */ false, + /* dynamic_thresholding_ratio = */ 0.995, + /* sample_max_value = */ 1.0, + /* algorithm_type = */ "dpmsolver++", + /* solver_type = */ "midpoint", + /* lower_order_final = */ true); + + return 0; +} \ No newline at end of file diff --git a/examples/multimodal/stable_diffusion/cpp/scheduler.h b/examples/multimodal/stable_diffusion/cpp/scheduler.h new file mode 100644 index 000000000..6a5cd2fed --- /dev/null +++ b/examples/multimodal/stable_diffusion/cpp/scheduler.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fastdeploy/core/fd_tensor.h" + +namespace fastdeploy { + +class Scheduler { + virtual void SetTimesteps(int num_inference_steps) = 0; + virtual void Step(const FDTensor& model_output, int timestep, + const FDTensor& sample, FDTensor* prev_sample) = 0; + virtual void ScaleModelInput(const FDTensor& sample, FDTensor* out, + const std::vector& timesteps = {}) = 0; + virtual void AddNoise(const FDTensor& original_samples, const FDTensor& noise, + const FDTensor& timesteps, FDTensor* out) = 0; +}; + +} // namespace fastdeploy diff --git a/fastdeploy/core/fd_tensor.cc b/fastdeploy/core/fd_tensor.cc index 896f2ff3b..e84535ac9 100644 --- a/fastdeploy/core/fd_tensor.cc +++ b/fastdeploy/core/fd_tensor.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "fastdeploy/core/fd_tensor.h" -#include "fastdeploy/core/fd_scalar.h" #include "fastdeploy/core/float16.h" #include "fastdeploy/utils/utils.h" @@ -81,8 +80,7 @@ const void* FDTensor::CpuData() const { void FDTensor::SetExternalData(const std::vector& new_shape, const FDDataType& data_type, void* data_buffer, - const Device& new_device, - int new_device_id) { + const Device& new_device, int new_device_id) { dtype = data_type; shape.assign(new_shape.begin(), new_shape.end()); external_data_ptr = data_buffer; diff --git a/fastdeploy/core/fd_tensor.h b/fastdeploy/core/fd_tensor.h index ef9ff3796..3c79b0c88 100644 --- a/fastdeploy/core/fd_tensor.h +++ b/fastdeploy/core/fd_tensor.h @@ -19,12 +19,11 @@ #include #include "fastdeploy/core/allocate.h" +#include "fastdeploy/core/fd_scalar.h" #include "fastdeploy/core/fd_type.h" namespace fastdeploy { -struct Scalar; - struct FASTDEPLOY_DECL FDTensor { // std::vector data; void* buffer_ = nullptr; diff --git a/fastdeploy/function/clip.cc b/fastdeploy/function/clip.cc index bede9e56a..c4b2fa9e0 100644 --- a/fastdeploy/function/clip.cc +++ b/fastdeploy/function/clip.cc @@ -39,14 +39,15 @@ void ClipKernel(const FDTensor& x, double min, double max, FDTensor* out) { "max should be greater than or equal to min. But received min = %f, " "max = %f", static_cast(min_), static_cast(max_)); - - out->Allocate(x.Shape(), x.Dtype()); + FDTensor tmp; + tmp.Allocate(x.Shape(), x.Dtype()); const T* x_data = reinterpret_cast(x.Data()); int64_t numel = x.Numel(); - T* out_data = reinterpret_cast(out->Data()); + T* out_data = reinterpret_cast(tmp.Data()); std::transform(x_data, x_data + numel, out_data, ClipFunctor(min_, max_)); + *out = std::move(tmp); } void Clip(const FDTensor& x, double min, double max, FDTensor* out) { diff --git a/fastdeploy/function/elementwise.cc b/fastdeploy/function/elementwise.cc index 120fe1678..5d94764de 100644 --- a/fastdeploy/function/elementwise.cc +++ b/fastdeploy/function/elementwise.cc @@ -86,4 +86,25 @@ FDTensor operator/(const FDTensor& x, const FDTensor& y) { return out; } +#define INSTANTIATE_OPERATOR(operation_type) \ + template FDTensor operator operation_type(const FDTensor& x, bool y); \ + template FDTensor operator operation_type(const FDTensor& x, uint8_t y); \ + template FDTensor operator operation_type(const FDTensor& x, int16_t y); \ + template FDTensor operator operation_type(const FDTensor& x, int y); \ + template FDTensor operator operation_type(const FDTensor& x, int64_t y); \ + template FDTensor operator operation_type(const FDTensor& x, float y); \ + template FDTensor operator operation_type(const FDTensor& x, double y); \ + template FDTensor operator operation_type(bool x, const FDTensor& y); \ + template FDTensor operator operation_type(uint8_t x, const FDTensor& y); \ + template FDTensor operator operation_type(int16_t x, const FDTensor& y); \ + template FDTensor operator operation_type(int x, const FDTensor& y); \ + template FDTensor operator operation_type(int64_t x, const FDTensor& y); \ + template FDTensor operator operation_type(float x, const FDTensor& y); \ + template FDTensor operator operation_type(double x, const FDTensor& y) + +INSTANTIATE_OPERATOR(+); +INSTANTIATE_OPERATOR(-); +INSTANTIATE_OPERATOR(*); +INSTANTIATE_OPERATOR(/); + } // namespace fastdeploy diff --git a/fastdeploy/function/elementwise.h b/fastdeploy/function/elementwise.h index fd0a9c44b..53d34da6e 100644 --- a/fastdeploy/function/elementwise.h +++ b/fastdeploy/function/elementwise.h @@ -14,9 +14,11 @@ #pragma once +#include "fastdeploy/core/fd_scalar.h" #include "fastdeploy/core/fd_tensor.h" namespace fastdeploy { + namespace function { /** Excute the add operation for input FDTensors. *out = x + y. @@ -62,10 +64,42 @@ FASTDEPLOY_DECL void Maximum(const FDTensor& x, const FDTensor& y, FASTDEPLOY_DECL FDTensor operator+(const FDTensor& x, const FDTensor& y); +template FDTensor operator+(const FDTensor& x, T y) { + return x + FDTensor(Scalar(y)); +} + +template FDTensor operator+(T x, const FDTensor& y) { + return FDTensor(Scalar(x)) + y; +} + FASTDEPLOY_DECL FDTensor operator-(const FDTensor& x, const FDTensor& y); +template FDTensor operator-(const FDTensor& x, T y) { + return x - FDTensor(Scalar(y)); +} + +template FDTensor operator-(T x, const FDTensor& y) { + return FDTensor(Scalar(x)) - y; +} + FASTDEPLOY_DECL FDTensor operator*(const FDTensor& x, const FDTensor& y); +template FDTensor operator*(const FDTensor& x, T y) { + return x * FDTensor(Scalar(y)); +} + +template FDTensor operator*(T x, const FDTensor& y) { + return FDTensor(Scalar(x)) * y; +} + FASTDEPLOY_DECL FDTensor operator/(const FDTensor& x, const FDTensor& y); +template FDTensor operator/(const FDTensor& x, T y) { + return x / FDTensor(Scalar(y)); +} + +template FDTensor operator/(T x, const FDTensor& y) { + return FDTensor(Scalar(x)) / y; +} + } // namespace fastdeploy diff --git a/fastdeploy/function/elementwise_base.h b/fastdeploy/function/elementwise_base.h index e2fab684e..7ce1a694d 100644 --- a/fastdeploy/function/elementwise_base.h +++ b/fastdeploy/function/elementwise_base.h @@ -213,10 +213,12 @@ void CommonElementwiseBroadcastForward(const FDTensor& x, const FDTensor& y, GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), y_dims_array.data(), out_dims_array.data(), max_dim, axis); - z->Allocate(out_dims_array, TypeToDataType::dtype); + FDTensor tmp; + tmp.Allocate(out_dims_array, TypeToDataType::dtype); CommonForwardBroadcastCPU( - x, y, z, x_dims_array.data(), y_dims_array.data(), out_dims_array.data(), - max_dim, func, is_xsize_larger); + x, y, &tmp, x_dims_array.data(), y_dims_array.data(), + out_dims_array.data(), max_dim, func, is_xsize_larger); + *z = std::move(tmp); } template diff --git a/fastdeploy/function/slice.cc b/fastdeploy/function/slice.cc index f374034f2..dab0ea023 100644 --- a/fastdeploy/function/slice.cc +++ b/fastdeploy/function/slice.cc @@ -163,5 +163,20 @@ void Slice(const FDTensor& x, const std::vector& axes, })); } +void Slice(const FDTensor& x, const std::vector& axes, + const std::vector& index, FDTensor* out) { + std::vector ends = index; + for (int i = 0; i < ends.size(); ++i) { + ends[i] += 1; + } + Slice(x, axes, index, ends, out); + for (int i = 0; i < axes.size(); ++i) { + if (out->Shape().size() <= 1) { + break; + } + out->Squeeze(axes[i]); + } +} + } // namespace function } // namespace fastdeploy diff --git a/fastdeploy/function/slice.h b/fastdeploy/function/slice.h index d676a232e..e35ee5762 100644 --- a/fastdeploy/function/slice.h +++ b/fastdeploy/function/slice.h @@ -37,5 +37,8 @@ FASTDEPLOY_DECL void Slice(const FDTensor& x, const std::vector& axes, const std::vector& starts, const std::vector& ends, FDTensor* out); +FASTDEPLOY_DECL void Slice(const FDTensor& x, const std::vector& axes, + const std::vector& index, FDTensor* out); + } // namespace function } // namespace fastdeploy diff --git a/tests/function/test_elementwise.cc b/tests/function/test_elementwise.cc index bd27e498a..5843ba932 100644 --- a/tests/function/test_elementwise.cc +++ b/tests/function/test_elementwise.cc @@ -164,6 +164,15 @@ TEST(fastdeploy, check_same_dim) { check_shape(z.shape, {2, 3, 4}); check_data(reinterpret_cast(z.Data()), maximum_result.data(), maximum_result.size()); + + x = 1.0f - x; + sub_result = {0.157138, 0.353809, 0.862595, 0.885693, 0.340074, 0.464184, + 0.257084, 0.154395, 0.787718, 0.700299, 0.137829, 0.591059, + 0.873153, 0.843381, 0.571159, 0.152347, 0.754137, 0.330954, + 0.121117, 0.323741, 0.333547, 0.67477, 0.586061, 0.165859}; + check_shape(x.shape, {2, 3, 4}); + check_data(reinterpret_cast(x.Data()), sub_result.data(), + sub_result.size()); } TEST(fastdeploy, check_broadcast_dim1) { @@ -498,6 +507,15 @@ TEST(fastdeploy, mixed_operation) { check_shape(output.shape, {2, 3, 4}); check_data(reinterpret_cast(output.Data()), result.data(), result.size()); + + result = {2.854443, 1.87709, 1.585621, 1.012709, 0.332781, 0.998346, + 0.228024, 2.140475, 0.246941, 0.301517, 1.575438, 0.595582, + -0.410393, -0.163718, -0.405571, 0.58563, -0.177035, 0.263035, + 0.075725, 0.591098, 0.156365, -0.106078, -0.475957, 0.626429}; + output = a + b * c / d - e; + check_shape(output.shape, {2, 3, 4}); + check_data(reinterpret_cast(output.Data()), result.data(), + result.size()); } } // namespace function