mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
234 lines
8.0 KiB
Plaintext
234 lines
8.0 KiB
Plaintext
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "helper.h"
|
|
|
|
inline cudaError_t GetGridSize(int64_t n, int block_size, int num_waves, int* num_blocks) {
|
|
int dev;
|
|
{
|
|
cudaError_t err = cudaGetDevice(&dev);
|
|
if (err != cudaSuccess) {
|
|
return err;
|
|
}
|
|
}
|
|
int sm_count;
|
|
{
|
|
cudaError_t err =
|
|
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
|
|
if (err != cudaSuccess) {
|
|
return err;
|
|
}
|
|
}
|
|
int tpm;
|
|
{
|
|
cudaError_t err = cudaDeviceGetAttribute(
|
|
&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev);
|
|
if (err != cudaSuccess) {
|
|
return err;
|
|
}
|
|
}
|
|
*num_blocks =
|
|
std::max<int>(1,
|
|
std::min<int64_t>((n + block_size - 1) / block_size,
|
|
sm_count * tpm / block_size * num_waves));
|
|
return cudaSuccess;
|
|
}
|
|
|
|
template<typename T, int VecSize>
|
|
__global__ void text_image_scatter_kernel(
|
|
T* input_ptr,
|
|
T* text_gather_ptr,
|
|
T* image_gather_ptr,
|
|
int32_t* token_type_ids,
|
|
int32_t* text_index,
|
|
int32_t* image_index,
|
|
const int64_t hidden_size,
|
|
const int64_t total_element_num
|
|
){
|
|
constexpr int HalfVecSize = VecSize / 2;
|
|
using T_Vec = AlignedVector<T, VecSize>;
|
|
T_Vec input_ptr_vec;
|
|
T_Vec text_imgaes_vec;
|
|
|
|
int64_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
|
|
const int64_t step = blockDim.x * gridDim.x * VecSize;
|
|
|
|
for(int64_t element_idx = global_thread_id * VecSize;
|
|
element_idx < total_element_num;
|
|
element_idx += step){
|
|
int64_t token_idx = element_idx / hidden_size;
|
|
int64_t hidden_offset = element_idx % hidden_size;
|
|
int32_t token_type_ids_num = token_type_ids[token_idx];
|
|
|
|
int64_t input_load_offset = token_idx * hidden_size + hidden_offset;
|
|
|
|
Load<T, VecSize>(input_ptr + input_load_offset, &input_ptr_vec);
|
|
#pragma unroll
|
|
for(int vi = 0; vi < VecSize; ++vi) {
|
|
text_imgaes_vec[vi] = input_ptr_vec[vi];
|
|
}
|
|
|
|
if (token_type_ids_num == 0) {
|
|
int64_t text_load_offset = text_index[token_idx] * hidden_size + hidden_offset;
|
|
Store<T,VecSize>(text_imgaes_vec, text_gather_ptr + text_load_offset);
|
|
|
|
} else {
|
|
int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset;
|
|
Store<T,VecSize>(text_imgaes_vec, image_gather_ptr + image_load_offset);
|
|
}
|
|
}
|
|
}
|
|
|
|
template<typename T, int VecSize>
|
|
__global__ void text_image_gather_kernel(
|
|
T* output_ptr,
|
|
T* text_gather_ptr,
|
|
T* image_gather_ptr,
|
|
int32_t* token_type_ids,
|
|
int32_t* text_index,
|
|
int32_t* image_index,
|
|
const int64_t hidden_size,
|
|
const int64_t total_element_num
|
|
){
|
|
constexpr int HalfVecSize = VecSize / 2;
|
|
using T_Vec = AlignedVector<T, VecSize>;
|
|
T_Vec output_ptr_vec;
|
|
T_Vec text_imgaes_vec;
|
|
|
|
int64_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
|
|
const int64_t step = blockDim.x * gridDim.x * VecSize;
|
|
|
|
for(int64_t element_idx = global_thread_id * VecSize;
|
|
element_idx < total_element_num;
|
|
element_idx += step){
|
|
int64_t token_idx = element_idx / hidden_size;
|
|
int64_t hidden_offset = element_idx % hidden_size;
|
|
int32_t token_type_ids_num = token_type_ids[token_idx];
|
|
|
|
if (token_type_ids_num == 0) {
|
|
int64_t text_load_offset = text_index[token_idx] * hidden_size + hidden_offset;
|
|
Load<T,VecSize>(text_gather_ptr + text_load_offset, &text_imgaes_vec);
|
|
|
|
} else {
|
|
int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset;
|
|
Load<T,VecSize>(image_gather_ptr + image_load_offset, &text_imgaes_vec);
|
|
}
|
|
|
|
#pragma unroll
|
|
for(int vi = 0; vi < VecSize; ++vi) {
|
|
output_ptr_vec[vi] = text_imgaes_vec[vi];
|
|
}
|
|
|
|
int64_t input_load_offset = token_idx * hidden_size + hidden_offset;
|
|
|
|
Store<T, VecSize>(output_ptr_vec, output_ptr + input_load_offset);
|
|
}
|
|
}
|
|
|
|
template <paddle::DataType D>
|
|
void LaunchTextImageGatherScatter(
|
|
paddle::Tensor& input,
|
|
paddle::Tensor& text_input,
|
|
paddle::Tensor& image_input,
|
|
paddle::Tensor& token_type_ids,
|
|
paddle::Tensor& text_index,
|
|
paddle::Tensor& image_index,
|
|
const bool is_scatter) {
|
|
|
|
typedef PDTraits<D> traits_;
|
|
typedef typename traits_::DataType DataType_;
|
|
typedef typename traits_::data_t data_t;
|
|
auto stream = input.stream();
|
|
const auto& in_dims = input.dims();
|
|
const int64_t token_num = in_dims[0];
|
|
const int64_t hidden_size = in_dims[1];
|
|
|
|
|
|
const int VecSize = 16 / sizeof(data_t);
|
|
const int64_t tot_element_num = token_num * hidden_size;
|
|
|
|
int64_t tot_pack_num = (tot_element_num + VecSize - 1) / VecSize;
|
|
|
|
const int block_size = 128;
|
|
int grid_index = (token_num + block_size - 1) / block_size;
|
|
constexpr int32_t kNumWaves = 16;
|
|
int grid_size_x = -1;
|
|
|
|
PADDLE_ENFORCE_GPU_SUCCESS(GetGridSize(tot_pack_num, block_size, kNumWaves, &grid_size_x));
|
|
dim3 grid_dim = dim3(grid_size_x, 1, 1);
|
|
if (is_scatter) {
|
|
text_image_scatter_kernel<DataType_, 8><<<grid_dim, block_size>>>(
|
|
reinterpret_cast<DataType_*>(input.data<data_t>()),
|
|
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
|
|
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
|
|
reinterpret_cast<int32_t*>(token_type_ids.data<int32_t>()),
|
|
reinterpret_cast<int32_t*>(text_index.data<int32_t>()),
|
|
reinterpret_cast<int32_t*>(image_index.data<int32_t>()),
|
|
hidden_size,
|
|
tot_element_num
|
|
);
|
|
} else {
|
|
text_image_gather_kernel<DataType_, 8><<<grid_dim, block_size>>>(
|
|
reinterpret_cast<DataType_*>(input.data<data_t>()),
|
|
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
|
|
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
|
|
reinterpret_cast<int32_t*>(token_type_ids.data<int32_t>()),
|
|
reinterpret_cast<int32_t*>(text_index.data<int32_t>()),
|
|
reinterpret_cast<int32_t*>(image_index.data<int32_t>()),
|
|
hidden_size,
|
|
tot_element_num
|
|
);
|
|
}
|
|
}
|
|
|
|
void TextImageGatherScatter(
|
|
paddle::Tensor& input,
|
|
paddle::Tensor& text_input,
|
|
paddle::Tensor& image_input,
|
|
paddle::Tensor& token_type_ids,
|
|
paddle::Tensor& text_index,
|
|
paddle::Tensor& image_index,
|
|
const bool is_scatter) {
|
|
|
|
switch (input.type()) {
|
|
case paddle::DataType::BFLOAT16: {
|
|
return LaunchTextImageGatherScatter<paddle::DataType::BFLOAT16>(input, text_input, image_input, token_type_ids, text_index, image_index, is_scatter);
|
|
}
|
|
default: {
|
|
PD_THROW(
|
|
"NOT supported data type. Only support BFLOAT16. ");
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
PD_BUILD_STATIC_OP(text_image_gather_scatter)
|
|
.Inputs({"input",
|
|
"text_input",
|
|
"image_input",
|
|
"token_type_ids",
|
|
"text_index",
|
|
"image_index"})
|
|
.Outputs({"text_input_out",
|
|
"image_input_out",
|
|
"text_index_out",
|
|
"image_index_out"})
|
|
.Attrs({"is_scatter:bool"})
|
|
.SetInplaceMap({{"text_input", "text_input_out"},
|
|
{"image_input", "image_input_out"},
|
|
{"text_index", "text_index_out"},
|
|
{"image_index", "image_index_out"}})
|
|
.SetKernelFn(PD_KERNEL(TextImageGatherScatter));
|