remove cum_offsets from get_block_shape_and_split_kv_block (#2913)

* remove padding_offsets from get_padding_offset.cu

* remove padding_offsets from get_padding_offset.cu

* remove padding_offsets from get_padding_offset.cu

* remove cum_offsets from get_block_shape_and_split_kv_block

* remove cum_offsets from get_block_shape_and_split_kv_block
This commit is contained in:
周周周
2025-07-18 16:13:32 +08:00
committed by GitHub
parent e81137e581
commit d306944f4f
6 changed files with 6 additions and 14 deletions

View File

@@ -194,12 +194,12 @@ get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time,
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock( std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &cum_offsets, const paddle::Tensor &seq_lens_this_time,
const int encoder_block_shape_q, const int decoder_block_shape_q, const int encoder_block_shape_q, const int decoder_block_shape_q,
const int group_size, const int block_size, const int group_size, const int block_size,
const int decoder_step_token_num) { const int decoder_step_token_num) {
auto stream = seq_lens_encoder.stream(); auto stream = seq_lens_encoder.stream();
int bsz = cum_offsets.shape()[0]; int bsz = seq_lens_encoder.shape()[0];
auto max_len_tensor = auto max_len_tensor =
GetEmptyTensor({8}, paddle::DataType::INT32, seq_lens_encoder.place()); GetEmptyTensor({8}, paddle::DataType::INT32, seq_lens_encoder.place());
GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder, GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder,
@@ -335,8 +335,7 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype( std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
const paddle::DataType &seq_lens_encoder_dtype, const paddle::DataType &seq_lens_encoder_dtype,
const paddle::DataType &seq_lens_decoder_dtype, const paddle::DataType &seq_lens_decoder_dtype,
const paddle::DataType &seq_lens_this_time_dtype, const paddle::DataType &seq_lens_this_time_dtype) {
const paddle::DataType &cum_offsets_dtype) {
return { return {
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
@@ -347,8 +346,7 @@ std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape( std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
const std::vector<int64_t> &seq_lens_encoder_shape, const std::vector<int64_t> &seq_lens_encoder_shape,
const std::vector<int64_t> &seq_lens_decoder_shape, const std::vector<int64_t> &seq_lens_decoder_shape,
const std::vector<int64_t> &seq_lens_this_time_shape, const std::vector<int64_t> &seq_lens_this_time_shape) {
const std::vector<int64_t> &cum_offsets_shape) {
std::vector<int64_t> dynamic_shape = {-1}; std::vector<int64_t> dynamic_shape = {-1};
return {dynamic_shape, return {dynamic_shape,
@@ -365,8 +363,7 @@ std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
} }
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block) PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
.Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time", .Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time"})
"cum_offsets"})
.Outputs({paddle::Optional("encoder_batch_ids"), .Outputs({paddle::Optional("encoder_batch_ids"),
paddle::Optional("encoder_tile_ids_per_batch"), paddle::Optional("encoder_tile_ids_per_batch"),
paddle::Optional("encoder_num_blocks"), paddle::Optional("encoder_num_blocks"),

View File

@@ -234,7 +234,7 @@ paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor &kv_signal_metadata,
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock( std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &cum_offsets, const paddle::Tensor &seq_lens_this_time,
const int encoder_block_shape_q, const int decoder_block_shape_q, const int encoder_block_shape_q, const int decoder_block_shape_q,
const int group_size, const int block_size, const int group_size, const int block_size,
const int decoder_step_token_num); const int decoder_step_token_num);

View File

@@ -145,7 +145,6 @@ class AppendAttentionBackend(AttentionBackend):
forward_meta.seq_lens_encoder, forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder, forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time, forward_meta.seq_lens_this_time,
forward_meta.cum_offsets,
metadata.encoder_block_shape_q, metadata.encoder_block_shape_q,
metadata.decoder_block_shape_q, metadata.decoder_block_shape_q,
self.num_heads // self.kv_num_heads, self.num_heads // self.kv_num_heads,

View File

@@ -151,7 +151,6 @@ class FlashAttentionBackend(AttentionBackend):
forward_meta.seq_lens_encoder, forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder, forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time, forward_meta.seq_lens_this_time,
forward_meta.cum_offsets,
metadata.encoder_block_shape_q, metadata.encoder_block_shape_q,
metadata.decoder_block_shape_q, metadata.decoder_block_shape_q,
self.num_heads // self.kv_num_heads, self.num_heads // self.kv_num_heads,

View File

@@ -173,7 +173,6 @@ class MLAAttentionBackend(AttentionBackend):
forward_meta.seq_lens_encoder, forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder, forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time, forward_meta.seq_lens_this_time,
forward_meta.cum_offsets,
metadata.encoder_block_shape_q, metadata.encoder_block_shape_q,
metadata.decoder_block_shape_q, metadata.decoder_block_shape_q,
self.num_heads // self.kv_num_heads, self.num_heads // self.kv_num_heads,

View File

@@ -28,7 +28,6 @@ def get_block_shape_and_split_kv_block(
seq_lens_encoder: paddle.Tensor, seq_lens_encoder: paddle.Tensor,
seq_lens_decoder: paddle.Tensor, seq_lens_decoder: paddle.Tensor,
seq_lens_this_time: paddle.Tensor, seq_lens_this_time: paddle.Tensor,
cum_offsets: paddle.Tensor,
encoder_block_shape_q: int, encoder_block_shape_q: int,
decoder_block_shape_q: int, decoder_block_shape_q: int,
group_size: int, group_size: int,
@@ -55,7 +54,6 @@ def get_block_shape_and_split_kv_block(
seq_lens_encoder, seq_lens_encoder,
seq_lens_decoder, seq_lens_decoder,
seq_lens_this_time, seq_lens_this_time,
cum_offsets,
encoder_block_shape_q, encoder_block_shape_q,
decoder_block_shape_q, decoder_block_shape_q,
group_size, group_size,