mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 09:07:10 +08:00
[Diffusion] Add StableDiffusionInpaint pipeline (#760)
* Update Inpaint pipeline * Update concat * Add GaussianRandomKernel * Update GaussianRandom * Add vae endoder * Add unet infer * Add vae decoder predict * add PrepareMaskAndMaskedImage * Add imwrite * Add time counter * Fix pipeline * use FDTensor move * Fix scaled_linear dpm solver * Add RGB2BGR
This commit is contained in:
@@ -57,8 +57,8 @@ DPMSolverMultistepScheduler::DPMSolverMultistepScheduler(
|
||||
function::Linspace(beta_start, beta_end, num_train_timesteps, &betas_,
|
||||
FDDataType::FP32);
|
||||
} else if (beta_schedule == "scaled_linear") {
|
||||
function::Linspace(beta_start, beta_end, num_train_timesteps, &betas_,
|
||||
FDDataType::FP32);
|
||||
function::Linspace(std::sqrt(beta_start), std::sqrt(beta_end),
|
||||
num_train_timesteps, &betas_, FDDataType::FP32);
|
||||
betas_ = betas_ * betas_;
|
||||
} else if (beta_schedule == "squaredcos_cap_v2") {
|
||||
BetaForAlphaBar(&betas_, num_train_timesteps);
|
||||
@@ -96,6 +96,8 @@ DPMSolverMultistepScheduler::DPMSolverMultistepScheduler(
|
||||
lower_order_nums_ = 0;
|
||||
}
|
||||
|
||||
float DPMSolverMultistepScheduler::InitNoiseSigma() { return 1.0; }
|
||||
|
||||
void DPMSolverMultistepScheduler::ConvertModelOutput(
|
||||
const FDTensor& model_output, int timestep, const FDTensor& sample,
|
||||
FDTensor* out) {
|
||||
@@ -314,7 +316,6 @@ void DPMSolverMultistepScheduler::Step(const FDTensor& model_output,
|
||||
if (timesteps_iter - timesteps_data < timesteps_.Numel()) {
|
||||
step_index = timesteps_iter - timesteps_data;
|
||||
}
|
||||
|
||||
int64_t prev_timestep = 0;
|
||||
if (step_index != timesteps_.Numel() - 1) {
|
||||
prev_timestep = timesteps_data[step_index + 1];
|
||||
@@ -392,4 +393,6 @@ void DPMSolverMultistepScheduler::AddNoise(const FDTensor& original_samples,
|
||||
*out = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise;
|
||||
}
|
||||
|
||||
FDTensor DPMSolverMultistepScheduler::GetTimesteps() { return timesteps_; }
|
||||
|
||||
} // namespace fastdeploy
|
||||
|
@@ -54,6 +54,8 @@ class DPMSolverMultistepScheduler : public Scheduler {
|
||||
const std::vector<FDTensor>& timesteps = {}) override;
|
||||
void AddNoise(const FDTensor& original_samples, const FDTensor& noise,
|
||||
const FDTensor& timesteps, FDTensor* out) override;
|
||||
float InitNoiseSigma() override;
|
||||
FDTensor GetTimesteps() override;
|
||||
struct Config {
|
||||
int num_train_timesteps_;
|
||||
float beta_start_;
|
||||
|
@@ -13,10 +13,55 @@
|
||||
// limitations under the License.
|
||||
|
||||
#include "dpm_solver_multistep_scheduler.h"
|
||||
#include "fastdeploy/vision/common/processors/mat.h"
|
||||
#include "fastdeploy/utils/perf.h"
|
||||
#include "opencv2/highgui/highgui.hpp"
|
||||
#include "opencv2/imgproc/imgproc.hpp"
|
||||
#include "pipeline_stable_diffusion_inpaint.h"
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
template <typename T> std::string Str(const T* value, int size) {
|
||||
std::ostringstream oss;
|
||||
oss << "[ " << value[0];
|
||||
for (int i = 1; i < size; ++i) {
|
||||
oss << " ," << value[i];
|
||||
}
|
||||
oss << " ]";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
std::unique_ptr<fastdeploy::Runtime>
|
||||
CreateRuntime(const std::string& model_file, const std::string& params_file,
|
||||
bool use_paddle_backend = true) {
|
||||
fastdeploy::RuntimeOption runtime_option;
|
||||
runtime_option.SetModelPath(model_file, params_file,
|
||||
fastdeploy::ModelFormat::PADDLE);
|
||||
runtime_option.UseGpu();
|
||||
if (use_paddle_backend) {
|
||||
runtime_option.UsePaddleBackend();
|
||||
} else {
|
||||
runtime_option.UseOrtBackend();
|
||||
}
|
||||
std::unique_ptr<fastdeploy::Runtime> runtime =
|
||||
std::unique_ptr<fastdeploy::Runtime>(new fastdeploy::Runtime());
|
||||
if (!runtime->Init(runtime_option)) {
|
||||
std::cerr << "--- Init FastDeploy Runitme Failed! "
|
||||
<< "\n--- Model: " << model_file << std::endl;
|
||||
return nullptr;
|
||||
} else {
|
||||
std::cout << "--- Init FastDeploy Runitme Done! "
|
||||
<< "\n--- Model: " << model_file << std::endl;
|
||||
}
|
||||
return runtime;
|
||||
}
|
||||
|
||||
int main() {
|
||||
fastdeploy::DPMSolverMultistepScheduler dpm(
|
||||
// 1. Init scheduler
|
||||
std::unique_ptr<fastdeploy::Scheduler> dpm(
|
||||
new fastdeploy::DPMSolverMultistepScheduler(
|
||||
/* num_train_timesteps */ 1000,
|
||||
/* beta_start = */ 0.00085,
|
||||
/* beta_end = */ 0.012,
|
||||
@@ -29,7 +74,60 @@ int main() {
|
||||
/* sample_max_value = */ 1.0,
|
||||
/* algorithm_type = */ "dpmsolver++",
|
||||
/* solver_type = */ "midpoint",
|
||||
/* lower_order_final = */ true);
|
||||
/* lower_order_final = */ true));
|
||||
|
||||
// 2. Init text encoder runtime
|
||||
std::string text_model_file = "sd15_inpaint/text_encoder/inference.pdmodel";
|
||||
std::string text_params_file =
|
||||
"sd15_inpaint/text_encoder/inference.pdiparams";
|
||||
std::unique_ptr<fastdeploy::Runtime> text_encoder_runtime =
|
||||
CreateRuntime(text_model_file, text_params_file, false);
|
||||
|
||||
// 3. Init vae encoder runtime
|
||||
std::string vae_encoder_model_file =
|
||||
"sd15_inpaint/vae_encoder/inference.pdmodel";
|
||||
std::string vae_encoder_params_file =
|
||||
"sd15_inpaint/vae_encoder/inference.pdiparams";
|
||||
std::unique_ptr<fastdeploy::Runtime> vae_encoder_runtime =
|
||||
CreateRuntime(vae_encoder_model_file, vae_encoder_params_file);
|
||||
|
||||
// 4. Init vae decoder runtime
|
||||
std::string vae_decoder_model_file =
|
||||
"sd15_inpaint/vae_decoder/inference.pdmodel";
|
||||
std::string vae_decoder_params_file =
|
||||
"sd15_inpaint/vae_decoder/inference.pdiparams";
|
||||
std::unique_ptr<fastdeploy::Runtime> vae_decoder_runtime =
|
||||
CreateRuntime(vae_decoder_model_file, vae_decoder_params_file);
|
||||
|
||||
// 5. Init unet runtime
|
||||
std::string unet_model_file = "sd15_inpaint/unet/inference.pdmodel";
|
||||
std::string unet_params_file = "sd15_inpaint/unet/inference.pdiparams";
|
||||
std::unique_ptr<fastdeploy::Runtime> unet_runtime =
|
||||
CreateRuntime(unet_model_file, unet_params_file);
|
||||
|
||||
// 6. Init fast tokenizer
|
||||
paddlenlp::fast_tokenizer::tokenizers_impl::ClipFastTokenizer tokenizer(
|
||||
"clip/vocab.json", "clip/merges.txt", /* max_length = */ 77);
|
||||
fastdeploy::StableDiffusionInpaintPipeline pipe(
|
||||
std::move(vae_encoder_runtime), std::move(vae_decoder_runtime),
|
||||
std::move(text_encoder_runtime), std::move(unet_runtime),
|
||||
/* scheduler = */ std::move(dpm), tokenizer);
|
||||
|
||||
// 7. Read images
|
||||
auto image = cv::imread("overture-creations.png");
|
||||
auto mask_image = cv::imread("overture-creations-mask.png");
|
||||
|
||||
// 8. Predict
|
||||
std::vector<std::string> prompts = {
|
||||
"Face of a yellow cat, high resolution, sitting on a park bench"};
|
||||
std::vector<fastdeploy::FDTensor> outputs;
|
||||
fastdeploy::TimeCounter tc;
|
||||
tc.Start();
|
||||
pipe.Predict(prompts, image, mask_image, &outputs, /* height = */ 512,
|
||||
/* width = */ 512, /* num_inference_steps = */ 50);
|
||||
tc.End();
|
||||
tc.PrintInfo();
|
||||
fastdeploy::vision::FDMat mat = fastdeploy::vision::FDMat::Create(outputs[0]);
|
||||
cv::imwrite("cat_on_bench_new.png", *mat.GetOpenCVMat());
|
||||
return 0;
|
||||
}
|
@@ -0,0 +1,322 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "pipeline_stable_diffusion_inpaint.h"
|
||||
#include "fastdeploy/function/functions.h"
|
||||
#include "fastdeploy/vision/common/processors/color_space_convert.h"
|
||||
#include "fastdeploy/vision/common/processors/mat.h"
|
||||
#include "fastdeploy/vision/common/processors/resize.h"
|
||||
#include <algorithm>
|
||||
|
||||
using namespace paddlenlp;
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
static constexpr int NUM_LATENT_CHANNELS = 4;
|
||||
static constexpr int NUM_UNET_INPUT_CHANNELS = 9;
|
||||
|
||||
void StableDiffusionInpaintPipeline::PrepareMaskAndMaskedImage(
|
||||
const cv::Mat& image, const cv::Mat& mask_mat,
|
||||
const std::vector<int64_t>& shape, FDTensor* mask, FDTensor* mask_image) {
|
||||
vision::FDMat image_fdmat(image);
|
||||
vision::BGR2RGB::Run(&image_fdmat, vision::ProcLib::OPENCV);
|
||||
vision::Resize::Run(&image_fdmat, shape[1] * 8, shape[0] * 8, -1.0f, -1.0f,
|
||||
cv::INTER_NEAREST, false, vision::ProcLib::OPENCV);
|
||||
image_fdmat.ShareWithTensor(mask_image);
|
||||
|
||||
vision::FDMat mask_fdmat(mask_mat);
|
||||
vision::BGR2GRAY::Run(&mask_fdmat, vision::ProcLib::OPENCV);
|
||||
vision::Resize::Run(&mask_fdmat, shape[1] * 8, shape[0] * 8, -1.0f, -1.0f,
|
||||
cv::INTER_NEAREST, false, vision::ProcLib::OPENCV);
|
||||
FDTensor image_mask;
|
||||
mask_fdmat.ShareWithTensor(&image_mask);
|
||||
function::Cast(image_mask, &image_mask, FDDataType::FP32);
|
||||
std::vector<float> float_mask(image_mask.Numel(), 0);
|
||||
float* image_mask_ptr = reinterpret_cast<float*>(image_mask.Data());
|
||||
for (int i = 0; i < image_mask.Numel(); ++i) {
|
||||
if (image_mask_ptr[i] < 127.5) {
|
||||
float_mask[i] = 1;
|
||||
}
|
||||
}
|
||||
image_mask.SetExternalData({1, 1, shape[1] * 8, shape[0] * 8},
|
||||
FDDataType::FP32, float_mask.data());
|
||||
|
||||
// Set mask_image
|
||||
mask_image->ExpandDim();
|
||||
function::Transpose(*mask_image, mask_image, {0, 3, 1, 2});
|
||||
function::Cast(*mask_image, mask_image, FDDataType::FP32);
|
||||
*mask_image = *mask_image / 127.5f - 1.0f;
|
||||
*mask_image = *mask_image * image_mask;
|
||||
|
||||
// Set mask
|
||||
vision::FDMat mask_fdmat_t(mask_mat);
|
||||
vision::BGR2GRAY::Run(&mask_fdmat_t, vision::ProcLib::OPENCV);
|
||||
vision::Resize::Run(&mask_fdmat_t, shape[1], shape[0], -1.0f, -1.0f,
|
||||
cv::INTER_NEAREST, false, vision::ProcLib::OPENCV);
|
||||
mask_fdmat_t.ShareWithTensor(mask);
|
||||
function::Cast(*mask, mask, FDDataType::FP32);
|
||||
*mask = *mask / 255.0f;
|
||||
mask->ExpandDim();
|
||||
function::Transpose(*mask, mask, {0, 3, 1, 2});
|
||||
float* mask_data = reinterpret_cast<float*>(mask->Data());
|
||||
for (int i = 0; i < mask->Numel(); ++i) {
|
||||
if (mask_data[i] < 0.5) {
|
||||
mask_data[i] = 0;
|
||||
} else {
|
||||
mask_data[i] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
StableDiffusionInpaintPipeline::StableDiffusionInpaintPipeline(
|
||||
std::unique_ptr<Runtime> vae_encoder, std::unique_ptr<Runtime> vae_decoder,
|
||||
std::unique_ptr<Runtime> text_encoder, std::unique_ptr<Runtime> unet,
|
||||
std::unique_ptr<Scheduler> scheduler,
|
||||
const paddlenlp::fast_tokenizer::tokenizers_impl::ClipFastTokenizer&
|
||||
tokenizer)
|
||||
: vae_encoder_(std::move(vae_encoder)),
|
||||
vae_decoder_(std::move(vae_decoder)),
|
||||
text_encoder_(std::move(text_encoder)), unet_(std::move(unet)),
|
||||
scheduler_(std::move(scheduler)), tokenizer_(tokenizer) {}
|
||||
|
||||
void StableDiffusionInpaintPipeline::Predict(
|
||||
const std::vector<std::string>& prompts, const cv::Mat& image,
|
||||
const cv::Mat& mask_image, std::vector<FDTensor>* output_images, int height,
|
||||
int width, int num_inference_steps, float guidance_scale,
|
||||
const std::vector<std::string>& negative_prompt, int num_images_per_prompt,
|
||||
float eta, uint32_t max_length, const FDTensor* latents, bool output_cv_mat,
|
||||
callback_ptr callback, int callback_steps) {
|
||||
int batch_size = prompts.size();
|
||||
FDASSERT(batch_size >= 1, "prompts should not be empty");
|
||||
FDASSERT(
|
||||
height % 8 == 0 && width % 8 == 0,
|
||||
"`height` and `width` have to be divisible by 8 but are {%d} and {%d}.",
|
||||
height, width);
|
||||
FDASSERT(callback_steps > 0,
|
||||
"`callback_steps` has to be a positive integer but is {%d}",
|
||||
callback_steps);
|
||||
|
||||
// Setting tokenizer attr
|
||||
if (max_length == 0) {
|
||||
tokenizer_.EnablePadMethod(fast_tokenizer::core::RIGHT,
|
||||
tokenizer_.GetPadTokenId(), 0,
|
||||
tokenizer_.GetPadToken(), nullptr, nullptr);
|
||||
tokenizer_.DisableTruncMethod();
|
||||
} else {
|
||||
tokenizer_.EnablePadMethod(fast_tokenizer::core::RIGHT,
|
||||
tokenizer_.GetPadTokenId(), 0,
|
||||
tokenizer_.GetPadToken(), &max_length, nullptr);
|
||||
tokenizer_.EnableTruncMethod(max_length, 0, fast_tokenizer::core::RIGHT,
|
||||
fast_tokenizer::core::LONGEST_FIRST);
|
||||
}
|
||||
std::vector<fast_tokenizer::core::Encoding> encodings;
|
||||
tokenizer_.EncodeBatchStrings(prompts, &encodings);
|
||||
|
||||
std::vector<int64_t> input_ids;
|
||||
for (auto& encoding : encodings) {
|
||||
auto curr_ids = encoding.GetIds();
|
||||
input_ids.insert(input_ids.end(), curr_ids.begin(), curr_ids.end());
|
||||
}
|
||||
encodings.clear();
|
||||
// Get text encoder output
|
||||
FDTensor text_intput_ids;
|
||||
std::vector<FDTensor> inputs(1);
|
||||
inputs[0].SetExternalData({batch_size, max_length}, FDDataType::INT64,
|
||||
input_ids.data());
|
||||
|
||||
TensorInfo text_info = text_encoder_->GetInputInfo(0);
|
||||
inputs[0].name = text_info.name;
|
||||
int output_size = text_encoder_->GetOutputInfos().size();
|
||||
std::vector<FDTensor> outputs(output_size);
|
||||
text_encoder_->Infer(inputs, &outputs);
|
||||
|
||||
FDTensor text_embeddings;
|
||||
function::Tile(outputs[0], {num_images_per_prompt, 1, 1}, &text_embeddings);
|
||||
|
||||
// here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
// of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
// corresponds to doing no classifier free guidance.
|
||||
bool do_classifier_free_guidance = guidance_scale > 1.0;
|
||||
if (do_classifier_free_guidance) {
|
||||
std::vector<std::string> uncond_tokens;
|
||||
if (negative_prompt.size() == 0) {
|
||||
uncond_tokens = {""};
|
||||
} else if (negative_prompt.size() != batch_size) {
|
||||
FDASSERT(false,
|
||||
"negative_prompt has batch size %d, but prompt has batch size "
|
||||
"%d. Please make sure that passed `negative_prompt` matches the "
|
||||
"batch size of `prompt`.",
|
||||
prompts.size(), negative_prompt.size());
|
||||
} else {
|
||||
uncond_tokens = negative_prompt;
|
||||
}
|
||||
tokenizer_.EncodeBatchStrings(uncond_tokens, &encodings);
|
||||
input_ids.clear();
|
||||
for (auto& encoding : encodings) {
|
||||
auto curr_ids = encoding.GetIds();
|
||||
input_ids.insert(input_ids.end(), curr_ids.begin(), curr_ids.end());
|
||||
}
|
||||
inputs[0].SetExternalData({batch_size, max_length}, FDDataType::INT64,
|
||||
input_ids.data());
|
||||
text_encoder_->Infer(inputs, &outputs);
|
||||
FDTensor uncond_embeddings;
|
||||
function::Tile(outputs[0], {num_images_per_prompt, 1, 1},
|
||||
&uncond_embeddings);
|
||||
function::Concat({uncond_embeddings, text_embeddings}, &text_embeddings);
|
||||
}
|
||||
std::vector<int64_t> latents_shape = {batch_size * num_images_per_prompt,
|
||||
NUM_LATENT_CHANNELS, height / 8,
|
||||
width / 8};
|
||||
auto latents_dtype = text_embeddings.Dtype();
|
||||
FDTensor actual_latents;
|
||||
if (latents == nullptr) {
|
||||
function::GaussianRandom(latents_shape, &actual_latents, latents_dtype);
|
||||
} else {
|
||||
bool result = std::equal(latents_shape.begin(), latents_shape.end(),
|
||||
latents->Shape().begin());
|
||||
FDASSERT(result, "Unexpected latents shape, got %s, expected %s",
|
||||
Str(latents_shape).c_str(), Str(latents->Shape()).c_str());
|
||||
actual_latents = *latents;
|
||||
}
|
||||
FDTensor mask_t, mask_image_t;
|
||||
PrepareMaskAndMaskedImage(image, mask_image, {height / 8, width / 8}, &mask_t,
|
||||
&mask_image_t);
|
||||
function::Cast(mask_t, &mask_t, actual_latents.Dtype());
|
||||
function::Cast(mask_image_t, &mask_image_t, actual_latents.Dtype());
|
||||
|
||||
// Get vae encoder output
|
||||
TensorInfo vae_encoder_info = vae_encoder_->GetInputInfo(0);
|
||||
mask_image_t.name = vae_encoder_info.name;
|
||||
outputs.resize(vae_encoder_->GetOutputInfos().size());
|
||||
inputs = {mask_image_t};
|
||||
vae_encoder_->Infer(inputs, &outputs);
|
||||
FDTensor masked_image_latents = 0.18215f * outputs[0];
|
||||
|
||||
std::vector<int64_t> mask_shape(mask_t.Shape().size(), 1);
|
||||
mask_shape[0] = batch_size * num_images_per_prompt;
|
||||
function::Tile(mask_t, mask_shape, &mask_t);
|
||||
|
||||
std::vector<int64_t> mask_image_shape(masked_image_latents.Shape().size(), 1);
|
||||
mask_image_shape[0] = batch_size * num_images_per_prompt;
|
||||
function::Tile(masked_image_latents, mask_image_shape, &masked_image_latents);
|
||||
|
||||
if (do_classifier_free_guidance) {
|
||||
function::Concat({mask_t, mask_t}, &mask_t);
|
||||
function::Concat({masked_image_latents, masked_image_latents},
|
||||
&masked_image_latents);
|
||||
}
|
||||
int num_channels_mask = mask_t.Shape()[1];
|
||||
int num_channels_masked_image = masked_image_latents.Shape()[1];
|
||||
FDASSERT(
|
||||
NUM_LATENT_CHANNELS + num_channels_mask + num_channels_masked_image ==
|
||||
NUM_UNET_INPUT_CHANNELS,
|
||||
"Incorrect configuration settings! The config of `pipeline.unet` expects"
|
||||
" %d but received `num_channels_latents`: %d + `num_channels_mask`: %d "
|
||||
"+ `num_channels_masked_image`: %d"
|
||||
" = %d. Please verify the config of `pipeline.unet` or your `mask_image` "
|
||||
"or `image` input.",
|
||||
NUM_UNET_INPUT_CHANNELS, NUM_LATENT_CHANNELS, num_channels_mask,
|
||||
num_channels_masked_image,
|
||||
NUM_LATENT_CHANNELS + num_channels_mask + num_channels_masked_image);
|
||||
|
||||
// set timesteps
|
||||
scheduler_->SetTimesteps(num_inference_steps);
|
||||
|
||||
// scale the initial noise by the standard deviation required by the scheduler
|
||||
actual_latents = actual_latents * scheduler_->InitNoiseSigma();
|
||||
|
||||
auto timestep = scheduler_->GetTimesteps();
|
||||
int64_t* timestep_data = reinterpret_cast<int64_t*>(timestep.Data());
|
||||
outputs.resize(unet_->GetOutputInfos().size());
|
||||
inputs.resize(unet_->GetInputInfos().size());
|
||||
inputs[2] = std::move(text_embeddings);
|
||||
auto unet_infos = unet_->GetInputInfos();
|
||||
for (int i = 0; i < timestep.Numel(); ++i) {
|
||||
FDTensor t;
|
||||
function::Slice(timestep, {0}, {i}, &t);
|
||||
inputs[1] = t;
|
||||
// expand the latents if we are doing classifier free guidance
|
||||
FDTensor latent_model_input;
|
||||
if (do_classifier_free_guidance) {
|
||||
function::Concat({actual_latents, actual_latents}, &latent_model_input);
|
||||
} else {
|
||||
latent_model_input = actual_latents;
|
||||
}
|
||||
// concat latents, mask, masked_image_latnets in the channel dimension
|
||||
function::Concat({latent_model_input, mask_t, masked_image_latents},
|
||||
&latent_model_input, 1);
|
||||
scheduler_->ScaleModelInput(latent_model_input, &latent_model_input, {t});
|
||||
inputs[0] = std::move(latent_model_input);
|
||||
// predict the noise residual
|
||||
for (int i = 0; i < unet_infos.size(); ++i) {
|
||||
inputs[i].name = unet_infos[i].name;
|
||||
}
|
||||
unet_->Infer(inputs, &outputs);
|
||||
FDTensor noise_pred = std::move(outputs[0]);
|
||||
// perform guidance
|
||||
if (do_classifier_free_guidance) {
|
||||
std::vector<FDTensor> noise_preds;
|
||||
int dim0 = noise_pred.Shape()[0];
|
||||
function::Split(noise_pred, {dim0 - dim0 / 2, dim0 / 2}, &noise_preds);
|
||||
noise_pred =
|
||||
noise_preds[0] + guidance_scale * (noise_preds[1] - noise_preds[0]);
|
||||
}
|
||||
|
||||
// compute the previous noisy sample x_t -> x_t-1
|
||||
int64_t time = reinterpret_cast<int64_t*>(t.Data())[0];
|
||||
scheduler_->Step(noise_pred, time, actual_latents, &actual_latents);
|
||||
|
||||
// call the callback, if provided
|
||||
if (callback != nullptr && i % callback_steps == 0) {
|
||||
callback(i, time, &actual_latents);
|
||||
}
|
||||
}
|
||||
actual_latents = (1.0f / 0.18215f) * actual_latents;
|
||||
|
||||
// Get vae decoder output
|
||||
int actual_latents_bs = actual_latents.Shape()[0];
|
||||
TensorInfo vae_decoder_info = vae_decoder_->GetInputInfo(0);
|
||||
inputs.resize(1);
|
||||
outputs.resize(vae_decoder_->GetOutputInfos().size());
|
||||
std::vector<FDTensor> decoder_reuslt;
|
||||
for (int i = 0; i < actual_latents_bs; ++i) {
|
||||
function::Slice(actual_latents, {0}, {i}, {i + 1}, &inputs[0]);
|
||||
inputs[0].name = vae_decoder_info.name;
|
||||
vae_decoder_->Infer(inputs, &outputs);
|
||||
decoder_reuslt.emplace_back(std::move(outputs[0]));
|
||||
}
|
||||
FDTensor output_image;
|
||||
function::Concat(decoder_reuslt, &output_image);
|
||||
|
||||
function::Clip(output_image / 2.0f + 0.5f, 0, 1, &output_image);
|
||||
function::Transpose(output_image, &output_image, {0, 2, 3, 1});
|
||||
|
||||
if (output_cv_mat) {
|
||||
output_image = output_image * 255.0f;
|
||||
function::Round(output_image, &output_image);
|
||||
function::Cast(output_image, &output_image, FDDataType::UINT8);
|
||||
}
|
||||
int output_batch_size = output_image.Shape()[0];
|
||||
output_images->resize(output_batch_size);
|
||||
for (int i = 0; i < output_batch_size; ++i) {
|
||||
function::Slice(output_image, {0}, {i}, &(*output_images)[i]);
|
||||
vision::FDMat mask_fdmat_t = vision::FDMat::Create((*output_images)[i]);
|
||||
vision::RGB2BGR::Run(&mask_fdmat_t, vision::ProcLib::OPENCV);
|
||||
mask_fdmat_t.CopyToTensor(&(*output_images)[i]);
|
||||
FDTensor sum;
|
||||
function::Sum((*output_images)[i], &sum, {}, false, true);
|
||||
FDINFO << "sum = " << ((float*)sum.Data())[0] << std::endl;
|
||||
}
|
||||
}
|
||||
} // namespace fastdeploy
|
@@ -0,0 +1,61 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "./scheduler.h"
|
||||
#include "fast_tokenizer/tokenizers/clip_fast_tokenizer.h"
|
||||
#include "fastdeploy/core/fd_tensor.h"
|
||||
#include "fastdeploy/runtime.h"
|
||||
#include "opencv2/core/core.hpp"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
class StableDiffusionInpaintPipeline {
|
||||
public:
|
||||
typedef void (*callback_ptr)(int, int, FDTensor*);
|
||||
|
||||
StableDiffusionInpaintPipeline(
|
||||
std::unique_ptr<Runtime> vae_encoder,
|
||||
std::unique_ptr<Runtime> vae_decoder,
|
||||
std::unique_ptr<Runtime> text_encoder, std::unique_ptr<Runtime> unet,
|
||||
std::unique_ptr<Scheduler> scheduler,
|
||||
const paddlenlp::fast_tokenizer::tokenizers_impl::ClipFastTokenizer&
|
||||
tokenizer);
|
||||
void Predict(const std::vector<std::string>& prompts, const cv::Mat& image,
|
||||
const cv::Mat& mask_image, std::vector<FDTensor>* output_images,
|
||||
int height = 512, int width = 512, int num_inference_steps = 50,
|
||||
float guidance_scale = 7.5,
|
||||
const std::vector<std::string>& negative_prompt = {},
|
||||
int num_images_per_prompt = 1, float eta = 0.0,
|
||||
uint32_t max_length = 77, const FDTensor* latents = nullptr,
|
||||
bool output_cv_mat = true, callback_ptr callback = nullptr,
|
||||
int callback_steps = 1);
|
||||
|
||||
private:
|
||||
void PrepareMaskAndMaskedImage(const cv::Mat& image, const cv::Mat& mask_mat,
|
||||
const std::vector<int64_t>& shape,
|
||||
FDTensor* mask, FDTensor* mask_image);
|
||||
std::unique_ptr<Runtime> vae_encoder_;
|
||||
std::unique_ptr<Runtime> vae_decoder_;
|
||||
std::unique_ptr<Runtime> text_encoder_;
|
||||
std::unique_ptr<Runtime> unet_;
|
||||
std::unique_ptr<Scheduler> scheduler_;
|
||||
paddlenlp::fast_tokenizer::tokenizers_impl::ClipFastTokenizer tokenizer_;
|
||||
};
|
||||
|
||||
} // namespace fastdeploy
|
@@ -19,13 +19,16 @@
|
||||
namespace fastdeploy {
|
||||
|
||||
class Scheduler {
|
||||
public:
|
||||
virtual void SetTimesteps(int num_inference_steps) = 0;
|
||||
virtual FDTensor GetTimesteps() = 0;
|
||||
virtual void Step(const FDTensor& model_output, int timestep,
|
||||
const FDTensor& sample, FDTensor* prev_sample) = 0;
|
||||
virtual void ScaleModelInput(const FDTensor& sample, FDTensor* out,
|
||||
const std::vector<FDTensor>& timesteps = {}) = 0;
|
||||
virtual void AddNoise(const FDTensor& original_samples, const FDTensor& noise,
|
||||
const FDTensor& timesteps, FDTensor* out) = 0;
|
||||
virtual float InitNoiseSigma() = 0;
|
||||
};
|
||||
|
||||
} // namespace fastdeploy
|
||||
|
@@ -88,11 +88,13 @@ template <typename T>
|
||||
void ConcatKernel(const std::vector<FDTensor>& input, FDTensor* output,
|
||||
int axis) {
|
||||
auto output_shape = ComputeAndCheckConcatOutputShape(input, axis);
|
||||
output->Resize(output_shape, TypeToDataType<T>::dtype, output->name,
|
||||
FDTensor output_tmp;
|
||||
output_tmp.Resize(output_shape, TypeToDataType<T>::dtype, output->name,
|
||||
input[0].device);
|
||||
|
||||
ConcatFunctor<T> functor;
|
||||
functor(input, axis, output);
|
||||
functor(input, axis, &output_tmp);
|
||||
*output = std::move(output_tmp);
|
||||
}
|
||||
|
||||
void Concat(const std::vector<FDTensor>& x, FDTensor* out, int axis) {
|
||||
|
@@ -21,6 +21,7 @@
|
||||
#include "fastdeploy/function/elementwise.h"
|
||||
#include "fastdeploy/function/full.h"
|
||||
#include "fastdeploy/function/gather_scatter_along_axis.h"
|
||||
#include "fastdeploy/function/gaussian_random.h"
|
||||
#include "fastdeploy/function/isfinite.h"
|
||||
#include "fastdeploy/function/linspace.h"
|
||||
#include "fastdeploy/function/math.h"
|
||||
|
46
fastdeploy/function/gaussian_random.cc
Normal file
46
fastdeploy/function/gaussian_random.cc
Normal file
@@ -0,0 +1,46 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/function/gaussian_random.h"
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <utility>
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace function {
|
||||
|
||||
template <typename T>
|
||||
void GaussianRandomKernel(const std::vector<int64_t>& shape, float mean,
|
||||
float std, int seed, FDTensor* out) {
|
||||
std::normal_distribution<T> dist(mean, std);
|
||||
|
||||
out->Allocate(shape, TypeToDataType<T>::dtype);
|
||||
int64_t size = out->Numel();
|
||||
T* data = reinterpret_cast<T*>(out->Data());
|
||||
std::mt19937_64 engine;
|
||||
engine.seed(seed);
|
||||
for (int64_t i = 0; i < size; ++i) {
|
||||
data[i] = dist(engine);
|
||||
}
|
||||
}
|
||||
|
||||
void GaussianRandom(const std::vector<int64_t>& shape, FDTensor* out,
|
||||
FDDataType dtype, float mean, float std, int seed) {
|
||||
FD_VISIT_FLOAT_TYPES(dtype, "GaussianRandomKernel", [&]() {
|
||||
GaussianRandomKernel<data_t>(shape, mean, std, seed, out);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace function
|
||||
} // namespace fastdeploy
|
36
fastdeploy/function/gaussian_random.h
Normal file
36
fastdeploy/function/gaussian_random.h
Normal file
@@ -0,0 +1,36 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fastdeploy/core/fd_tensor.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace function {
|
||||
|
||||
/** Output is obtained by gathering entries of axis of x indexed by index and
|
||||
* concatenate them together.
|
||||
@param shape The output tensor shape.
|
||||
@param out the output tensor.
|
||||
@param mean mean value of gaussian random
|
||||
@param std standard value of gaussian random
|
||||
@param seed The seed of random generator.
|
||||
@param dtype The data type of the output Tensor.
|
||||
*/
|
||||
void GaussianRandom(const std::vector<int64_t>& shape, FDTensor* out,
|
||||
FDDataType dtype = FDDataType::FP32, float mean = 0.0f,
|
||||
float std = 1.0f, int seed = 0);
|
||||
|
||||
} // namespace function
|
||||
} // namespace fastdeploy
|
@@ -49,6 +49,7 @@ void TileFunctor(const FDTensor& x,
|
||||
return;
|
||||
}
|
||||
|
||||
FDTensor out_tmp;
|
||||
Eigen::DSizes<Eigen::DenseIndex, Rank> bcast_dims;
|
||||
for (size_t i = 0; i < repeat_times.size(); ++i) {
|
||||
bcast_dims[i] = repeat_times[i];
|
||||
@@ -59,12 +60,14 @@ void TileFunctor(const FDTensor& x,
|
||||
out_shape[i] *= repeat_times[i];
|
||||
}
|
||||
|
||||
out->Allocate(out_shape, x.Dtype());
|
||||
out_tmp.Allocate(out_shape, x.Dtype());
|
||||
auto eigen_x = EigenTensor<T, Rank>::From(x, x_shape);
|
||||
auto eigen_out = EigenTensor<T, Rank>::From(*out, out_shape);
|
||||
auto eigen_out = EigenTensor<T, Rank>::From(out_tmp, out_shape);
|
||||
|
||||
const auto& dev = *EigenDeviceWrapper::GetInstance()->GetDevice();
|
||||
eigen_out.device(dev) = eigen_x.broadcast(bcast_dims);
|
||||
|
||||
*out = std::move(out_tmp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@@ -66,7 +66,8 @@ class FASTDEPLOY_DECL FDLogger {
|
||||
if (!verbose_ && line_ != "") {
|
||||
std::cout << line_ << std::endl;
|
||||
#ifdef __ANDROID__
|
||||
__android_log_print(ANDROID_LOG_INFO, prefix_.c_str(), "%s", line_.c_str());
|
||||
__android_log_print(ANDROID_LOG_INFO, prefix_.c_str(), "%s",
|
||||
line_.c_str());
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -122,6 +123,8 @@ FASTDEPLOY_DECL bool ReadBinaryFromFile(const std::string& file,
|
||||
[&] { \
|
||||
const auto& __dtype__ = TYPE; \
|
||||
switch (__dtype__) { \
|
||||
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::UINT8, uint8_t, \
|
||||
__VA_ARGS__) \
|
||||
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::BOOL, bool, \
|
||||
__VA_ARGS__) \
|
||||
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT32, int32_t, \
|
||||
|
Reference in New Issue
Block a user