// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "helper.h" // NOLINT template __global__ void RebuildPaddingKernel(T *output_data, const T *input_data, const int *cum_offsets, const int *seq_len_this_time, const int *seq_len_decoder, const int *seq_len_encoder, const int max_input_length, const int dim_embed, const int elem_nums) { using LoadT = AlignedVector; LoadT src_vec; const int global_idx = blockDim.x * blockIdx.x + threadIdx.x; for (int i = global_idx * VecSize; i < elem_nums; i += gridDim.x * blockDim.x * VecSize) { const int bi = i / dim_embed; const int bias_idx = i % dim_embed; int seq_id = 0; if (seq_len_this_time[bi] == 0) continue; if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue; // if encoder, get last token; just decoder, get first token. if (seq_len_encoder[bi] > 0) seq_id = seq_len_encoder[bi] - 1; const int ori_token_idx = bi * max_input_length - cum_offsets[bi] + seq_id; const int src_offset = ori_token_idx * dim_embed + bias_idx; Load(&input_data[src_offset], &src_vec); Store(src_vec, &output_data[i]); } } template __global__ void RebuildAppendPaddingKernel(T *output_data, const T *input_data, const int *cum_offset, const int *seq_len_this_time, const int *seq_len_decoder, const int *seq_len_encoder, const int *output_padding_offset, const int max_input_length, const int dim_embed, const int64_t output_elem_nums) { AlignedVector src_vec; const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; for (int64_t i = global_idx * VecSize; i < output_elem_nums; i += gridDim.x * blockDim.x * VecSize) { const int out_token_id = i / dim_embed; const int ori_token_id = out_token_id + output_padding_offset[out_token_id]; const int bi = ori_token_id / max_input_length; int seq_id = 0; if (seq_len_this_time[bi] == 0) continue; if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue; // if encoder, get last token; just decoder, get first token. if (seq_len_encoder[bi] > 0) seq_id = seq_len_encoder[bi] - 1; const int input_token_id = ori_token_id - cum_offset[bi] + seq_id; const int bias_idx = i % dim_embed; Load(&input_data[input_token_id * dim_embed + bias_idx], &src_vec); Store(src_vec, &output_data[i]); } } template std::vector rebuild_padding( const paddle::Tensor &tmp_out, // [token_num, dim_embed] const paddle::Tensor &cum_offsets, // [bsz, 1] const paddle::Tensor &seq_len_this_time, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, const paddle::optional &output_padding_offset, int max_input_length) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; #ifdef PADDLE_WITH_CUSTOM_DEVICE auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(tmp_out.place())); auto cu_stream = dev_ctx->stream(); #else auto cu_stream = tmp_out.stream(); #endif std::vector tmp_out_shape = tmp_out.shape(); const int token_num = tmp_out_shape[0]; const int dim_embed = tmp_out_shape[1]; const int bsz = cum_offsets.shape()[0]; paddle::Tensor out; if (output_padding_offset) { int need_delete_token_num = 0; auto seq_lens_encoder_cpu = seq_lens_encoder.copy_to(paddle::CPUPlace(), true); for (int i = 0; i < bsz; ++i) { if (seq_lens_encoder_cpu.data()[i] > 0) { need_delete_token_num += seq_lens_encoder_cpu.data()[i] - 1; } } out = paddle::full({token_num - need_delete_token_num, dim_embed}, 0, D, tmp_out.place()); } else { out = paddle::full({bsz, dim_embed}, 0, tmp_out.dtype(), tmp_out.place()); } constexpr int PackSize = VEC_16B / sizeof(DataType_); int elem_nums = out.numel(); int pack_num = elem_nums / PackSize; const int blocksize = 128; const int grid_size = (pack_num + blocksize - 1) / blocksize; if (output_padding_offset) { RebuildAppendPaddingKernel <<>>( reinterpret_cast(out.data()), reinterpret_cast(tmp_out.data()), cum_offsets.data(), seq_len_this_time.data(), seq_lens_decoder.data(), seq_lens_encoder.data(), output_padding_offset.get_ptr()->data(), max_input_length, dim_embed, elem_nums); } else { RebuildPaddingKernel <<>>( reinterpret_cast(out.data()), reinterpret_cast( const_cast(tmp_out.data())), cum_offsets.data(), seq_len_this_time.data(), seq_lens_decoder.data(), seq_lens_encoder.data(), max_input_length, dim_embed, elem_nums); } return {out}; } paddle::Tensor RebuildPaddingFunc( const paddle::Tensor &tmp_out, // [token_num, dim_embed] const paddle::Tensor &cum_offsets, // [bsz, 1] const paddle::Tensor &seq_len_this_time, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, const paddle::optional &output_padding_offset, int max_input_length) { switch (tmp_out.type()) { case paddle::DataType::BFLOAT16: { return rebuild_padding( tmp_out, cum_offsets, seq_len_this_time, seq_lens_decoder, seq_lens_encoder, output_padding_offset, max_input_length)[0]; } case paddle::DataType::FLOAT16: { return rebuild_padding( tmp_out, cum_offsets, seq_len_this_time, seq_lens_decoder, seq_lens_encoder, output_padding_offset, max_input_length)[0]; } case paddle::DataType::FLOAT32: { return rebuild_padding( tmp_out, cum_offsets, seq_len_this_time, seq_lens_decoder, seq_lens_encoder, output_padding_offset, max_input_length)[0]; } default: { PD_THROW( "NOT supported data type. " "Only float16, bfloat16 and float32 are supported. "); break; } } } std::vector RebuildPadding( const paddle::Tensor &tmp_out, // [token_num, dim_embed] const paddle::Tensor &cum_offsets, // [bsz, 1] const paddle::Tensor &seq_len_this_time, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, const paddle::optional &output_padding_offset, int max_input_length) { return {RebuildPaddingFunc(tmp_out, cum_offsets, seq_len_this_time, seq_lens_decoder, seq_lens_encoder, output_padding_offset, max_input_length)}; } std::vector> RebuildPaddingInferShape( const std::vector &tmp_out_shape, const std::vector &cum_offsets_shape, const std::vector &seq_len_this_time_shape, const std::vector &seq_lens_decoder_shape, const std::vector &seq_lens_encoder_shape, const paddle::optional> &output_padding_offset_shape) { int64_t dim_embed = tmp_out_shape[1]; // whether speculative decoding if (output_padding_offset_shape) { return {{-1, dim_embed}}; } else { int64_t bsz = cum_offsets_shape[0]; return {{bsz, dim_embed}}; } } std::vector RebuildPaddingInferDtype( const paddle::DataType &tmp_out_dtype, const paddle::DataType &cum_offsets_dtype, const paddle::DataType &seq_len_this_time_dtype, const paddle::DataType &seq_lens_decoder_dtype, const paddle::DataType &seq_lens_encoder_dtype, const paddle::optional &output_padding_offset_dtype) { return {tmp_out_dtype}; } PD_BUILD_STATIC_OP(rebuild_padding) .Inputs({"tmp_out", "cum_offsets", "seq_len_this_time", "seq_lens_decoder", "seq_lens_encoder", paddle::Optional("output_padding_offset")}) .Outputs({"out"}) .Attrs({"max_input_length: int"}) .SetKernelFn(PD_KERNEL(RebuildPadding)) .SetInferShapeFn(PD_INFER_SHAPE(RebuildPaddingInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(RebuildPaddingInferDtype));