// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "helper.h" #include "paddle/extension.h" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) #endif __global__ void PrefixSumKernel(int64_t *ids_remove_padding, int *batch_id_per_token, int *cu_seqlens_q, int *cu_seqlens_k, const int64_t *input_data, const int *seq_lens, const int max_seq_len) { const int bi = blockIdx.x; const int tid = threadIdx.x; #ifdef PADDLE_WITH_COREX const int warp_id = threadIdx.x / 64; const int lane_id = threadIdx.x % 64; #else const int warp_id = threadIdx.x / 32; const int lane_id = threadIdx.x % 32; #endif int cum_seq_len = 0; // compute sum of seq_lens[0,1,2,...,bi] for (int i = lane_id; i < bi + 1; i += WARP_SIZE) { cum_seq_len += seq_lens[i]; } for (int offset = 1; offset < WARP_SIZE; offset <<= 1) { const int tmp = __shfl_up_sync(0xffffffff, cum_seq_len, offset); if (lane_id >= offset) cum_seq_len += tmp; } cum_seq_len = __shfl_sync(0xffffffff, cum_seq_len, WARP_SIZE - 1); if (tid == 0) { cu_seqlens_q[bi + 1] = cum_seq_len; cu_seqlens_k[bi + 1] = cum_seq_len; } if (bi == 0 && tid == 0) { cu_seqlens_q[0] = 0; cu_seqlens_k[0] = 0; } for (int i = tid; i < seq_lens[bi]; i += blockDim.x) { const int tgt_seq_id = cum_seq_len - seq_lens[bi] + i; const int src_seq_id = bi * max_seq_len + i; ids_remove_padding[tgt_seq_id] = input_data[src_seq_id]; batch_id_per_token[tgt_seq_id] = bi; } } std::vector GetPaddingOffset(const paddle::Tensor &input_ids, const paddle::Tensor &seq_len, const int64_t cpu_token_num) { #ifdef PADDLE_WITH_CUSTOM_DEVICE auto dev_ctx = static_cast( paddle::experimental::DeviceContextPool::Instance().Get( input_ids.place())); auto cu_stream = dev_ctx->stream(); #else auto cu_stream = input_ids.stream(); #endif std::vector input_ids_shape = input_ids.shape(); const int bsz = seq_len.shape()[0]; const int max_seq_len = input_ids_shape[1]; const int token_num_data = cpu_token_num; auto x_remove_padding = paddle::empty( {token_num_data}, paddle::DataType::INT64, input_ids.place()); auto batch_id_per_token = paddle::empty( {token_num_data}, paddle::DataType::INT32, input_ids.place()); auto cu_seqlens_q = paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place()); auto cu_seqlens_k = paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place()); #ifdef PADDLE_WITH_COREX int blockSize = std::min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128); #else int blockSize = min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128); #endif PrefixSumKernel<<>>( x_remove_padding.data(), batch_id_per_token.data(), cu_seqlens_q.data(), cu_seqlens_k.data(), input_ids.data(), seq_len.data(), max_seq_len); return {x_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k}; } std::vector> GetPaddingOffsetInferShape( const std::vector &input_ids_shape, const std::vector &token_num_shape, const std::vector &seq_len_shape) { int64_t bsz = seq_len_shape[0]; int64_t seq_len = input_ids_shape[1]; return {{-1}, {-1}, {bsz + 1}, {bsz + 1}}; } std::vector GetPaddingOffsetInferDtype( const paddle::DataType &input_ids_dtype, const paddle::DataType &token_num_dtype, const paddle::DataType &seq_len_dtype) { return {input_ids_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype}; } PD_BUILD_STATIC_OP(get_padding_offset) .Inputs({"input_ids", "seq_len"}) .Outputs({"x_remove_padding", "batch_id_per_token", "cu_seqlens_q", "cu_seqlens_k"}) .Attrs({"cpu_token_num: int64_t"}) .SetKernelFn(PD_KERNEL(GetPaddingOffset)) .SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(GetPaddingOffsetInferDtype));