diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index 42bae453e..e438380e2 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -194,12 +194,12 @@ get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time, std::vector GetBlockShapeAndSplitKVBlock( const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &cum_offsets, + const paddle::Tensor &seq_lens_this_time, const int encoder_block_shape_q, const int decoder_block_shape_q, const int group_size, const int block_size, const int decoder_step_token_num) { auto stream = seq_lens_encoder.stream(); - int bsz = cum_offsets.shape()[0]; + int bsz = seq_lens_encoder.shape()[0]; auto max_len_tensor = GetEmptyTensor({8}, paddle::DataType::INT32, seq_lens_encoder.place()); GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder, @@ -335,8 +335,7 @@ std::vector GetBlockShapeAndSplitKVBlock( std::vector GetBlockShapeAndSplitKVBlockInferDtype( const paddle::DataType &seq_lens_encoder_dtype, const paddle::DataType &seq_lens_decoder_dtype, - const paddle::DataType &seq_lens_this_time_dtype, - const paddle::DataType &cum_offsets_dtype) { + const paddle::DataType &seq_lens_this_time_dtype) { return { paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32, @@ -347,8 +346,7 @@ std::vector GetBlockShapeAndSplitKVBlockInferDtype( std::vector> GetBlockShapeAndSplitKVBlockInferShape( const std::vector &seq_lens_encoder_shape, const std::vector &seq_lens_decoder_shape, - const std::vector &seq_lens_this_time_shape, - const std::vector &cum_offsets_shape) { + const std::vector &seq_lens_this_time_shape) { std::vector dynamic_shape = {-1}; return {dynamic_shape, @@ -365,8 +363,7 @@ std::vector> GetBlockShapeAndSplitKVBlockInferShape( } PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block) - .Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time", - "cum_offsets"}) + .Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time"}) .Outputs({paddle::Optional("encoder_batch_ids"), paddle::Optional("encoder_tile_ids_per_batch"), paddle::Optional("encoder_num_blocks"), diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 38bd4b67f..e1d48f41c 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -234,7 +234,7 @@ paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor &kv_signal_metadata, std::vector GetBlockShapeAndSplitKVBlock( const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &cum_offsets, + const paddle::Tensor &seq_lens_this_time, const int encoder_block_shape_q, const int decoder_block_shape_q, const int group_size, const int block_size, const int decoder_step_token_num); diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 7da552d70..311fb6bce 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -145,7 +145,6 @@ class AppendAttentionBackend(AttentionBackend): forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, - forward_meta.cum_offsets, metadata.encoder_block_shape_q, metadata.decoder_block_shape_q, self.num_heads // self.kv_num_heads, diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index 4c1cde80b..97b0b1bb7 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -151,7 +151,6 @@ class FlashAttentionBackend(AttentionBackend): forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, - forward_meta.cum_offsets, metadata.encoder_block_shape_q, metadata.decoder_block_shape_q, self.num_heads // self.kv_num_heads, diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 3940eb780..e11469e96 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -173,7 +173,6 @@ class MLAAttentionBackend(AttentionBackend): forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, - forward_meta.cum_offsets, metadata.encoder_block_shape_q, metadata.decoder_block_shape_q, self.num_heads // self.kv_num_heads, diff --git a/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py b/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py index 097c228bd..97c3e6f9b 100644 --- a/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py +++ b/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py @@ -28,7 +28,6 @@ def get_block_shape_and_split_kv_block( seq_lens_encoder: paddle.Tensor, seq_lens_decoder: paddle.Tensor, seq_lens_this_time: paddle.Tensor, - cum_offsets: paddle.Tensor, encoder_block_shape_q: int, decoder_block_shape_q: int, group_size: int, @@ -55,7 +54,6 @@ def get_block_shape_and_split_kv_block( seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, - cum_offsets, encoder_block_shape_q, decoder_block_shape_q, group_size,