[get_padding_offset.] clean get_padding_offset.cu (#4777)

[get_padding_offset.] clean get_padding_offset.cu (#4777)
This commit is contained in:
周周周
2025-11-05 10:47:40 +08:00
committed by GitHub
parent 1c3ca48128
commit 937eb3c6ed
4 changed files with 95 additions and 91 deletions

View File

@@ -384,7 +384,6 @@ void GetBlockShapeAndSplitKVBlock(
const int decoder_step_token_num);
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor& input_ids,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& token_num,
const paddle::Tensor& seq_len);

View File

@@ -12,127 +12,119 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "helper.h"
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
__global__ void RemovePadding(int64_t *output_data,
const int64_t *input_data,
const int *seq_lens,
const int *cum_offsets,
const int sequence_length) {
const int bi = blockIdx.x;
const int tid = threadIdx.x;
__global__ void PrefixSumKernel(int64_t *ids_remove_padding,
int *batch_id_per_token,
int *cu_seqlens_q,
int *cu_seqlens_k,
const int64_t *input_data,
const int *seq_lens,
const int max_seq_len) {
const int bi = blockIdx.x;
const int tid = threadIdx.x;
const int warp_id = threadIdx.x / 32;
const int lane_id = threadIdx.x % 32;
for (int i = tid; i < seq_lens[bi]; i += blockDim.x) {
const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i;
const int src_seq_id = bi * sequence_length + i;
output_data[tgt_seq_id] = input_data[src_seq_id];
}
}
int cum_seq_len = 0;
__global__ void GetPaddingOffsetKernel(int *batch_id_per_token,
int *cum_offsets_out,
int *cu_seqlens_q,
int *cu_seqlens_k,
const int *cum_offsets,
const int *seq_lens,
const int max_seq_len) {
// get padding offset of each batch
const int bi = blockIdx.x;
const int ti = threadIdx.x;
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
#ifdef PADDLE_WITH_HIP
batch_id_per_token[bi * max_seq_len - cum_offset + i] = cum_offset;
#else
batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi;
#endif
}
if (ti == 0) {
cum_offsets_out[bi] = cum_offset;
int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi];
cu_seqlens_q[bi + 1] = cum_seq_len;
cu_seqlens_k[bi + 1] = cum_seq_len;
}
// compute sum of seq_lens[0,1,2,...,bi]
for (int i = lane_id; i < bi + 1; i += warpSize) {
cum_seq_len += seq_lens[i];
}
for (int offset = 1; offset < warpSize; offset <<= 1) {
const int tmp = __shfl_up_sync(0xffffffff, cum_seq_len, offset);
if (lane_id >= offset) cum_seq_len += tmp;
}
cum_seq_len = __shfl_sync(0xffffffff, cum_seq_len, warpSize - 1);
if (tid == 0) {
cu_seqlens_q[bi + 1] = cum_seq_len;
cu_seqlens_k[bi + 1] = cum_seq_len;
}
if (bi == 0 && tid == 0) {
cu_seqlens_q[0] = 0;
cu_seqlens_k[0] = 0;
}
for (int i = tid; i < seq_lens[bi]; i += blockDim.x) {
const int tgt_seq_id = cum_seq_len - seq_lens[bi] + i;
const int src_seq_id = bi * max_seq_len + i;
ids_remove_padding[tgt_seq_id] = input_data[src_seq_id];
batch_id_per_token[tgt_seq_id] = bi;
}
}
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &token_num,
const paddle::Tensor &seq_len) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(input_ids.place()));
auto cu_stream = dev_ctx->stream();
auto dev_ctx = static_cast<const phi::CustomContext *>(
paddle::experimental::DeviceContextPool::Instance().Get(
input_ids.place()));
auto cu_stream = dev_ctx->stream();
#else
auto cu_stream = input_ids.stream();
auto cu_stream = input_ids.stream();
#endif
std::vector<int64_t> input_ids_shape = input_ids.shape();
const int bsz = seq_len.shape()[0];
const int seq_length = input_ids_shape[1];
auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false);
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
std::vector<int64_t> input_ids_shape = input_ids.shape();
const int bsz = seq_len.shape()[0];
const int max_seq_len = input_ids_shape[1];
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
const int token_num_data = cpu_token_num.data<int64_t>()[0];
auto x_remove_padding = paddle::empty(
{token_num_data}, paddle::DataType::INT64, input_ids.place());
auto batch_id_per_token = paddle::empty(
{token_num_data}, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q =
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k =
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
const int token_num_data = cpu_token_num.data<int64_t>()[0];
auto x_remove_padding = paddle::empty(
{token_num_data}, paddle::DataType::INT64, input_ids.place());
auto batch_id_per_token = paddle::empty(
{token_num_data}, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q =
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k =
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
#ifdef PADDLE_WITH_COREX
int blockSize = std::min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128);
int blockSize =
std::min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128);
#else
int blockSize = min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128);
int blockSize =
min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128);
#endif
GetPaddingOffsetKernel<<<bsz, 128, 0, cu_stream>>>(
batch_id_per_token.data<int>(),
cum_offsets_out.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
seq_length);
RemovePadding<<<bsz, blockSize, 0, cu_stream>>>(
x_remove_padding.data<int64_t>(),
input_ids.data<int64_t>(),
seq_len.data<int>(),
cum_offsets_out.data<int>(),
seq_length);
return {x_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k}; // , enc_token_num, dec_token_num};
PrefixSumKernel<<<bsz, blockSize, 0, cu_stream>>>(
x_remove_padding.data<int64_t>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
input_ids.data<int64_t>(),
seq_len.data<int>(),
max_seq_len);
return {x_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k};
}
std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
const std::vector<int64_t> &input_ids_shape,
const std::vector<int64_t> &cum_offsets_shape,
const std::vector<int64_t> &token_num_shape,
const std::vector<int64_t> &seq_len_shape) {
int64_t bsz = seq_len_shape[0];
int64_t seq_len = input_ids_shape[1];
return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
int64_t bsz = seq_len_shape[0];
int64_t seq_len = input_ids_shape[1];
return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
}
std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
const paddle::DataType &input_ids_dtype,
const paddle::DataType &cum_offsets_dtype,
const paddle::DataType &token_num_dtype,
const paddle::DataType &seq_len_dtype) {
return {input_ids_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype};
return {input_ids_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype};
}
PD_BUILD_STATIC_OP(get_padding_offset)
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
.Inputs({"input_ids", "token_num", "seq_len"})
.Outputs({"x_remove_padding",
"batch_id_per_token",
"cu_seqlens_q",

View File

@@ -206,10 +206,26 @@ def pre_process(
cu_seqlens_q:
cu_seqlens_k:
"""
token_num = paddle.sum(seq_lens_this_time)
if current_platform.is_cuda() and not speculative_decoding:
# Note(ZKK): This case's code is very simple!
ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset(
input_ids, token_num, seq_lens_this_time
)
return (
ids_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
None,
None,
)
# Remove padding
max_len = input_ids.shape[1]
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32")
token_num = paddle.sum(seq_lens_this_time)
output_padding_offset = None
output_cum_offsets = None
if speculative_decoding:

View File

@@ -22,9 +22,7 @@ from fastdeploy.model_executor.ops.gpu import get_padding_offset
class TestGetPaddingOffset(unittest.TestCase):
def test_get_padding_offset(self):
max_len = 10
seq_lens = np.array([4, 3, 6], "int32").reshape(-1, 1)
cum_offset = np.cumsum((max_len - seq_lens).flatten(), -1, "int32")
token_num = np.sum(seq_lens)
input_ids = np.array(
[[8, 7, 8, 2, 0, 0, 0, 0, 0, 0], [4, 5, 5, 0, 0, 0, 0, 0, 0, 0], [7, 6, 1, 7, 2, 6, 0, 0, 0, 0]], "int64"
@@ -36,7 +34,6 @@ class TestGetPaddingOffset(unittest.TestCase):
cu_seqlens_k,
) = get_padding_offset(
paddle.to_tensor(input_ids),
paddle.to_tensor(cum_offset),
paddle.to_tensor(token_num),
paddle.to_tensor(seq_lens),
)