From 8d99bac532d29ed409ab36c19e61b898fa3d3d7c Mon Sep 17 00:00:00 2001 From: K11OntheBoat Date: Tue, 9 Dec 2025 14:17:30 +0800 Subject: [PATCH] Remove CUDA ERROR 9 of inputs of get_padding_offset kernel (#5440) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: K11OntheBoat <“ruianmaidanglao@163.com”> --- custom_ops/gpu_ops/cpp_extensions.cc | 8 +- custom_ops/gpu_ops/get_padding_offset.cu | 11 +- .../speculate_get_padding_offset.cu | 153 +++++++++--------- fastdeploy/model_executor/layers/utils.py | 56 +------ .../model_executor/pre_and_post_process.py | 17 +- tests/layers/test_attention_layer.py | 6 +- tests/operators/test_get_padding_offset.py | 8 +- .../test_speculate_get_padding_offset.py | 5 +- 8 files changed, 97 insertions(+), 167 deletions(-) diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index c52971472..3383269c6 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -388,8 +388,8 @@ void GetBlockShapeAndSplitKVBlock( const int block_size); std::vector GetPaddingOffset(const paddle::Tensor& input_ids, - const paddle::Tensor& token_num, - const paddle::Tensor& seq_len); + const paddle::Tensor& seq_len, + const int64_t token_num_cpu); void SetValueByFlagsAndIdx(const paddle::Tensor& pre_ids_all, const paddle::Tensor& input_ids, @@ -725,9 +725,9 @@ std::vector SpeculateGetPaddingOffset( const paddle::Tensor& input_ids, const paddle::Tensor& draft_tokens, const paddle::Tensor& cum_offsets, - const paddle::Tensor& token_num, const paddle::Tensor& seq_len, - const paddle::Tensor& seq_lens_encoder); + const paddle::Tensor& seq_lens_encoder, + const int64_t token_num_cpu); std::vector SpeculateGetSeqLensOutput( const paddle::Tensor& seq_lens_this_time, diff --git a/custom_ops/gpu_ops/get_padding_offset.cu b/custom_ops/gpu_ops/get_padding_offset.cu index 646e0a159..6493941b7 100644 --- a/custom_ops/gpu_ops/get_padding_offset.cu +++ b/custom_ops/gpu_ops/get_padding_offset.cu @@ -64,8 +64,8 @@ __global__ void PrefixSumKernel(int64_t *ids_remove_padding, } std::vector GetPaddingOffset(const paddle::Tensor &input_ids, - const paddle::Tensor &token_num, - const paddle::Tensor &seq_len) { + const paddle::Tensor &seq_len, + const int64_t cpu_token_num) { #ifdef PADDLE_WITH_CUSTOM_DEVICE auto dev_ctx = static_cast( paddle::experimental::DeviceContextPool::Instance().Get( @@ -77,9 +77,7 @@ std::vector GetPaddingOffset(const paddle::Tensor &input_ids, std::vector input_ids_shape = input_ids.shape(); const int bsz = seq_len.shape()[0]; const int max_seq_len = input_ids_shape[1]; - auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false); - - const int token_num_data = cpu_token_num.data()[0]; + const int token_num_data = cpu_token_num; auto x_remove_padding = paddle::empty( {token_num_data}, paddle::DataType::INT64, input_ids.place()); auto batch_id_per_token = paddle::empty( @@ -124,11 +122,12 @@ std::vector GetPaddingOffsetInferDtype( } PD_BUILD_STATIC_OP(get_padding_offset) - .Inputs({"input_ids", "token_num", "seq_len"}) + .Inputs({"input_ids", "seq_len"}) .Outputs({"x_remove_padding", "batch_id_per_token", "cu_seqlens_q", "cu_seqlens_k"}) + .Attrs({"cpu_token_num: int64_t"}) .SetKernelFn(PD_KERNEL(GetPaddingOffset)) .SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(GetPaddingOffsetInferDtype)); diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu index de9b8333d..d644a4fa3 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu @@ -26,19 +26,19 @@ __global__ void SpeculateRemovePadding(int64_t* output_data, const int* cum_offsets, const int sequence_length, const int max_draft_tokens) { - const int bi = blockIdx.x; - const int tid = threadIdx.x; + const int bi = blockIdx.x; + const int tid = threadIdx.x; - for (int i = tid; i < seq_lens[bi]; i += blockDim.x) { - const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i; - if (seq_lens_encoder[bi] > 0) { - const int src_seq_id = bi * sequence_length + i; - output_data[tgt_seq_id] = input_data[src_seq_id]; - } else { - const int src_seq_id = bi * max_draft_tokens + i; - output_data[tgt_seq_id] = draft_tokens[src_seq_id]; - } + for (int i = tid; i < seq_lens[bi]; i += blockDim.x) { + const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i; + if (seq_lens_encoder[bi] > 0) { + const int src_seq_id = bi * sequence_length + i; + output_data[tgt_seq_id] = input_data[src_seq_id]; + } else { + const int src_seq_id = bi * max_draft_tokens + i; + output_data[tgt_seq_id] = draft_tokens[src_seq_id]; } + } } __global__ void SpeculateGetPaddingOffsetKernel(int* batch_id_per_token, @@ -48,67 +48,65 @@ __global__ void SpeculateGetPaddingOffsetKernel(int* batch_id_per_token, const int* cum_offsets, const int* seq_lens, const int max_seq_len) { - // get padding offset of each batch - const int bi = blockIdx.x; - const int ti = threadIdx.x; - int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1]; - for (int i = ti; i < seq_lens[bi]; i += blockDim.x) { - batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi; - } - if (ti == 0) { - cum_offsets_out[bi] = cum_offset; - int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi]; - cu_seqlens_q[bi + 1] = cum_seq_len; - cu_seqlens_k[bi + 1] = cum_seq_len; - } + // get padding offset of each batch + const int bi = blockIdx.x; + const int ti = threadIdx.x; + int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1]; + for (int i = ti; i < seq_lens[bi]; i += blockDim.x) { + batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi; + } + if (ti == 0) { + cum_offsets_out[bi] = cum_offset; + int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi]; + cu_seqlens_q[bi + 1] = cum_seq_len; + cu_seqlens_k[bi + 1] = cum_seq_len; + } } std::vector SpeculateGetPaddingOffset( const paddle::Tensor& input_ids, const paddle::Tensor& draft_tokens, const paddle::Tensor& cum_offsets, - const paddle::Tensor& token_num, const paddle::Tensor& seq_len, - const paddle::Tensor& seq_lens_encoder) { - auto cu_stream = input_ids.stream(); - std::vector input_ids_shape = input_ids.shape(); - const int bsz = seq_len.shape()[0]; - const int seq_length = input_ids_shape[1]; - const int max_draft_tokens = draft_tokens.shape()[1]; - auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false); - auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false); - - const int token_num_data = cpu_token_num.data()[0]; - auto x_remove_padding = paddle::full( - {token_num_data}, 0, paddle::DataType::INT64, input_ids.place()); - auto batch_id_per_token = paddle::full( - {token_num_data}, 0, paddle::DataType::INT32, input_ids.place()); - auto cu_seqlens_q = - paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); - auto cu_seqlens_k = - paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); - int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128); - SpeculateGetPaddingOffsetKernel<<>>( - batch_id_per_token.data(), - cum_offsets_out.data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - cum_offsets.data(), - seq_len.data(), - seq_length); - SpeculateRemovePadding<<>>( - x_remove_padding.data(), - input_ids.data(), - draft_tokens.data(), - seq_len.data(), - seq_lens_encoder.data(), - cum_offsets_out.data(), - seq_length, - max_draft_tokens); - return {x_remove_padding, - batch_id_per_token, - cu_seqlens_q, - cu_seqlens_k}; // , enc_token_num, dec_token_num}; + const paddle::Tensor& seq_lens_encoder, + const int64_t cpu_token_num) { + auto cu_stream = input_ids.stream(); + std::vector input_ids_shape = input_ids.shape(); + const int bsz = seq_len.shape()[0]; + const int seq_length = input_ids_shape[1]; + const int max_draft_tokens = draft_tokens.shape()[1]; + auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false); + const int token_num_data = cpu_token_num; + auto x_remove_padding = paddle::full( + {token_num_data}, 0, paddle::DataType::INT64, input_ids.place()); + auto batch_id_per_token = paddle::full( + {token_num_data}, 0, paddle::DataType::INT32, input_ids.place()); + auto cu_seqlens_q = + paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); + auto cu_seqlens_k = + paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); + int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128); + SpeculateGetPaddingOffsetKernel<<>>( + batch_id_per_token.data(), + cum_offsets_out.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + cum_offsets.data(), + seq_len.data(), + seq_length); + SpeculateRemovePadding<<>>( + x_remove_padding.data(), + input_ids.data(), + draft_tokens.data(), + seq_len.data(), + seq_lens_encoder.data(), + cum_offsets_out.data(), + seq_length, + max_draft_tokens); + return {x_remove_padding, + batch_id_per_token, + cu_seqlens_q, + cu_seqlens_k}; // , enc_token_num, dec_token_num}; } std::vector> SpeculateGetPaddingOffsetInferShape( @@ -118,9 +116,9 @@ std::vector> SpeculateGetPaddingOffsetInferShape( const std::vector& token_num_shape, const std::vector& seq_len_shape, const std::vector& seq_lens_encoder_shape) { - int64_t bsz = seq_len_shape[0]; - int64_t seq_len = input_ids_shape[1]; - return {{-1}, {-1}, {bsz + 1}, {bsz + 1}}; + int64_t bsz = seq_len_shape[0]; + int64_t seq_len = input_ids_shape[1]; + return {{-1}, {-1}, {bsz + 1}, {bsz + 1}}; } std::vector SpeculateGetPaddingOffsetInferDtype( @@ -130,23 +128,22 @@ std::vector SpeculateGetPaddingOffsetInferDtype( const paddle::DataType& token_num_dtype, const paddle::DataType& seq_len_dtype, const paddle::DataType& seq_lens_encoder_dtype) { - return {input_ids_dtype, - seq_len_dtype, - seq_len_dtype, - seq_len_dtype}; + return {input_ids_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype}; } PD_BUILD_STATIC_OP(speculate_get_padding_offset) - .Inputs({"input_ids", - "draft_tokens", - "cum_offsets", - "token_num", - "seq_len", - "seq_lens_encoder"}) + .Inputs({ + "input_ids", + "draft_tokens", + "cum_offsets", + "seq_len", + "seq_lens_encoder", + }) .Outputs({"x_remove_padding", "batch_id_per_token", "cu_seqlens_q", "cu_seqlens_k"}) + .Attrs({"cpu_token_num: int64_t"}) .SetKernelFn(PD_KERNEL(SpeculateGetPaddingOffset)) .SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetPaddingOffsetInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetPaddingOffsetInferDtype)); diff --git a/fastdeploy/model_executor/layers/utils.py b/fastdeploy/model_executor/layers/utils.py index 914c0260d..66c347f68 100644 --- a/fastdeploy/model_executor/layers/utils.py +++ b/fastdeploy/model_executor/layers/utils.py @@ -27,10 +27,7 @@ from fastdeploy.platforms import current_platform if current_platform.is_cuda() and current_platform.available(): try: - from fastdeploy.model_executor.ops.gpu import ( - get_padding_offset, - speculate_get_padding_offset, - ) + from fastdeploy.model_executor.ops.gpu import get_padding_offset except Exception: raise ImportError( "Verify environment consistency between compilation and FastDeploy installation. " @@ -458,57 +455,6 @@ def remove_padding( ) -def speculate_remove_padding( - max_len: paddle.Tensor, - input_ids: paddle.Tensor, - seq_lens_this_time: paddle.Tensor, - draft_tokens: paddle.Tensor, - seq_lens_encoder: paddle.Tensor, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """ - Remove padding from sequences. - - Args: - max_len (paddle.Tensor): The maximum length of the sequences. - input_ids (paddle.Tensor): The IDs of the input sequences. - seq_lens_this_time (paddle.Tensor): The lengths of the sequences in the current batch. - draft_tokens (paddle.Tensor): The draft tokens. - seq_lens_encoder (paddle.Tensor): The lengths of the encoder sequences. - - Returns: - tuple: A tuple containing: - - The input sequence IDs with padding removed (paddle.Tensor). - - Padding offsets (paddle.Tensor). - - Cumulative offsets (paddle.Tensor). - - Query sequence lengths (paddle.Tensor). - - Key sequence lengths (paddle.Tensor). - """ - if current_platform.is_cuda(): - cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32") - token_num = paddle.sum(seq_lens_this_time) - ( - ids_remove_padding, - cum_offsets, - padding_offset, - cu_seqlens_q, - cu_seqlens_k, - ) = speculate_get_padding_offset( - input_ids, - draft_tokens, - cum_offsets_now, - token_num, - seq_lens_this_time, - seq_lens_encoder, - ) - return ( - ids_remove_padding, - padding_offset, - cum_offsets, - cu_seqlens_q, - cu_seqlens_k, - ) - - class CpuGuard: """CpuGuard""" diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 4a4132597..5b771b6d7 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -183,7 +183,7 @@ def speculate_limit_thinking_content_length( def pre_process( input_ids: paddle.Tensor, - seq_lens_this_time: int, + seq_lens_this_time: paddle.Tensor, speculative_decoding: bool, draft_tokens: Optional[paddle.Tensor] = None, seq_lens_encoder: Optional[paddle.Tensor] = None, @@ -204,15 +204,13 @@ def pre_process( cu_seqlens_q: cu_seqlens_k: """ - token_num = paddle.sum(seq_lens_this_time) - + token_num_cpu = seq_lens_this_time.numpy().sum().item() specific_platform = current_platform.is_cuda() or current_platform.is_maca() or current_platform.is_iluvatar() if specific_platform and not speculative_decoding: # Note(ZKK): This case's code is very simple! ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset( - input_ids, token_num, seq_lens_this_time + input_ids, seq_lens_this_time, token_num_cpu ) - return ( ids_remove_padding, batch_id_per_token, @@ -221,7 +219,6 @@ def pre_process( None, None, ) - # Remove padding max_len = input_ids.shape[1] cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32") @@ -234,12 +231,7 @@ def pre_process( cu_seqlens_q, cu_seqlens_k, ) = speculate_get_padding_offset( - input_ids, - draft_tokens, - cum_offsets_now, - token_num, - seq_lens_this_time, - seq_lens_encoder, + input_ids, draft_tokens, cum_offsets_now, seq_lens_this_time, seq_lens_encoder, token_num_cpu ) seq_lens_output = speculate_get_seq_lens_output( seq_lens_this_time, @@ -257,6 +249,7 @@ def pre_process( max_len, ) else: + token_num = paddle.sum(seq_lens_this_time) ( ids_remove_padding, batch_id_per_token, diff --git a/tests/layers/test_attention_layer.py b/tests/layers/test_attention_layer.py index 0acbada35..deffb5a73 100644 --- a/tests/layers/test_attention_layer.py +++ b/tests/layers/test_attention_layer.py @@ -270,10 +270,10 @@ class TestAttentionPerformance(unittest.TestCase): partial_rotary_factor=fd_config.model_config.partial_rotary_factor, ) - input_ids = paddle.zeros([batch_size, max_model_len], dtype="int64") - token_num = paddle.sum(seq_lens_this_time) + input_ids = paddle.zeros([batch_size, seq_len if mode == ForwardMode.EXTEND else 1], dtype="int64") + token_num = np.sum(seq_lens_this_time) ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset( - input_ids, token_num, seq_lens_this_time + input_ids, seq_lens_this_time, token_num ) forward_meta = ForwardMeta( diff --git a/tests/operators/test_get_padding_offset.py b/tests/operators/test_get_padding_offset.py index 4bbcf8a15..cfa7760d8 100644 --- a/tests/operators/test_get_padding_offset.py +++ b/tests/operators/test_get_padding_offset.py @@ -23,7 +23,7 @@ from fastdeploy.model_executor.ops.gpu import get_padding_offset class TestGetPaddingOffset(unittest.TestCase): def test_get_padding_offset(self): seq_lens = np.array([4, 3, 6], "int32").reshape(-1, 1) - token_num = np.sum(seq_lens) + token_num_cpu = np.sum(seq_lens) input_ids = np.array( [[8, 7, 8, 2, 0, 0, 0, 0, 0, 0], [4, 5, 5, 0, 0, 0, 0, 0, 0, 0], [7, 6, 1, 7, 2, 6, 0, 0, 0, 0]], "int64" ) @@ -32,11 +32,7 @@ class TestGetPaddingOffset(unittest.TestCase): batch_id_per_token, cu_seqlens_q, cu_seqlens_k, - ) = get_padding_offset( - paddle.to_tensor(input_ids), - paddle.to_tensor(token_num), - paddle.to_tensor(seq_lens), - ) + ) = get_padding_offset(paddle.to_tensor(input_ids), paddle.to_tensor(seq_lens), token_num_cpu) ref_x_remove_padding = np.array([8, 7, 8, 2, 4, 5, 5, 7, 6, 1, 7, 2, 6], "int64") ref_batch_id_per_token = np.array([0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2], "int32") diff --git a/tests/operators/test_speculate_get_padding_offset.py b/tests/operators/test_speculate_get_padding_offset.py index a9e0b3031..a8aac690b 100644 --- a/tests/operators/test_speculate_get_padding_offset.py +++ b/tests/operators/test_speculate_get_padding_offset.py @@ -86,14 +86,13 @@ class TestSpeculateGetPaddingOffset(unittest.TestCase): input_ids = np.random.randint(0, 1000, (test_case["bsz"], test_case["max_seq_len"]), dtype=np.int64) draft_tokens = np.random.randint(0, 1000, (test_case["bsz"], max_draft_tokens), dtype=np.int64) - token_num = np.array([test_case["token_num_data"]], dtype=np.int64) + token_num_cpu = np.array([test_case["token_num_data"]], dtype=np.int64).item() input_ids_tensor = paddle.to_tensor(input_ids) draft_tokens_tensor = paddle.to_tensor(draft_tokens) cum_offsets_tensor = paddle.to_tensor(test_case["cum_offsets"]) seq_lens_tensor = paddle.to_tensor(test_case["seq_lens"]) seq_lens_encoder_tensor = paddle.to_tensor(test_case["seq_lens_encoder"]) - token_num_tensor = paddle.to_tensor(token_num) ( x_remove_padding, @@ -104,9 +103,9 @@ class TestSpeculateGetPaddingOffset(unittest.TestCase): input_ids_tensor, draft_tokens_tensor, cum_offsets_tensor, - token_num_tensor, seq_lens_tensor, seq_lens_encoder_tensor, + token_num_cpu, ) (