mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Diffusion] Add C++ dpm solver (#714)
* Add BetaForAlphaBar, ConvertModelOutput, SetTimesteps, and constructor for DPMSolverMultistepScheduler * tmp * Add DPMSolverFirstOrderUpdate * Add ScaleModelInput * Add MultiStepDPMSolverSecondOrderUpdate * add MultiStepDPMSolverThirdOrderUpdate * Add Step * Add FASTDEPLOY_DECL * Add AddNoise * Fix operator * update * Fix DPMSolverMultistepScheduler * Upgrade Slice * Fix DPMSolverFirstOrderUpdate * remove FASTDEPLOY_DECL * Add config for dpm solver
This commit is contained in:
27
examples/multimodal/stable_diffusion/cpp/CMakeLists.txt
Normal file
27
examples/multimodal/stable_diffusion/cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,27 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
PROJECT(main C CXX)
|
||||
CMAKE_MINIMUM_REQUIRED (VERSION 3.10)
|
||||
|
||||
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
|
||||
set(THIRD_LIBS "")
|
||||
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
|
||||
|
||||
include_directories(${FASTDEPLOY_INCS})
|
||||
|
||||
file(GLOB_RECURSE ALL_SRCS ${PROJECT_SOURCE_DIR}/*.cc)
|
||||
|
||||
add_executable(main ${ALL_SRCS})
|
||||
target_link_libraries(main ${FASTDEPLOY_LIBS} ${THIRD_LIBS})
|
@@ -0,0 +1,395 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "dpm_solver_multistep_scheduler.h"
|
||||
#include "fastdeploy/core/fd_scalar.h"
|
||||
#include "fastdeploy/function/functions.h"
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
void DPMSolverMultistepScheduler::BetaForAlphaBar(FDTensor* out,
|
||||
int num_diffusion_timesteps,
|
||||
float max_beta) {
|
||||
auto alpha_bar = [](float time_step) -> float {
|
||||
constexpr float pi = 3.14159265358979323846;
|
||||
return std::pow(std::cos((time_step + 0.008) / 1.008 * pi / 2), 2);
|
||||
};
|
||||
std::vector<FDTensor> betas;
|
||||
for (int i = 0; i < num_diffusion_timesteps; ++i) {
|
||||
float t1 = i / num_diffusion_timesteps;
|
||||
float t2 = (i + 1) / num_diffusion_timesteps;
|
||||
float beta_val = (std::min)(1 - alpha_bar(t1) / alpha_bar(t2), max_beta);
|
||||
betas.emplace_back(Scalar(beta_val));
|
||||
}
|
||||
function::Concat(betas, out);
|
||||
}
|
||||
|
||||
DPMSolverMultistepScheduler::DPMSolverMultistepScheduler(
|
||||
int num_train_timesteps, float beta_start, float beta_end,
|
||||
const std::string& beta_schedule, const std::vector<float>& trained_betas,
|
||||
int solver_order, bool predict_epsilon, bool thresholding,
|
||||
float dynamic_thresholding_ratio, float sample_max_value,
|
||||
const std::string& algorithm_type, const std::string& solver_type,
|
||||
bool lower_order_final)
|
||||
: config({num_train_timesteps, beta_start, beta_end, beta_schedule,
|
||||
solver_order, predict_epsilon, thresholding,
|
||||
dynamic_thresholding_ratio, sample_max_value, algorithm_type,
|
||||
solver_type, lower_order_final}) {
|
||||
int beta_size = trained_betas.size();
|
||||
if (beta_size > 0) {
|
||||
betas_.Allocate({beta_size}, FDDataType::FP32);
|
||||
std::copy(trained_betas.data(), trained_betas.data() + beta_size,
|
||||
reinterpret_cast<float*>(betas_.Data()));
|
||||
} else if (beta_schedule == "linear") {
|
||||
function::Linspace(beta_start, beta_end, num_train_timesteps, &betas_,
|
||||
FDDataType::FP32);
|
||||
} else if (beta_schedule == "scaled_linear") {
|
||||
function::Linspace(beta_start, beta_end, num_train_timesteps, &betas_,
|
||||
FDDataType::FP32);
|
||||
betas_ = betas_ * betas_;
|
||||
} else if (beta_schedule == "squaredcos_cap_v2") {
|
||||
BetaForAlphaBar(&betas_, num_train_timesteps);
|
||||
} else {
|
||||
FDASSERT(false, "%s is not implemented for DPMSolverMultistepScheduler",
|
||||
beta_schedule.c_str());
|
||||
}
|
||||
|
||||
alphas_ = 1.0f - betas_;
|
||||
function::Cumprod(alphas_, &alphas_cumprod_);
|
||||
function::Sqrt(alphas_cumprod_, &alpha_t_);
|
||||
function::Sqrt(1.0f - alphas_cumprod_, &sigma_t_);
|
||||
FDTensor alpha_t_log, sigma_t_log;
|
||||
function::Log(alpha_t_, &alpha_t_log);
|
||||
function::Log(sigma_t_, &sigma_t_log);
|
||||
lambda_t_ = alpha_t_log - sigma_t_log;
|
||||
|
||||
FDASSERT(config.algorithm_type_ == "dpmsolver" ||
|
||||
config.algorithm_type_ == "dpmsolver++",
|
||||
"%s does is not implemented for DPMSolverMultistepScheduler",
|
||||
config.algorithm_type_.c_str());
|
||||
FDASSERT(config.solver_type_ == "midpoint" || config.solver_type_ == "heun",
|
||||
"%s does is not implemented for DPMSolverMultistepScheduler",
|
||||
config.solver_type_.c_str());
|
||||
num_inference_steps_ = -1;
|
||||
|
||||
function::Linspace(0, config.num_train_timesteps_ - 1,
|
||||
config.num_train_timesteps_, ×teps_);
|
||||
function::Cast(timesteps_, ×teps_, FDDataType::INT64);
|
||||
// Reverse timesteps
|
||||
int64_t* timesteps_data = reinterpret_cast<int64_t*>(timesteps_.Data());
|
||||
std::reverse(timesteps_data, timesteps_data + timesteps_.Numel());
|
||||
|
||||
model_outputs_.resize(config.solver_order_);
|
||||
lower_order_nums_ = 0;
|
||||
}
|
||||
|
||||
void DPMSolverMultistepScheduler::ConvertModelOutput(
|
||||
const FDTensor& model_output, int timestep, const FDTensor& sample,
|
||||
FDTensor* out) {
|
||||
if (config.algorithm_type_ == "dpmsolver++") {
|
||||
FDTensor x0_pred;
|
||||
if (config.predict_epsilon_) {
|
||||
FDTensor alpha_t, sigma_t;
|
||||
function::Slice(alpha_t_, {0}, {timestep}, &alpha_t);
|
||||
function::Slice(sigma_t_, {0}, {timestep}, &sigma_t);
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t;
|
||||
} else {
|
||||
x0_pred = model_output;
|
||||
}
|
||||
if (config.thresholding_) {
|
||||
FDTensor dynamic_max_val, x0_pred_abs;
|
||||
function::Abs(x0_pred, &x0_pred_abs);
|
||||
x0_pred_abs.Reshape({x0_pred_abs.Shape()[0], -1});
|
||||
function::Quantile(x0_pred_abs, {config.dynamic_thresholding_ratio_}, {1},
|
||||
&dynamic_max_val);
|
||||
|
||||
FDTensor max_value, dy_max_val;
|
||||
function::FullLike(dynamic_max_val, config.sample_max_value_, &max_value,
|
||||
dynamic_max_val.Dtype());
|
||||
function::Maximum(dynamic_max_val, max_value, &dy_max_val);
|
||||
int expand_dims = x0_pred.Shape().size() - 1;
|
||||
for (int i = 0; i < expand_dims; ++i) {
|
||||
dy_max_val.ExpandDim(dy_max_val.Shape().size());
|
||||
}
|
||||
float clip_max = reinterpret_cast<float*>(dy_max_val.Data())[0];
|
||||
function::Clip(x0_pred, -clip_max, clip_max, &x0_pred);
|
||||
x0_pred = x0_pred / dy_max_val;
|
||||
}
|
||||
*out = std::move(x0_pred);
|
||||
} else if (config.algorithm_type_ == "dpmsolver") {
|
||||
if (config.predict_epsilon_) {
|
||||
*out = model_output;
|
||||
} else {
|
||||
FDTensor alpha_t, sigma_t;
|
||||
function::Slice(alpha_t_, {0}, {timestep}, &alpha_t);
|
||||
function::Slice(sigma_t_, {0}, {timestep}, &sigma_t);
|
||||
*out = (sample - (alpha_t * model_output)) / sigma_t;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DPMSolverMultistepScheduler::DPMSolverFirstOrderUpdate(
|
||||
const FDTensor& model_output, int timestep, int prev_timestep,
|
||||
const FDTensor& sample, FDTensor* out) {
|
||||
FDTensor lambda_t, lambda_s;
|
||||
function::Slice(lambda_t_, {0}, {prev_timestep}, &lambda_t);
|
||||
function::Slice(lambda_t_, {0}, {timestep}, &lambda_s);
|
||||
|
||||
FDTensor alpha_t, alpha_s;
|
||||
function::Slice(alpha_t_, {0}, {prev_timestep}, &alpha_t);
|
||||
function::Slice(alpha_t_, {0}, {timestep}, &alpha_s);
|
||||
|
||||
FDTensor sigma_t, sigma_s;
|
||||
function::Slice(sigma_t_, {0}, {prev_timestep}, &sigma_t);
|
||||
function::Slice(sigma_t_, {0}, {timestep}, &sigma_s);
|
||||
|
||||
FDTensor h = lambda_t - lambda_s;
|
||||
if (config.algorithm_type_ == "dpmsolver++") {
|
||||
function::Exp(0.0f - h, &h);
|
||||
*out = (sigma_t / sigma_s) * sample - (alpha_t * (h - 1.0f)) * model_output;
|
||||
} else if (config.algorithm_type_ == "dpmsolver") {
|
||||
function::Exp(h, &h);
|
||||
*out = (alpha_t / alpha_s) * sample - (sigma_t * (h - 1.0f)) * model_output;
|
||||
}
|
||||
}
|
||||
|
||||
void DPMSolverMultistepScheduler::MultiStepDPMSolverSecondOrderUpdate(
|
||||
const std::vector<FDTensor>& model_output_list,
|
||||
const std::vector<int>& timestep_list, int prev_timestep,
|
||||
const FDTensor& sample, FDTensor* out) {
|
||||
int timestep_size = timestep_list.size();
|
||||
int model_output_size = model_output_list.size();
|
||||
int t = prev_timestep;
|
||||
int s0 = timestep_list[timestep_size - 1];
|
||||
int s1 = timestep_list[timestep_size - 2];
|
||||
const FDTensor& m0 = model_output_list[model_output_size - 1];
|
||||
const FDTensor& m1 = model_output_list[model_output_size - 2];
|
||||
FDTensor lambda_t, lambda_s0, lambda_s1;
|
||||
function::Slice(lambda_t_, {0}, {t}, &lambda_t);
|
||||
function::Slice(lambda_t_, {0}, {s0}, &lambda_s0);
|
||||
function::Slice(lambda_t_, {0}, {s1}, &lambda_s1);
|
||||
|
||||
FDTensor alpha_t, alpha_s0, sigma_t, sigma_s0;
|
||||
function::Slice(alpha_t_, {0}, {t}, &alpha_t);
|
||||
function::Slice(alpha_t_, {0}, {s0}, &alpha_s0);
|
||||
function::Slice(sigma_t_, {0}, {t}, &sigma_t);
|
||||
function::Slice(sigma_t_, {0}, {s0}, &sigma_s0);
|
||||
|
||||
FDTensor h = lambda_t - lambda_s0;
|
||||
FDTensor h0 = lambda_s0 - lambda_s1;
|
||||
FDTensor r0 = h0 / h;
|
||||
FDTensor D0 = m0;
|
||||
FDTensor D1 = (1.0f / r0) * (m0 - m1);
|
||||
if (config.algorithm_type_ == "dpmsolver++") {
|
||||
if (config.solver_type_ == "midpoint") {
|
||||
function::Exp(0.0f - h, &h);
|
||||
*out = (sigma_t / sigma_s0 * sample) - (alpha_t * (h - 1.0f) * D0) -
|
||||
(0.5f * alpha_t * (h - 1.0f) * D1);
|
||||
} else if (config.solver_type_ == "heun") {
|
||||
FDTensor h_exp;
|
||||
function::Exp(0.0f - h, &h_exp);
|
||||
*out = (sigma_t / sigma_s0 * sample) - (alpha_t * (h_exp - 1.0f) * D0) +
|
||||
(alpha_t * ((h_exp - 1.0f) / h + 1.0f) * D1);
|
||||
}
|
||||
} else if (config.algorithm_type_ == "dpmsolver") {
|
||||
FDTensor h_exp;
|
||||
function::Exp(h, &h_exp);
|
||||
if (config.solver_type_ == "midpoint") {
|
||||
*out = alpha_t / alpha_s0 * sample - sigma_t * (h_exp - 1.0f) * D0 -
|
||||
0.5 * (sigma_t * (h_exp - 1.0f) * D1);
|
||||
} else if (config.solver_type_ == "heun") {
|
||||
*out = alpha_t / alpha_s0 * sample - sigma_t * (h_exp - 1.0f) * D0 -
|
||||
(sigma_t * ((h_exp - 1.0f) / h - 1.0f) * D1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DPMSolverMultistepScheduler::MultiStepDPMSolverThirdOrderUpdate(
|
||||
const std::vector<FDTensor>& model_output_list,
|
||||
const std::vector<int>& timestep_list, int prev_timestep,
|
||||
const FDTensor& sample, FDTensor* out) {
|
||||
int timestep_size = timestep_list.size();
|
||||
int model_output_size = model_output_list.size();
|
||||
int t = prev_timestep;
|
||||
|
||||
int s0 = timestep_list[timestep_size - 1];
|
||||
int s1 = timestep_list[timestep_size - 2];
|
||||
int s2 = timestep_list[timestep_size - 3];
|
||||
const FDTensor& m0 = model_output_list[model_output_size - 1];
|
||||
const FDTensor& m1 = model_output_list[model_output_size - 2];
|
||||
const FDTensor& m2 = model_output_list[model_output_size - 3];
|
||||
|
||||
FDTensor lambda_t, lambda_s0, lambda_s1, lambda_s2;
|
||||
function::Slice(lambda_t_, {0}, {t}, &lambda_t);
|
||||
function::Slice(lambda_t_, {0}, {s0}, &lambda_s0);
|
||||
function::Slice(lambda_t_, {0}, {s1}, &lambda_s1);
|
||||
function::Slice(lambda_t_, {0}, {s2}, &lambda_s2);
|
||||
|
||||
FDTensor alpha_t, alpha_s0, sigma_t, sigma_s0;
|
||||
function::Slice(alpha_t_, {0}, {t}, &alpha_t);
|
||||
function::Slice(alpha_t_, {0}, {s0}, &alpha_s0);
|
||||
function::Slice(sigma_t_, {0}, {t}, &sigma_t);
|
||||
function::Slice(sigma_t_, {0}, {s0}, &sigma_s0);
|
||||
|
||||
FDTensor h = lambda_t - lambda_s0;
|
||||
FDTensor h0 = lambda_s0 - lambda_s1;
|
||||
FDTensor h1 = lambda_s1 - lambda_s2;
|
||||
|
||||
FDTensor r0 = h0 / h;
|
||||
FDTensor r1 = h1 / h;
|
||||
FDTensor D0 = m0;
|
||||
FDTensor D1_0 = (1.0f / r0) * (m0 - m1);
|
||||
FDTensor D1_1 = (1.0f / r1) * (m1 - m2);
|
||||
FDTensor D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1);
|
||||
FDTensor D2 = (1.0f / (r0 + r1)) * (D1_0 - D1_1);
|
||||
|
||||
if (config.algorithm_type_ == "dpmsolver++") {
|
||||
FDTensor h_exp;
|
||||
function::Exp(0.0f - h, &h_exp);
|
||||
*out = (sigma_t / sigma_s0) * sample - (alpha_t * (h_exp - 1.0f)) * D0 +
|
||||
(alpha_t * ((h_exp - 1.0) / h + 1.0)) * D1 -
|
||||
(alpha_t * ((h_exp - 1.0 + h) / (h * h) - 0.5)) * D2;
|
||||
|
||||
} else if (config.algorithm_type_ == "dpmsolver") {
|
||||
FDTensor h_exp;
|
||||
function::Exp(h, &h_exp);
|
||||
*out = (alpha_t / alpha_s0) * sample - (sigma_t * (h_exp - 1.0f)) * D0 +
|
||||
(sigma_t * ((h_exp - 1.0) / h - 1.0)) * D1 -
|
||||
(sigma_t * ((h_exp - 1.0 - h) / (h * h) - 0.5)) * D2;
|
||||
}
|
||||
}
|
||||
|
||||
void DPMSolverMultistepScheduler::ScaleModelInput(
|
||||
const FDTensor& sample, FDTensor* out,
|
||||
const std::vector<FDTensor>& timesteps) {
|
||||
*out = sample;
|
||||
}
|
||||
|
||||
void DPMSolverMultistepScheduler::SetTimesteps(int num_inference_steps) {
|
||||
num_inference_steps_ = num_inference_steps;
|
||||
function::Linspace(0, config.num_train_timesteps_ - 1,
|
||||
num_inference_steps + 1, ×teps_);
|
||||
function::Round(timesteps_, ×teps_);
|
||||
// Reverse timesteps
|
||||
float* timesteps_data = reinterpret_cast<float*>(timesteps_.Data());
|
||||
std::reverse(timesteps_data, timesteps_data + timesteps_.Numel());
|
||||
FDTensor timestep_tmp;
|
||||
timestep_tmp.Allocate({num_inference_steps}, timesteps_.Dtype());
|
||||
float* timestep_tmp_data = reinterpret_cast<float*>(timestep_tmp.Data());
|
||||
std::copy(timesteps_data, timesteps_data + num_inference_steps,
|
||||
timestep_tmp_data);
|
||||
timesteps_ = std::move(timestep_tmp);
|
||||
|
||||
function::Cast(timesteps_, ×teps_, FDDataType::INT64);
|
||||
|
||||
model_outputs_.clear();
|
||||
model_outputs_.resize(config.solver_order_);
|
||||
|
||||
lower_order_nums_ = 0;
|
||||
}
|
||||
|
||||
void DPMSolverMultistepScheduler::Step(const FDTensor& model_output,
|
||||
int timestep, const FDTensor& sample,
|
||||
FDTensor* prev_sample) {
|
||||
FDASSERT(num_inference_steps_ > -1,
|
||||
"Number of inference steps is -1, you need to run SetTimesteps "
|
||||
"after creating the scheduler");
|
||||
int64_t step_index = timesteps_.Numel() - 1;
|
||||
int64_t* timesteps_data = reinterpret_cast<int64_t*>(timesteps_.Data());
|
||||
int64_t* timesteps_iter =
|
||||
std::find(timesteps_data, timesteps_data + timesteps_.Numel(), timestep);
|
||||
if (timesteps_iter - timesteps_data < timesteps_.Numel()) {
|
||||
step_index = timesteps_iter - timesteps_data;
|
||||
}
|
||||
|
||||
int64_t prev_timestep = 0;
|
||||
if (step_index != timesteps_.Numel() - 1) {
|
||||
prev_timestep = timesteps_data[step_index + 1];
|
||||
}
|
||||
bool lower_order_final = (step_index == timesteps_.Numel() - 1) &&
|
||||
config.lower_order_final_ &&
|
||||
(timesteps_.Numel() < 15);
|
||||
bool lower_order_second = (step_index == timesteps_.Numel() - 2) &&
|
||||
config.lower_order_final_ &&
|
||||
(timesteps_.Numel() < 15);
|
||||
FDTensor model_out;
|
||||
ConvertModelOutput(model_output, timestep, sample, &model_out);
|
||||
for (int i = 0; i < config.solver_order_ - 1; ++i) {
|
||||
model_outputs_[i] = std::move(model_outputs_[i + 1]);
|
||||
}
|
||||
model_outputs_[config.solver_order_ - 1] = std::move(model_out);
|
||||
|
||||
if (config.solver_order_ == 1 || lower_order_nums_ < 1 || lower_order_final) {
|
||||
DPMSolverFirstOrderUpdate(model_outputs_[config.solver_order_ - 1],
|
||||
timestep, prev_timestep, sample, prev_sample);
|
||||
} else if (config.solver_order_ == 2 || lower_order_nums_ < 2 ||
|
||||
lower_order_second) {
|
||||
int t0 = reinterpret_cast<int64_t*>(timesteps_.Data())[step_index - 1];
|
||||
std::vector<int> timestep_list = {t0, timestep};
|
||||
MultiStepDPMSolverSecondOrderUpdate(model_outputs_, timestep_list,
|
||||
prev_timestep, sample, prev_sample);
|
||||
} else {
|
||||
int t0 = reinterpret_cast<int64_t*>(timesteps_.Data())[step_index - 1];
|
||||
int t1 = reinterpret_cast<int64_t*>(timesteps_.Data())[step_index - 2];
|
||||
std::vector<int> timestep_list = {t1, t0, timestep};
|
||||
MultiStepDPMSolverThirdOrderUpdate(model_outputs_, timestep_list,
|
||||
prev_timestep, sample, prev_sample);
|
||||
}
|
||||
|
||||
if (lower_order_nums_ < config.solver_order_) {
|
||||
lower_order_nums_ += 1;
|
||||
}
|
||||
}
|
||||
|
||||
void DPMSolverMultistepScheduler::AddNoise(const FDTensor& original_samples,
|
||||
const FDTensor& noise,
|
||||
const FDTensor& timesteps,
|
||||
FDTensor* out) {
|
||||
function::Cast(alphas_cumprod_, &alphas_cumprod_, original_samples.Dtype());
|
||||
|
||||
const int64_t* timesteps_data =
|
||||
reinterpret_cast<const int64_t*>(timesteps.Data());
|
||||
std::vector<int64_t> timesteps_vec;
|
||||
for (int i = 0; i < timesteps.Numel(); ++i) {
|
||||
timesteps_vec.push_back(timesteps_data[i]);
|
||||
}
|
||||
FDTensor sqrt_alpha_prod;
|
||||
function::Slice(alphas_cumprod_, {0}, timesteps_vec, &sqrt_alpha_prod);
|
||||
function::Sqrt(sqrt_alpha_prod, &sqrt_alpha_prod);
|
||||
sqrt_alpha_prod.Reshape({-1});
|
||||
int rank_diff =
|
||||
original_samples.Shape().size() - sqrt_alpha_prod.Shape().size();
|
||||
for (int i = 0; i < rank_diff; ++i) {
|
||||
int curr_rank = sqrt_alpha_prod.Shape().size();
|
||||
sqrt_alpha_prod.ExpandDim(curr_rank - 1);
|
||||
}
|
||||
|
||||
FDTensor sqrt_one_minus_alpha_prod;
|
||||
function::Slice(alphas_cumprod_, {0}, timesteps_vec,
|
||||
&sqrt_one_minus_alpha_prod);
|
||||
sqrt_one_minus_alpha_prod = 1.0f - sqrt_one_minus_alpha_prod;
|
||||
function::Sqrt(sqrt_one_minus_alpha_prod, &sqrt_one_minus_alpha_prod);
|
||||
sqrt_one_minus_alpha_prod.Reshape({-1});
|
||||
rank_diff = original_samples.Shape().size() -
|
||||
sqrt_one_minus_alpha_prod.Shape().size();
|
||||
for (int i = 0; i < rank_diff; ++i) {
|
||||
int curr_rank = sqrt_one_minus_alpha_prod.Shape().size();
|
||||
sqrt_one_minus_alpha_prod.ExpandDim(curr_rank - 1);
|
||||
}
|
||||
*out = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise;
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
@@ -0,0 +1,85 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "./scheduler.h"
|
||||
#include "fastdeploy/core/fd_tensor.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
class DPMSolverMultistepScheduler : public Scheduler {
|
||||
public:
|
||||
DPMSolverMultistepScheduler(int num_train_timesteps = 1000,
|
||||
float beta_start = 0.0001, float beta_end = 0.02,
|
||||
const std::string& beta_schedule = "linear",
|
||||
const std::vector<float>& trained_betas = {},
|
||||
int solver_order = 2, bool predict_epsilon = true,
|
||||
bool thresholding = false,
|
||||
float dynamic_thresholding_ratio = 0.995,
|
||||
float sample_max_value = 1.0,
|
||||
const std::string& algorithm_type = "dpmsolver++",
|
||||
const std::string& solver_type = "midpoint",
|
||||
bool lower_order_final = true);
|
||||
void BetaForAlphaBar(FDTensor* out, int num_diffusion_timesteps,
|
||||
float max_beta = 0.999);
|
||||
void ConvertModelOutput(const FDTensor& model_output, int timestep,
|
||||
const FDTensor& sample, FDTensor* out);
|
||||
void DPMSolverFirstOrderUpdate(const FDTensor& model_output, int timestep,
|
||||
int prev_timestep, const FDTensor& sample,
|
||||
FDTensor* out);
|
||||
void MultiStepDPMSolverSecondOrderUpdate(
|
||||
const std::vector<FDTensor>& model_output_list,
|
||||
const std::vector<int>& timestep_list, int prev_timestep,
|
||||
const FDTensor& sample, FDTensor* out);
|
||||
void MultiStepDPMSolverThirdOrderUpdate(
|
||||
const std::vector<FDTensor>& model_output_list,
|
||||
const std::vector<int>& timestep_list, int prev_timestep,
|
||||
const FDTensor& sample, FDTensor* out);
|
||||
void SetTimesteps(int num_inference_steps) override;
|
||||
void Step(const FDTensor& model_output, int timestep, const FDTensor& sample,
|
||||
FDTensor* prev_sample) override;
|
||||
void ScaleModelInput(const FDTensor& sample, FDTensor* out,
|
||||
const std::vector<FDTensor>& timesteps = {}) override;
|
||||
void AddNoise(const FDTensor& original_samples, const FDTensor& noise,
|
||||
const FDTensor& timesteps, FDTensor* out) override;
|
||||
struct Config {
|
||||
int num_train_timesteps_;
|
||||
float beta_start_;
|
||||
float beta_end_;
|
||||
std::string beta_schedule_;
|
||||
int solver_order_;
|
||||
bool predict_epsilon_;
|
||||
bool thresholding_;
|
||||
float dynamic_thresholding_ratio_;
|
||||
float sample_max_value_;
|
||||
std::string algorithm_type_;
|
||||
std::string solver_type_;
|
||||
bool lower_order_final_;
|
||||
} config;
|
||||
|
||||
private:
|
||||
FDTensor betas_;
|
||||
FDTensor alphas_;
|
||||
FDTensor alphas_cumprod_;
|
||||
FDTensor alpha_t_;
|
||||
FDTensor sigma_t_;
|
||||
FDTensor lambda_t_;
|
||||
int num_inference_steps_;
|
||||
FDTensor timesteps_;
|
||||
int lower_order_nums_;
|
||||
std::vector<FDTensor> model_outputs_;
|
||||
};
|
||||
|
||||
} // namespace fastdeploy
|
35
examples/multimodal/stable_diffusion/cpp/main.cc
Normal file
35
examples/multimodal/stable_diffusion/cpp/main.cc
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "dpm_solver_multistep_scheduler.h"
|
||||
#include <iostream>
|
||||
|
||||
int main() {
|
||||
fastdeploy::DPMSolverMultistepScheduler dpm(
|
||||
/* num_train_timesteps */ 1000,
|
||||
/* beta_start = */ 0.00085,
|
||||
/* beta_end = */ 0.012,
|
||||
/* beta_schedule = */ "scaled_linear",
|
||||
/* trained_betas = */ {},
|
||||
/* solver_order = */ 2,
|
||||
/* predict_epsilon = */ true,
|
||||
/* thresholding = */ false,
|
||||
/* dynamic_thresholding_ratio = */ 0.995,
|
||||
/* sample_max_value = */ 1.0,
|
||||
/* algorithm_type = */ "dpmsolver++",
|
||||
/* solver_type = */ "midpoint",
|
||||
/* lower_order_final = */ true);
|
||||
|
||||
return 0;
|
||||
}
|
31
examples/multimodal/stable_diffusion/cpp/scheduler.h
Normal file
31
examples/multimodal/stable_diffusion/cpp/scheduler.h
Normal file
@@ -0,0 +1,31 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fastdeploy/core/fd_tensor.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
class Scheduler {
|
||||
virtual void SetTimesteps(int num_inference_steps) = 0;
|
||||
virtual void Step(const FDTensor& model_output, int timestep,
|
||||
const FDTensor& sample, FDTensor* prev_sample) = 0;
|
||||
virtual void ScaleModelInput(const FDTensor& sample, FDTensor* out,
|
||||
const std::vector<FDTensor>& timesteps = {}) = 0;
|
||||
virtual void AddNoise(const FDTensor& original_samples, const FDTensor& noise,
|
||||
const FDTensor& timesteps, FDTensor* out) = 0;
|
||||
};
|
||||
|
||||
} // namespace fastdeploy
|
@@ -12,7 +12,6 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include "fastdeploy/core/fd_tensor.h"
|
||||
#include "fastdeploy/core/fd_scalar.h"
|
||||
#include "fastdeploy/core/float16.h"
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
|
||||
@@ -81,8 +80,7 @@ const void* FDTensor::CpuData() const {
|
||||
|
||||
void FDTensor::SetExternalData(const std::vector<int64_t>& new_shape,
|
||||
const FDDataType& data_type, void* data_buffer,
|
||||
const Device& new_device,
|
||||
int new_device_id) {
|
||||
const Device& new_device, int new_device_id) {
|
||||
dtype = data_type;
|
||||
shape.assign(new_shape.begin(), new_shape.end());
|
||||
external_data_ptr = data_buffer;
|
||||
|
@@ -19,12 +19,11 @@
|
||||
#include <vector>
|
||||
|
||||
#include "fastdeploy/core/allocate.h"
|
||||
#include "fastdeploy/core/fd_scalar.h"
|
||||
#include "fastdeploy/core/fd_type.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
struct Scalar;
|
||||
|
||||
struct FASTDEPLOY_DECL FDTensor {
|
||||
// std::vector<int8_t> data;
|
||||
void* buffer_ = nullptr;
|
||||
|
@@ -39,14 +39,15 @@ void ClipKernel(const FDTensor& x, double min, double max, FDTensor* out) {
|
||||
"max should be greater than or equal to min. But received min = %f, "
|
||||
"max = %f",
|
||||
static_cast<float>(min_), static_cast<float>(max_));
|
||||
|
||||
out->Allocate(x.Shape(), x.Dtype());
|
||||
FDTensor tmp;
|
||||
tmp.Allocate(x.Shape(), x.Dtype());
|
||||
const T* x_data = reinterpret_cast<const T*>(x.Data());
|
||||
|
||||
int64_t numel = x.Numel();
|
||||
T* out_data = reinterpret_cast<T*>(out->Data());
|
||||
T* out_data = reinterpret_cast<T*>(tmp.Data());
|
||||
|
||||
std::transform(x_data, x_data + numel, out_data, ClipFunctor<T>(min_, max_));
|
||||
*out = std::move(tmp);
|
||||
}
|
||||
|
||||
void Clip(const FDTensor& x, double min, double max, FDTensor* out) {
|
||||
|
@@ -86,4 +86,25 @@ FDTensor operator/(const FDTensor& x, const FDTensor& y) {
|
||||
return out;
|
||||
}
|
||||
|
||||
#define INSTANTIATE_OPERATOR(operation_type) \
|
||||
template FDTensor operator operation_type(const FDTensor& x, bool y); \
|
||||
template FDTensor operator operation_type(const FDTensor& x, uint8_t y); \
|
||||
template FDTensor operator operation_type(const FDTensor& x, int16_t y); \
|
||||
template FDTensor operator operation_type(const FDTensor& x, int y); \
|
||||
template FDTensor operator operation_type(const FDTensor& x, int64_t y); \
|
||||
template FDTensor operator operation_type(const FDTensor& x, float y); \
|
||||
template FDTensor operator operation_type(const FDTensor& x, double y); \
|
||||
template FDTensor operator operation_type(bool x, const FDTensor& y); \
|
||||
template FDTensor operator operation_type(uint8_t x, const FDTensor& y); \
|
||||
template FDTensor operator operation_type(int16_t x, const FDTensor& y); \
|
||||
template FDTensor operator operation_type(int x, const FDTensor& y); \
|
||||
template FDTensor operator operation_type(int64_t x, const FDTensor& y); \
|
||||
template FDTensor operator operation_type(float x, const FDTensor& y); \
|
||||
template FDTensor operator operation_type(double x, const FDTensor& y)
|
||||
|
||||
INSTANTIATE_OPERATOR(+);
|
||||
INSTANTIATE_OPERATOR(-);
|
||||
INSTANTIATE_OPERATOR(*);
|
||||
INSTANTIATE_OPERATOR(/);
|
||||
|
||||
} // namespace fastdeploy
|
||||
|
@@ -14,9 +14,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fastdeploy/core/fd_scalar.h"
|
||||
#include "fastdeploy/core/fd_tensor.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
namespace function {
|
||||
|
||||
/** Excute the add operation for input FDTensors. *out = x + y.
|
||||
@@ -62,10 +64,42 @@ FASTDEPLOY_DECL void Maximum(const FDTensor& x, const FDTensor& y,
|
||||
|
||||
FASTDEPLOY_DECL FDTensor operator+(const FDTensor& x, const FDTensor& y);
|
||||
|
||||
template <typename T> FDTensor operator+(const FDTensor& x, T y) {
|
||||
return x + FDTensor(Scalar(y));
|
||||
}
|
||||
|
||||
template <typename T> FDTensor operator+(T x, const FDTensor& y) {
|
||||
return FDTensor(Scalar(x)) + y;
|
||||
}
|
||||
|
||||
FASTDEPLOY_DECL FDTensor operator-(const FDTensor& x, const FDTensor& y);
|
||||
|
||||
template <typename T> FDTensor operator-(const FDTensor& x, T y) {
|
||||
return x - FDTensor(Scalar(y));
|
||||
}
|
||||
|
||||
template <typename T> FDTensor operator-(T x, const FDTensor& y) {
|
||||
return FDTensor(Scalar(x)) - y;
|
||||
}
|
||||
|
||||
FASTDEPLOY_DECL FDTensor operator*(const FDTensor& x, const FDTensor& y);
|
||||
|
||||
template <typename T> FDTensor operator*(const FDTensor& x, T y) {
|
||||
return x * FDTensor(Scalar(y));
|
||||
}
|
||||
|
||||
template <typename T> FDTensor operator*(T x, const FDTensor& y) {
|
||||
return FDTensor(Scalar(x)) * y;
|
||||
}
|
||||
|
||||
FASTDEPLOY_DECL FDTensor operator/(const FDTensor& x, const FDTensor& y);
|
||||
|
||||
template <typename T> FDTensor operator/(const FDTensor& x, T y) {
|
||||
return x / FDTensor(Scalar(y));
|
||||
}
|
||||
|
||||
template <typename T> FDTensor operator/(T x, const FDTensor& y) {
|
||||
return FDTensor(Scalar(x)) / y;
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
|
@@ -213,10 +213,12 @@ void CommonElementwiseBroadcastForward(const FDTensor& x, const FDTensor& y,
|
||||
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
|
||||
y_dims_array.data(), out_dims_array.data(), max_dim,
|
||||
axis);
|
||||
z->Allocate(out_dims_array, TypeToDataType<OutType>::dtype);
|
||||
FDTensor tmp;
|
||||
tmp.Allocate(out_dims_array, TypeToDataType<OutType>::dtype);
|
||||
CommonForwardBroadcastCPU<Functor, T, OutType>(
|
||||
x, y, z, x_dims_array.data(), y_dims_array.data(), out_dims_array.data(),
|
||||
max_dim, func, is_xsize_larger);
|
||||
x, y, &tmp, x_dims_array.data(), y_dims_array.data(),
|
||||
out_dims_array.data(), max_dim, func, is_xsize_larger);
|
||||
*z = std::move(tmp);
|
||||
}
|
||||
|
||||
template <typename Functor, typename T, typename OutType = T>
|
||||
|
@@ -163,5 +163,20 @@ void Slice(const FDTensor& x, const std::vector<int64_t>& axes,
|
||||
}));
|
||||
}
|
||||
|
||||
void Slice(const FDTensor& x, const std::vector<int64_t>& axes,
|
||||
const std::vector<int64_t>& index, FDTensor* out) {
|
||||
std::vector<int64_t> ends = index;
|
||||
for (int i = 0; i < ends.size(); ++i) {
|
||||
ends[i] += 1;
|
||||
}
|
||||
Slice(x, axes, index, ends, out);
|
||||
for (int i = 0; i < axes.size(); ++i) {
|
||||
if (out->Shape().size() <= 1) {
|
||||
break;
|
||||
}
|
||||
out->Squeeze(axes[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace function
|
||||
} // namespace fastdeploy
|
||||
|
@@ -37,5 +37,8 @@ FASTDEPLOY_DECL void Slice(const FDTensor& x, const std::vector<int64_t>& axes,
|
||||
const std::vector<int64_t>& starts,
|
||||
const std::vector<int64_t>& ends, FDTensor* out);
|
||||
|
||||
FASTDEPLOY_DECL void Slice(const FDTensor& x, const std::vector<int64_t>& axes,
|
||||
const std::vector<int64_t>& index, FDTensor* out);
|
||||
|
||||
} // namespace function
|
||||
} // namespace fastdeploy
|
||||
|
@@ -164,6 +164,15 @@ TEST(fastdeploy, check_same_dim) {
|
||||
check_shape(z.shape, {2, 3, 4});
|
||||
check_data(reinterpret_cast<const float*>(z.Data()), maximum_result.data(),
|
||||
maximum_result.size());
|
||||
|
||||
x = 1.0f - x;
|
||||
sub_result = {0.157138, 0.353809, 0.862595, 0.885693, 0.340074, 0.464184,
|
||||
0.257084, 0.154395, 0.787718, 0.700299, 0.137829, 0.591059,
|
||||
0.873153, 0.843381, 0.571159, 0.152347, 0.754137, 0.330954,
|
||||
0.121117, 0.323741, 0.333547, 0.67477, 0.586061, 0.165859};
|
||||
check_shape(x.shape, {2, 3, 4});
|
||||
check_data(reinterpret_cast<const float*>(x.Data()), sub_result.data(),
|
||||
sub_result.size());
|
||||
}
|
||||
|
||||
TEST(fastdeploy, check_broadcast_dim1) {
|
||||
@@ -498,6 +507,15 @@ TEST(fastdeploy, mixed_operation) {
|
||||
check_shape(output.shape, {2, 3, 4});
|
||||
check_data(reinterpret_cast<const float*>(output.Data()), result.data(),
|
||||
result.size());
|
||||
|
||||
result = {2.854443, 1.87709, 1.585621, 1.012709, 0.332781, 0.998346,
|
||||
0.228024, 2.140475, 0.246941, 0.301517, 1.575438, 0.595582,
|
||||
-0.410393, -0.163718, -0.405571, 0.58563, -0.177035, 0.263035,
|
||||
0.075725, 0.591098, 0.156365, -0.106078, -0.475957, 0.626429};
|
||||
output = a + b * c / d - e;
|
||||
check_shape(output.shape, {2, 3, 4});
|
||||
check_data(reinterpret_cast<const float*>(output.Data()), result.data(),
|
||||
result.size());
|
||||
}
|
||||
|
||||
} // namespace function
|
||||
|
Reference in New Issue
Block a user