diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 36b193d0c..979a94eb7 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -384,7 +384,6 @@ void GetBlockShapeAndSplitKVBlock( const int decoder_step_token_num); std::vector GetPaddingOffset(const paddle::Tensor& input_ids, - const paddle::Tensor& cum_offsets, const paddle::Tensor& token_num, const paddle::Tensor& seq_len); diff --git a/custom_ops/gpu_ops/get_padding_offset.cu b/custom_ops/gpu_ops/get_padding_offset.cu index f36201389..646e0a159 100644 --- a/custom_ops/gpu_ops/get_padding_offset.cu +++ b/custom_ops/gpu_ops/get_padding_offset.cu @@ -12,127 +12,119 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/extension.h" #include "helper.h" +#include "paddle/extension.h" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) #endif -__global__ void RemovePadding(int64_t *output_data, - const int64_t *input_data, - const int *seq_lens, - const int *cum_offsets, - const int sequence_length) { - const int bi = blockIdx.x; - const int tid = threadIdx.x; +__global__ void PrefixSumKernel(int64_t *ids_remove_padding, + int *batch_id_per_token, + int *cu_seqlens_q, + int *cu_seqlens_k, + const int64_t *input_data, + const int *seq_lens, + const int max_seq_len) { + const int bi = blockIdx.x; + const int tid = threadIdx.x; + const int warp_id = threadIdx.x / 32; + const int lane_id = threadIdx.x % 32; - for (int i = tid; i < seq_lens[bi]; i += blockDim.x) { - const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i; - const int src_seq_id = bi * sequence_length + i; - output_data[tgt_seq_id] = input_data[src_seq_id]; - } -} + int cum_seq_len = 0; -__global__ void GetPaddingOffsetKernel(int *batch_id_per_token, - int *cum_offsets_out, - int *cu_seqlens_q, - int *cu_seqlens_k, - const int *cum_offsets, - const int *seq_lens, - const int max_seq_len) { - // get padding offset of each batch - const int bi = blockIdx.x; - const int ti = threadIdx.x; - int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1]; - for (int i = ti; i < seq_lens[bi]; i += blockDim.x) { -#ifdef PADDLE_WITH_HIP - batch_id_per_token[bi * max_seq_len - cum_offset + i] = cum_offset; -#else - batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi; -#endif - } - if (ti == 0) { - cum_offsets_out[bi] = cum_offset; - int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi]; - cu_seqlens_q[bi + 1] = cum_seq_len; - cu_seqlens_k[bi + 1] = cum_seq_len; - } + // compute sum of seq_lens[0,1,2,...,bi] + for (int i = lane_id; i < bi + 1; i += warpSize) { + cum_seq_len += seq_lens[i]; + } + + for (int offset = 1; offset < warpSize; offset <<= 1) { + const int tmp = __shfl_up_sync(0xffffffff, cum_seq_len, offset); + if (lane_id >= offset) cum_seq_len += tmp; + } + + cum_seq_len = __shfl_sync(0xffffffff, cum_seq_len, warpSize - 1); + + if (tid == 0) { + cu_seqlens_q[bi + 1] = cum_seq_len; + cu_seqlens_k[bi + 1] = cum_seq_len; + } + + if (bi == 0 && tid == 0) { + cu_seqlens_q[0] = 0; + cu_seqlens_k[0] = 0; + } + + for (int i = tid; i < seq_lens[bi]; i += blockDim.x) { + const int tgt_seq_id = cum_seq_len - seq_lens[bi] + i; + const int src_seq_id = bi * max_seq_len + i; + ids_remove_padding[tgt_seq_id] = input_data[src_seq_id]; + batch_id_per_token[tgt_seq_id] = bi; + } } std::vector GetPaddingOffset(const paddle::Tensor &input_ids, - const paddle::Tensor &cum_offsets, const paddle::Tensor &token_num, const paddle::Tensor &seq_len) { #ifdef PADDLE_WITH_CUSTOM_DEVICE - auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(input_ids.place())); - auto cu_stream = dev_ctx->stream(); + auto dev_ctx = static_cast( + paddle::experimental::DeviceContextPool::Instance().Get( + input_ids.place())); + auto cu_stream = dev_ctx->stream(); #else - auto cu_stream = input_ids.stream(); + auto cu_stream = input_ids.stream(); #endif - std::vector input_ids_shape = input_ids.shape(); - const int bsz = seq_len.shape()[0]; - const int seq_length = input_ids_shape[1]; - auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false); - auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false); + std::vector input_ids_shape = input_ids.shape(); + const int bsz = seq_len.shape()[0]; + const int max_seq_len = input_ids_shape[1]; + auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false); - const int token_num_data = cpu_token_num.data()[0]; - auto x_remove_padding = paddle::empty( - {token_num_data}, paddle::DataType::INT64, input_ids.place()); - auto batch_id_per_token = paddle::empty( - {token_num_data}, paddle::DataType::INT32, input_ids.place()); - auto cu_seqlens_q = - paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); - auto cu_seqlens_k = - paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); + const int token_num_data = cpu_token_num.data()[0]; + auto x_remove_padding = paddle::empty( + {token_num_data}, paddle::DataType::INT64, input_ids.place()); + auto batch_id_per_token = paddle::empty( + {token_num_data}, paddle::DataType::INT32, input_ids.place()); + auto cu_seqlens_q = + paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place()); + auto cu_seqlens_k = + paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place()); #ifdef PADDLE_WITH_COREX - int blockSize = std::min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128); + int blockSize = + std::min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128); #else - int blockSize = min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128); + int blockSize = + min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128); #endif - GetPaddingOffsetKernel<<>>( - batch_id_per_token.data(), - cum_offsets_out.data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - cum_offsets.data(), - seq_len.data(), - seq_length); - RemovePadding<<>>( - x_remove_padding.data(), - input_ids.data(), - seq_len.data(), - cum_offsets_out.data(), - seq_length); - return {x_remove_padding, - batch_id_per_token, - cu_seqlens_q, - cu_seqlens_k}; // , enc_token_num, dec_token_num}; + PrefixSumKernel<<>>( + x_remove_padding.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + input_ids.data(), + seq_len.data(), + max_seq_len); + + return {x_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k}; } std::vector> GetPaddingOffsetInferShape( const std::vector &input_ids_shape, - const std::vector &cum_offsets_shape, const std::vector &token_num_shape, const std::vector &seq_len_shape) { - int64_t bsz = seq_len_shape[0]; - int64_t seq_len = input_ids_shape[1]; - return {{-1}, {-1}, {bsz + 1}, {bsz + 1}}; + int64_t bsz = seq_len_shape[0]; + int64_t seq_len = input_ids_shape[1]; + return {{-1}, {-1}, {bsz + 1}, {bsz + 1}}; } std::vector GetPaddingOffsetInferDtype( const paddle::DataType &input_ids_dtype, - const paddle::DataType &cum_offsets_dtype, const paddle::DataType &token_num_dtype, const paddle::DataType &seq_len_dtype) { - return {input_ids_dtype, - seq_len_dtype, - seq_len_dtype, - seq_len_dtype}; + return {input_ids_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype}; } PD_BUILD_STATIC_OP(get_padding_offset) - .Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"}) + .Inputs({"input_ids", "token_num", "seq_len"}) .Outputs({"x_remove_padding", "batch_id_per_token", "cu_seqlens_q", diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index a5354da8d..521d60455 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -206,10 +206,26 @@ def pre_process( cu_seqlens_q: cu_seqlens_k: """ + token_num = paddle.sum(seq_lens_this_time) + + if current_platform.is_cuda() and not speculative_decoding: + # Note(ZKK): This case's code is very simple! + ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset( + input_ids, token_num, seq_lens_this_time + ) + + return ( + ids_remove_padding, + batch_id_per_token, + cu_seqlens_q, + cu_seqlens_k, + None, + None, + ) + # Remove padding max_len = input_ids.shape[1] cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32") - token_num = paddle.sum(seq_lens_this_time) output_padding_offset = None output_cum_offsets = None if speculative_decoding: diff --git a/tests/operators/test_get_padding_offset.py b/tests/operators/test_get_padding_offset.py index bc29de25c..4bbcf8a15 100644 --- a/tests/operators/test_get_padding_offset.py +++ b/tests/operators/test_get_padding_offset.py @@ -22,9 +22,7 @@ from fastdeploy.model_executor.ops.gpu import get_padding_offset class TestGetPaddingOffset(unittest.TestCase): def test_get_padding_offset(self): - max_len = 10 seq_lens = np.array([4, 3, 6], "int32").reshape(-1, 1) - cum_offset = np.cumsum((max_len - seq_lens).flatten(), -1, "int32") token_num = np.sum(seq_lens) input_ids = np.array( [[8, 7, 8, 2, 0, 0, 0, 0, 0, 0], [4, 5, 5, 0, 0, 0, 0, 0, 0, 0], [7, 6, 1, 7, 2, 6, 0, 0, 0, 0]], "int64" @@ -36,7 +34,6 @@ class TestGetPaddingOffset(unittest.TestCase): cu_seqlens_k, ) = get_padding_offset( paddle.to_tensor(input_ids), - paddle.to_tensor(cum_offset), paddle.to_tensor(token_num), paddle.to_tensor(seq_lens), )