Remove CUDA ERROR 9 of inputs of get_padding_offset kernel (#5440)

Co-authored-by: K11OntheBoat <“ruianmaidanglao@163.com”>
This commit is contained in:
K11OntheBoat
2025-12-09 14:17:30 +08:00
committed by GitHub
parent 76649b45c1
commit 8d99bac532
8 changed files with 97 additions and 167 deletions

View File

@@ -388,8 +388,8 @@ void GetBlockShapeAndSplitKVBlock(
const int block_size);
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor& input_ids,
const paddle::Tensor& token_num,
const paddle::Tensor& seq_len);
const paddle::Tensor& seq_len,
const int64_t token_num_cpu);
void SetValueByFlagsAndIdx(const paddle::Tensor& pre_ids_all,
const paddle::Tensor& input_ids,
@@ -725,9 +725,9 @@ std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
const paddle::Tensor& input_ids,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& token_num,
const paddle::Tensor& seq_len,
const paddle::Tensor& seq_lens_encoder);
const paddle::Tensor& seq_lens_encoder,
const int64_t token_num_cpu);
std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
const paddle::Tensor& seq_lens_this_time,

View File

@@ -64,8 +64,8 @@ __global__ void PrefixSumKernel(int64_t *ids_remove_padding,
}
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
const paddle::Tensor &token_num,
const paddle::Tensor &seq_len) {
const paddle::Tensor &seq_len,
const int64_t cpu_token_num) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext *>(
paddle::experimental::DeviceContextPool::Instance().Get(
@@ -77,9 +77,7 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
std::vector<int64_t> input_ids_shape = input_ids.shape();
const int bsz = seq_len.shape()[0];
const int max_seq_len = input_ids_shape[1];
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
const int token_num_data = cpu_token_num.data<int64_t>()[0];
const int token_num_data = cpu_token_num;
auto x_remove_padding = paddle::empty(
{token_num_data}, paddle::DataType::INT64, input_ids.place());
auto batch_id_per_token = paddle::empty(
@@ -124,11 +122,12 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
}
PD_BUILD_STATIC_OP(get_padding_offset)
.Inputs({"input_ids", "token_num", "seq_len"})
.Inputs({"input_ids", "seq_len"})
.Outputs({"x_remove_padding",
"batch_id_per_token",
"cu_seqlens_q",
"cu_seqlens_k"})
.Attrs({"cpu_token_num: int64_t"})
.SetKernelFn(PD_KERNEL(GetPaddingOffset))
.SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(GetPaddingOffsetInferDtype));

View File

@@ -26,19 +26,19 @@ __global__ void SpeculateRemovePadding(int64_t* output_data,
const int* cum_offsets,
const int sequence_length,
const int max_draft_tokens) {
const int bi = blockIdx.x;
const int tid = threadIdx.x;
const int bi = blockIdx.x;
const int tid = threadIdx.x;
for (int i = tid; i < seq_lens[bi]; i += blockDim.x) {
const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i;
if (seq_lens_encoder[bi] > 0) {
const int src_seq_id = bi * sequence_length + i;
output_data[tgt_seq_id] = input_data[src_seq_id];
} else {
const int src_seq_id = bi * max_draft_tokens + i;
output_data[tgt_seq_id] = draft_tokens[src_seq_id];
}
for (int i = tid; i < seq_lens[bi]; i += blockDim.x) {
const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i;
if (seq_lens_encoder[bi] > 0) {
const int src_seq_id = bi * sequence_length + i;
output_data[tgt_seq_id] = input_data[src_seq_id];
} else {
const int src_seq_id = bi * max_draft_tokens + i;
output_data[tgt_seq_id] = draft_tokens[src_seq_id];
}
}
}
__global__ void SpeculateGetPaddingOffsetKernel(int* batch_id_per_token,
@@ -48,67 +48,65 @@ __global__ void SpeculateGetPaddingOffsetKernel(int* batch_id_per_token,
const int* cum_offsets,
const int* seq_lens,
const int max_seq_len) {
// get padding offset of each batch
const int bi = blockIdx.x;
const int ti = threadIdx.x;
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi;
}
if (ti == 0) {
cum_offsets_out[bi] = cum_offset;
int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi];
cu_seqlens_q[bi + 1] = cum_seq_len;
cu_seqlens_k[bi + 1] = cum_seq_len;
}
// get padding offset of each batch
const int bi = blockIdx.x;
const int ti = threadIdx.x;
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi;
}
if (ti == 0) {
cum_offsets_out[bi] = cum_offset;
int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi];
cu_seqlens_q[bi + 1] = cum_seq_len;
cu_seqlens_k[bi + 1] = cum_seq_len;
}
}
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
const paddle::Tensor& input_ids,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& token_num,
const paddle::Tensor& seq_len,
const paddle::Tensor& seq_lens_encoder) {
auto cu_stream = input_ids.stream();
std::vector<int64_t> input_ids_shape = input_ids.shape();
const int bsz = seq_len.shape()[0];
const int seq_length = input_ids_shape[1];
const int max_draft_tokens = draft_tokens.shape()[1];
auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false);
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
const int token_num_data = cpu_token_num.data<int64_t>()[0];
auto x_remove_padding = paddle::full(
{token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
auto batch_id_per_token = paddle::full(
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q =
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k =
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128);
SpeculateGetPaddingOffsetKernel<<<bsz, 128, 0, cu_stream>>>(
batch_id_per_token.data<int>(),
cum_offsets_out.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
seq_length);
SpeculateRemovePadding<<<bsz, blockSize, 0, cu_stream>>>(
x_remove_padding.data<int64_t>(),
input_ids.data<int64_t>(),
draft_tokens.data<int64_t>(),
seq_len.data<int>(),
seq_lens_encoder.data<int>(),
cum_offsets_out.data<int>(),
seq_length,
max_draft_tokens);
return {x_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k}; // , enc_token_num, dec_token_num};
const paddle::Tensor& seq_lens_encoder,
const int64_t cpu_token_num) {
auto cu_stream = input_ids.stream();
std::vector<int64_t> input_ids_shape = input_ids.shape();
const int bsz = seq_len.shape()[0];
const int seq_length = input_ids_shape[1];
const int max_draft_tokens = draft_tokens.shape()[1];
auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false);
const int token_num_data = cpu_token_num;
auto x_remove_padding = paddle::full(
{token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
auto batch_id_per_token = paddle::full(
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q =
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k =
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128);
SpeculateGetPaddingOffsetKernel<<<bsz, 128, 0, cu_stream>>>(
batch_id_per_token.data<int>(),
cum_offsets_out.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
seq_length);
SpeculateRemovePadding<<<bsz, blockSize, 0, cu_stream>>>(
x_remove_padding.data<int64_t>(),
input_ids.data<int64_t>(),
draft_tokens.data<int64_t>(),
seq_len.data<int>(),
seq_lens_encoder.data<int>(),
cum_offsets_out.data<int>(),
seq_length,
max_draft_tokens);
return {x_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k}; // , enc_token_num, dec_token_num};
}
std::vector<std::vector<int64_t>> SpeculateGetPaddingOffsetInferShape(
@@ -118,9 +116,9 @@ std::vector<std::vector<int64_t>> SpeculateGetPaddingOffsetInferShape(
const std::vector<int64_t>& token_num_shape,
const std::vector<int64_t>& seq_len_shape,
const std::vector<int64_t>& seq_lens_encoder_shape) {
int64_t bsz = seq_len_shape[0];
int64_t seq_len = input_ids_shape[1];
return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
int64_t bsz = seq_len_shape[0];
int64_t seq_len = input_ids_shape[1];
return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
}
std::vector<paddle::DataType> SpeculateGetPaddingOffsetInferDtype(
@@ -130,23 +128,22 @@ std::vector<paddle::DataType> SpeculateGetPaddingOffsetInferDtype(
const paddle::DataType& token_num_dtype,
const paddle::DataType& seq_len_dtype,
const paddle::DataType& seq_lens_encoder_dtype) {
return {input_ids_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype};
return {input_ids_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype};
}
PD_BUILD_STATIC_OP(speculate_get_padding_offset)
.Inputs({"input_ids",
"draft_tokens",
"cum_offsets",
"token_num",
"seq_len",
"seq_lens_encoder"})
.Inputs({
"input_ids",
"draft_tokens",
"cum_offsets",
"seq_len",
"seq_lens_encoder",
})
.Outputs({"x_remove_padding",
"batch_id_per_token",
"cu_seqlens_q",
"cu_seqlens_k"})
.Attrs({"cpu_token_num: int64_t"})
.SetKernelFn(PD_KERNEL(SpeculateGetPaddingOffset))
.SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetPaddingOffsetInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetPaddingOffsetInferDtype));

View File

@@ -27,10 +27,7 @@ from fastdeploy.platforms import current_platform
if current_platform.is_cuda() and current_platform.available():
try:
from fastdeploy.model_executor.ops.gpu import (
get_padding_offset,
speculate_get_padding_offset,
)
from fastdeploy.model_executor.ops.gpu import get_padding_offset
except Exception:
raise ImportError(
"Verify environment consistency between compilation and FastDeploy installation. "
@@ -458,57 +455,6 @@ def remove_padding(
)
def speculate_remove_padding(
max_len: paddle.Tensor,
input_ids: paddle.Tensor,
seq_lens_this_time: paddle.Tensor,
draft_tokens: paddle.Tensor,
seq_lens_encoder: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""
Remove padding from sequences.
Args:
max_len (paddle.Tensor): The maximum length of the sequences.
input_ids (paddle.Tensor): The IDs of the input sequences.
seq_lens_this_time (paddle.Tensor): The lengths of the sequences in the current batch.
draft_tokens (paddle.Tensor): The draft tokens.
seq_lens_encoder (paddle.Tensor): The lengths of the encoder sequences.
Returns:
tuple: A tuple containing:
- The input sequence IDs with padding removed (paddle.Tensor).
- Padding offsets (paddle.Tensor).
- Cumulative offsets (paddle.Tensor).
- Query sequence lengths (paddle.Tensor).
- Key sequence lengths (paddle.Tensor).
"""
if current_platform.is_cuda():
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32")
token_num = paddle.sum(seq_lens_this_time)
(
ids_remove_padding,
cum_offsets,
padding_offset,
cu_seqlens_q,
cu_seqlens_k,
) = speculate_get_padding_offset(
input_ids,
draft_tokens,
cum_offsets_now,
token_num,
seq_lens_this_time,
seq_lens_encoder,
)
return (
ids_remove_padding,
padding_offset,
cum_offsets,
cu_seqlens_q,
cu_seqlens_k,
)
class CpuGuard:
"""CpuGuard"""

View File

@@ -183,7 +183,7 @@ def speculate_limit_thinking_content_length(
def pre_process(
input_ids: paddle.Tensor,
seq_lens_this_time: int,
seq_lens_this_time: paddle.Tensor,
speculative_decoding: bool,
draft_tokens: Optional[paddle.Tensor] = None,
seq_lens_encoder: Optional[paddle.Tensor] = None,
@@ -204,15 +204,13 @@ def pre_process(
cu_seqlens_q:
cu_seqlens_k:
"""
token_num = paddle.sum(seq_lens_this_time)
token_num_cpu = seq_lens_this_time.numpy().sum().item()
specific_platform = current_platform.is_cuda() or current_platform.is_maca() or current_platform.is_iluvatar()
if specific_platform and not speculative_decoding:
# Note(ZKK): This case's code is very simple!
ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset(
input_ids, token_num, seq_lens_this_time
input_ids, seq_lens_this_time, token_num_cpu
)
return (
ids_remove_padding,
batch_id_per_token,
@@ -221,7 +219,6 @@ def pre_process(
None,
None,
)
# Remove padding
max_len = input_ids.shape[1]
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32")
@@ -234,12 +231,7 @@ def pre_process(
cu_seqlens_q,
cu_seqlens_k,
) = speculate_get_padding_offset(
input_ids,
draft_tokens,
cum_offsets_now,
token_num,
seq_lens_this_time,
seq_lens_encoder,
input_ids, draft_tokens, cum_offsets_now, seq_lens_this_time, seq_lens_encoder, token_num_cpu
)
seq_lens_output = speculate_get_seq_lens_output(
seq_lens_this_time,
@@ -257,6 +249,7 @@ def pre_process(
max_len,
)
else:
token_num = paddle.sum(seq_lens_this_time)
(
ids_remove_padding,
batch_id_per_token,

View File

@@ -270,10 +270,10 @@ class TestAttentionPerformance(unittest.TestCase):
partial_rotary_factor=fd_config.model_config.partial_rotary_factor,
)
input_ids = paddle.zeros([batch_size, max_model_len], dtype="int64")
token_num = paddle.sum(seq_lens_this_time)
input_ids = paddle.zeros([batch_size, seq_len if mode == ForwardMode.EXTEND else 1], dtype="int64")
token_num = np.sum(seq_lens_this_time)
ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset(
input_ids, token_num, seq_lens_this_time
input_ids, seq_lens_this_time, token_num
)
forward_meta = ForwardMeta(

View File

@@ -23,7 +23,7 @@ from fastdeploy.model_executor.ops.gpu import get_padding_offset
class TestGetPaddingOffset(unittest.TestCase):
def test_get_padding_offset(self):
seq_lens = np.array([4, 3, 6], "int32").reshape(-1, 1)
token_num = np.sum(seq_lens)
token_num_cpu = np.sum(seq_lens)
input_ids = np.array(
[[8, 7, 8, 2, 0, 0, 0, 0, 0, 0], [4, 5, 5, 0, 0, 0, 0, 0, 0, 0], [7, 6, 1, 7, 2, 6, 0, 0, 0, 0]], "int64"
)
@@ -32,11 +32,7 @@ class TestGetPaddingOffset(unittest.TestCase):
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
) = get_padding_offset(
paddle.to_tensor(input_ids),
paddle.to_tensor(token_num),
paddle.to_tensor(seq_lens),
)
) = get_padding_offset(paddle.to_tensor(input_ids), paddle.to_tensor(seq_lens), token_num_cpu)
ref_x_remove_padding = np.array([8, 7, 8, 2, 4, 5, 5, 7, 6, 1, 7, 2, 6], "int64")
ref_batch_id_per_token = np.array([0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2], "int32")

View File

@@ -86,14 +86,13 @@ class TestSpeculateGetPaddingOffset(unittest.TestCase):
input_ids = np.random.randint(0, 1000, (test_case["bsz"], test_case["max_seq_len"]), dtype=np.int64)
draft_tokens = np.random.randint(0, 1000, (test_case["bsz"], max_draft_tokens), dtype=np.int64)
token_num = np.array([test_case["token_num_data"]], dtype=np.int64)
token_num_cpu = np.array([test_case["token_num_data"]], dtype=np.int64).item()
input_ids_tensor = paddle.to_tensor(input_ids)
draft_tokens_tensor = paddle.to_tensor(draft_tokens)
cum_offsets_tensor = paddle.to_tensor(test_case["cum_offsets"])
seq_lens_tensor = paddle.to_tensor(test_case["seq_lens"])
seq_lens_encoder_tensor = paddle.to_tensor(test_case["seq_lens_encoder"])
token_num_tensor = paddle.to_tensor(token_num)
(
x_remove_padding,
@@ -104,9 +103,9 @@ class TestSpeculateGetPaddingOffset(unittest.TestCase):
input_ids_tensor,
draft_tokens_tensor,
cum_offsets_tensor,
token_num_tensor,
seq_lens_tensor,
seq_lens_encoder_tensor,
token_num_cpu,
)
(