diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index a7584162b..63aec90af 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -59,7 +59,6 @@ void AppendAttentionKernel( const paddle::Tensor& decoder_tile_ids_per_batch, const paddle::Tensor& decoder_num_blocks, const paddle::Tensor& set_max_lengths, - const paddle::Tensor& max_len_kv, paddle::Tensor& fmha_out, const paddle::optional& rotary_embs, const paddle::optional& attn_mask, @@ -103,6 +102,7 @@ void AppendAttentionKernel( int max_dec_len_this_time = set_max_lengths.data()[2]; int max_enc_dec_len_this_time = set_max_lengths.data()[3]; int max_just_dec_len_this_time = set_max_lengths.data()[4]; + int max_kv_len_this_time = set_max_lengths.data()[8]; auto main_stream = qkv.stream(); static cudaEvent_t main_event; @@ -245,7 +245,6 @@ void AppendAttentionKernel( if (max_just_dec_len_this_time > 0) { int decoder_num_blocks_data = decoder_num_blocks.data()[0]; - int max_len_kv_data = max_len_kv.data()[0]; cudaStream_t exec_stream; if (max_enc_len_this_time > 0) { @@ -371,20 +370,20 @@ void AppendAttentionKernel( case paddle::DataType::INT8:{ int8_t tmp; dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data, - decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream); + decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream); break; } case paddle::DataType::FLOAT8_E4M3FN:{ phi::dtype::float8_e4m3fn tmp; dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data, - decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream); + decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream); break; } } } else { data_t tmp; dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data, - decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream); + decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream); } if (max_enc_len_this_time > 0) { cudaEventRecord(decoder_event, exec_stream); @@ -413,7 +412,6 @@ std::vector AppendAttention( const paddle::Tensor& decoder_tile_ids_per_batch, const paddle::Tensor& decoder_num_blocks, const paddle::Tensor& set_max_lengths, - const paddle::Tensor& max_len_kv, const paddle::optional& rotary_embs, const paddle::optional& attn_mask, const paddle::optional& qkv_bias, @@ -539,7 +537,6 @@ std::vector AppendAttention( decoder_tile_ids_per_batch, decoder_num_blocks, set_max_lengths, - max_len_kv, fmha_out, rotary_embs, attn_mask, @@ -616,7 +613,6 @@ void AppendAttentionWithOutput( const paddle::Tensor& decoder_tile_ids_per_batch, const paddle::Tensor& decoder_num_blocks, const paddle::Tensor& set_max_lengths, - const paddle::Tensor& max_len_kv, paddle::Tensor& fmha_out, const paddle::optional& rotary_embs, const paddle::optional& attn_mask, @@ -695,7 +691,6 @@ void AppendAttentionWithOutput( decoder_tile_ids_per_batch, decoder_num_blocks, set_max_lengths, - max_len_kv, fmha_out, rotary_embs, attn_mask, @@ -784,7 +779,6 @@ std::vector> AppendAttentionInferShape( const std::vector& decoder_tile_ids_per_batch_shape, const std::vector& decoder_num_blocks_shape, const std::vector& set_max_lengths_shape, - const std::vector& max_len_kv_shape, const paddle::optional>& rotary_embs_shape, const paddle::optional>& attn_mask_shape, const paddle::optional>& qkv_bias_shape, @@ -848,7 +842,6 @@ std::vector AppendAttentionInferDtype( const paddle::DataType& decoder_tile_ids_per_batch_dtype, const paddle::DataType& decoder_num_blocks_dtype, const paddle::DataType& set_max_lengths_dtype, - const paddle::DataType& max_len_kv_dtype, const paddle::optional& rotary_embs_dtype, const paddle::optional& attn_mask_dtype, const paddle::optional& qkv_bias_dtype, @@ -930,7 +923,6 @@ std::vector> AppendAttentionWithOutputInferShape( const std::vector& decoder_tile_ids_per_batch_shape, const std::vector& decoder_num_blocks_shape, const std::vector& set_max_lengths_shape, - const std::vector& max_len_kv_shape, const std::vector& fmha_out_shape, const paddle::optional>& rotary_embs_shape, const paddle::optional>& attn_mask_shape, @@ -987,7 +979,6 @@ std::vector AppendAttentionWithOutputInferDtype( const paddle::DataType& decoder_tile_ids_per_batch_dtype, const paddle::DataType& decoder_num_blocks_dtype, const paddle::DataType& set_max_lengths_dtype, - const paddle::DataType& max_len_kv_dtype, const paddle::DataType& fmha_out_dtype, const paddle::optional& rotary_embs_dtype, const paddle::optional& attn_mask_dtype, @@ -1046,7 +1037,6 @@ PD_BUILD_STATIC_OP(append_attention) "decoder_tile_ids_per_batch", "decoder_num_blocks", "set_max_lengths", - "max_len_kv", paddle::Optional("rotary_embs"), paddle::Optional("attn_mask"), paddle::Optional("qkv_bias"), @@ -1105,7 +1095,6 @@ PD_BUILD_STATIC_OP(append_attention_with_output) "decoder_tile_ids_per_batch", "decoder_num_blocks", "set_max_lengths", - "max_len_kv", "fmha_out", paddle::Optional("rotary_embs"), paddle::Optional("attn_mask"), diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index b5e2baf87..d17316fec 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -19,7 +19,7 @@ template __global__ void -GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time, +GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time, const int *seq_lens_encoder, const int *seq_lens_this_time_merged, const int *seq_lens_encoder_merged, const int *seq_mapping, @@ -37,41 +37,27 @@ GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time, int max_just_dec_merged_len_this_time_this_thread = 0; int max_system_len_this_thread = 0; int max_dec_len_without_system_this_thread = 0; + int max_len_kv_this_thread = 0; for (int i = tid; i < batch_size; i += blockDim.x) { const int seq_len_this_time = seq_lens_this_time[i]; + const int seq_len_decoder = seq_lens_decoder[i]; max_len_this_time_this_thread = max(seq_len_this_time, max_len_this_time_this_thread); max_len_encoder_this_thread = max(seq_lens_encoder[i], max_len_encoder_this_thread); - max_len_decoder_this_thread = max(seq_lens[i], max_len_decoder_this_thread); + max_len_decoder_this_thread = max(seq_len_decoder, max_len_decoder_this_thread); if (seq_len_this_time <= 0) continue; - const int max_just_dec_len_now = seq_lens_encoder[i] > 0 ? 0 : seq_lens[i]; + const int max_just_dec_len_now = seq_lens_encoder[i] > 0 ? 0 : seq_len_decoder; max_len_this_thread = - max(seq_lens[i] + seq_len_this_time, max_len_this_thread); + max(seq_len_decoder + seq_len_this_time, max_len_this_thread); max_just_dec_len_this_thread = max(max_just_dec_len_this_thread, max_just_dec_len_now); - if (system_lens) { - const int real_bid = seq_mapping[i]; - const int system_len_now = system_lens[real_bid]; - max_system_len_this_thread = - max(max_system_len_this_thread, system_len_now); - max_dec_len_without_system_this_thread = - max(max_dec_len_without_system_this_thread, - max_just_dec_len_now - system_len_now); - } - } - if (system_lens) { - for (int i = tid; i < batch_size; i += blockDim.x) { - const int ori_seq_len_this_time = seq_lens_this_time_merged[i]; - if (ori_seq_len_this_time <= 0) - continue; - const int max_just_dec_merged_len_this_time_now = - seq_lens_encoder_merged[i] > 0 ? 0 : ori_seq_len_this_time; - max_just_dec_merged_len_this_time_this_thread = - max(max_just_dec_merged_len_this_time_this_thread, - max_just_dec_merged_len_this_time_now); - } + + if (seq_len_decoder == 0) + continue; + max_len_kv_this_thread = + max(seq_len_this_time + seq_len_decoder, max_len_kv_this_thread); } int total_max_len_this_time = BlockReduce(temp_storage) @@ -94,6 +80,8 @@ GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time, int total_dec_len_without_system = BlockReduce(temp_storage) .Reduce(max_dec_len_without_system_this_thread, MaxOp()); + int total_max_len_kv = + BlockReduce(temp_storage).Reduce(max_len_kv_this_thread, MaxOp()); if (tid == 0) { max_lens[0] = total_max_len_this_time; max_lens[1] = total_max_len_encoder; @@ -103,6 +91,7 @@ GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time, max_lens[5] = total_just_dec_merged; max_lens[6] = total_system_len; max_lens[7] = total_dec_len_without_system; + max_lens[8] = total_max_len_kv; } } @@ -256,29 +245,6 @@ __global__ void split_kv_block(const int *__restrict__ seq_lens_decoder, } } -template -__global__ void -get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time, - const int *seq_lens_decoder, const int batch_size) { - const int tid = threadIdx.x; - - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - int max_len_this_thread = 0; - for (int i = tid; i < batch_size; i += blockDim.x) { - if (seq_lens_decoder[i] == 0) - continue; - max_len_this_thread = - max(seq_lens_this_time[i] + seq_lens_decoder[i], max_len_this_thread); - } - int total = - BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp()); - if (tid == 0) { - *max_seq_lens_out = total; - } -} - void GetBlockShapeAndSplitKVBlock( const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_decoder, @@ -295,7 +261,6 @@ void GetBlockShapeAndSplitKVBlock( paddle::Tensor &kv_batch_ids, // Inplace paddle::Tensor &kv_tile_ids_per_batch, // Inplace paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU - paddle::Tensor &max_len_kv_cpu, // Inplace, CPU const int encoder_block_shape_q, const int decoder_block_shape_q, const int group_size, @@ -319,15 +284,7 @@ void GetBlockShapeAndSplitKVBlock( int max_just_dec_merged_len_this_time = max_len_cpu_ptr[5]; int max_system_len = max_len_cpu_ptr[6]; int max_just_dec_len_without_system = max_len_cpu_ptr[7]; - - auto max_len_kv = - GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place()); - get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>( - max_len_kv.data(), seq_lens_this_time.data(), - seq_lens_decoder.data(), bsz); - - - max_len_kv_cpu.copy_(max_len_kv, max_len_kv_cpu.place(), false); + int max_kv_len_this_time = max_len_cpu_ptr[8]; // decoder if (max_dec_len_this_time > 0) { @@ -430,7 +387,7 @@ void GetBlockShapeAndSplitKVBlock( decoder_num_blocks_cpu.copy_( decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( - decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); + decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); } } else { PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( @@ -492,7 +449,6 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block) "kv_batch_ids", "kv_tile_ids_per_batch", "kv_num_blocks_x_cpu", - "max_len_kv_cpu" }) .Outputs({ diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index ce4fa1420..57d6201ef 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -64,7 +64,7 @@ std::vector AppendAttention( const paddle::Tensor &decoder_batch_ids, const paddle::Tensor &decoder_tile_ids_per_batch, const paddle::Tensor &decoder_num_blocks_cpu, - const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv, + const paddle::Tensor &set_max_lengths, const paddle::optional &rotary_embs, const paddle::optional &attn_mask, const paddle::optional &qkv_bias, @@ -106,7 +106,7 @@ void AppendAttentionWithOutput( const paddle::Tensor &decoder_batch_ids, const paddle::Tensor &decoder_tile_ids_per_batch, const paddle::Tensor &decoder_num_blocks_cpu, - const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv, + const paddle::Tensor &set_max_lengths, paddle::Tensor &fmha_out, const paddle::optional &rotary_embs, const paddle::optional &attn_mask, @@ -315,7 +315,6 @@ void GetBlockShapeAndSplitKVBlock( paddle::Tensor &kv_batch_ids, // Inplace paddle::Tensor &kv_tile_ids_per_batch, // Inplace paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, Pinned Memory - paddle::Tensor &max_len_kv_cpu, // Inplace, Pinned Memory const int encoder_block_shape_q, const int decoder_block_shape_q, const int group_size, diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index 9d780dc19..0114bb53f 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -119,8 +119,6 @@ class ForwardMeta: kv_tile_ids_per_batch: Optional[paddle.Tensor] = None # The number of CUDA blocks to launch in the x-dimension for the append_write_cache_kv kernel, defining its grids.x. kv_num_blocks_x_cpu: Optional[paddle.Tensor] = None - # The maximum sequence length of the KV cache, which may represent the current maximum decoder length. - max_len_kv_cpu: Optional[paddle.Tensor] = None decoder_chunk_size_device: Optional[paddle.Tensor] = None diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 4e015e003..f97bf2015 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -150,7 +150,6 @@ class AppendAttentionBackend(AttentionBackend): forward_meta.kv_batch_ids, forward_meta.kv_tile_ids_per_batch, forward_meta.kv_num_blocks_x_cpu, - forward_meta.max_len_kv_cpu, self.encoder_block_shape_q, self.decoder_block_shape_q, self.group_size, @@ -291,7 +290,6 @@ class AppendAttentionBackend(AttentionBackend): forward_meta.decoder_tile_ids_per_batch, forward_meta.decoder_num_blocks_cpu, forward_meta.max_len_tensor_cpu, - forward_meta.max_len_kv_cpu, res, metadata.rotary_embs, metadata.attn_mask, @@ -347,7 +345,6 @@ class AppendAttentionBackend(AttentionBackend): forward_meta.decoder_tile_ids_per_batch, forward_meta.decoder_num_blocks_cpu, forward_meta.max_len_tensor_cpu, - forward_meta.max_len_kv_cpu, metadata.rotary_embs, metadata.attn_mask, layer.qkv_bias, diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index ee57c7754..54d8595d6 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -207,7 +207,6 @@ class FlashAttentionBackend(AttentionBackend): forward_meta.kv_batch_ids, forward_meta.kv_tile_ids_per_batch, forward_meta.kv_num_blocks_x_cpu, - forward_meta.max_len_kv_cpu, self.encoder_block_shape_q, self.decoder_block_shape_q, self.group_size, @@ -340,7 +339,6 @@ class FlashAttentionBackend(AttentionBackend): forward_meta.decoder_tile_ids_per_batch, # from buffer forward_meta.decoder_num_blocks_cpu, metadata.max_len_tensor_cpu_decoder, - forward_meta.max_len_kv_cpu, metadata.rotary_embs, forward_meta.attn_mask, layer.qkv_bias, diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 5c283c84d..d7d18526f 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -83,6 +83,7 @@ class MLAAttentionMetadata(AttentionMetadata): max_enc_len_this_time: Optional[paddle.Tensor] = None max_dec_len_this_time: Optional[paddle.Tensor] = None + max_kv_len_this_time: Optional[paddle.Tensor] = None class MLAAttentionBackend(AttentionBackend): @@ -199,7 +200,6 @@ class MLAAttentionBackend(AttentionBackend): forward_meta.kv_batch_ids, forward_meta.kv_tile_ids_per_batch, forward_meta.kv_num_blocks_x_cpu, - forward_meta.max_len_kv_cpu, self.encoder_block_shape_q, self.decoder_block_shape_q, self.group_size, @@ -210,6 +210,7 @@ class MLAAttentionBackend(AttentionBackend): # MLA metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1] metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2] + metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[8] # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers @@ -362,7 +363,7 @@ class MLAAttentionBackend(AttentionBackend): forward_meta.decoder_num_blocks_device, forward_meta.decoder_chunk_size_device, metadata.max_dec_len_this_time, - forward_meta.max_len_kv_cpu, + metadata.max_kv_len_this_time, None, # attn_mask None, # qkv_bias None, # qkv_out_scales @@ -478,7 +479,7 @@ class MLAAttentionBackend(AttentionBackend): forward_meta.decoder_num_blocks_device, forward_meta.decoder_chunk_size_device, metadata.max_dec_len_this_time, - forward_meta.max_len_kv_cpu, + metadata.max_kv_len_this_time, None, # attn_mask None, # qkv_bias None, # qkv_out_scales diff --git a/fastdeploy/model_executor/layers/attention/ops/append_attention.py b/fastdeploy/model_executor/layers/attention/ops/append_attention.py index 7cf963687..6216d0cd1 100644 --- a/fastdeploy/model_executor/layers/attention/ops/append_attention.py +++ b/fastdeploy/model_executor/layers/attention/ops/append_attention.py @@ -49,7 +49,6 @@ def append_attention( decoder_tile_ids_per_batch: paddle.Tensor, decoder_num_blocks: paddle.Tensor, set_max_lengths: paddle.Tensor, - max_len_kv: paddle.Tensor, rotary_embs: Optional[paddle.Tensor] = None, attn_mask: Optional[paddle.Tensor] = None, qkv_bias: Optional[paddle.Tensor] = None, @@ -107,7 +106,6 @@ def append_attention( decoder_tile_ids_per_batch, decoder_num_blocks, set_max_lengths, - max_len_kv, rotary_embs, attn_mask, qkv_bias, @@ -169,7 +167,6 @@ def append_attention_with_output( decoder_tile_ids_per_batch: paddle.Tensor, decoder_num_blocks: paddle.Tensor, set_max_lengths: paddle.Tensor, - max_len_kv: paddle.Tensor, out: paddle.tensor, # attention output rotary_embs: Optional[paddle.Tensor] = None, attn_mask: Optional[paddle.Tensor] = None, @@ -228,7 +225,6 @@ def append_attention_with_output( decoder_tile_ids_per_batch, decoder_num_blocks, set_max_lengths, - max_len_kv, out, rotary_embs, attn_mask, diff --git a/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py b/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py index edcf8a692..1cd5f4f14 100644 --- a/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py +++ b/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py @@ -40,7 +40,6 @@ def get_block_shape_and_split_kv_block( kv_batch_ids: paddle.Tensor, kv_tile_ids_per_batch: paddle.Tensor, kv_num_blocks_x_cpu: paddle.Tensor, - max_len_kv_cpu: paddle.Tensor, encoder_block_shape_q: int, decoder_block_shape_q: int, group_size: int, @@ -67,7 +66,6 @@ def get_block_shape_and_split_kv_block( kv_batch_ids, kv_tile_ids_per_batch, kv_num_blocks_x_cpu, - max_len_kv_cpu, encoder_block_shape_q, decoder_block_shape_q, group_size, diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 14a3e4fec..e41563cc1 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -247,7 +247,6 @@ class MTPProposer(Proposer): self.model_inputs["kv_num_blocks_x_cpu"] = paddle.zeros_like( self.target_model_inputs["kv_num_blocks_x_cpu"] ).cpu() - self.model_inputs["max_len_kv_cpu"] = paddle.zeros_like(self.target_model_inputs["max_len_kv_cpu"]).cpu() # Get the attention backend attn_cls = get_attention_backend() @@ -374,7 +373,6 @@ class MTPProposer(Proposer): self.model_inputs["kv_batch_ids"] = None self.model_inputs["kv_tile_ids_per_batch"] = None self.model_inputs["kv_num_blocks_x_cpu"] = None # CPU - self.model_inputs["max_len_kv_cpu"] = None # CPU # Input tokens self.model_inputs["draft_tokens"] = paddle.full( @@ -583,7 +581,6 @@ class MTPProposer(Proposer): kv_batch_ids=self.model_inputs["kv_batch_ids"], kv_tile_ids_per_batch=self.model_inputs["kv_tile_ids_per_batch"], kv_num_blocks_x_cpu=self.model_inputs["kv_num_blocks_x_cpu"], - max_len_kv_cpu=self.model_inputs["max_len_kv_cpu"], ) # Initialzie attention meta data diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py index 5c0580ea8..b8351625a 100644 --- a/fastdeploy/worker/gcu_model_runner.py +++ b/fastdeploy/worker/gcu_model_runner.py @@ -436,7 +436,6 @@ class GCUModelRunner(ModelRunnerBase): self.share_inputs["kv_batch_ids"] = None self.share_inputs["kv_tile_ids_per_batch"] = None self.share_inputs["kv_num_blocks_x_cpu"] = None # CPU - self.share_inputs["max_len_kv_cpu"] = None # CPU # Initialize rotary position embedding tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1)) @@ -614,7 +613,6 @@ class GCUModelRunner(ModelRunnerBase): kv_batch_ids=self.share_inputs["kv_batch_ids"], kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"], kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"], - max_len_kv_cpu=self.share_inputs["max_len_kv_cpu"], ) # Update Batch type for cuda graph @@ -703,7 +701,7 @@ class GCUModelRunner(ModelRunnerBase): self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").cpu() - self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu() + self.share_inputs["max_len_tensor_cpu"] = paddle.full([9], 0, dtype="int32").cpu() self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") @@ -712,7 +710,6 @@ class GCUModelRunner(ModelRunnerBase): self.share_inputs["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") self.share_inputs["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") self.share_inputs["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() - self.share_inputs["max_len_kv_cpu"] = paddle.full([1], 0, dtype="int32").cpu() # Get the attention backend attn_cls = get_attention_backend() diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 1c55b4679..8d1be796f 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -873,7 +873,6 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["kv_batch_ids"] = None self.share_inputs["kv_tile_ids_per_batch"] = None self.share_inputs["kv_num_blocks_x_cpu"] = None # CPU - self.share_inputs["max_len_kv_cpu"] = None # CPU # Initialize rotary position embedding tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1)) @@ -1119,7 +1118,6 @@ class GPUModelRunner(ModelRunnerBase): kv_batch_ids=self.share_inputs["kv_batch_ids"], kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"], kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"], - max_len_kv_cpu=self.share_inputs["max_len_kv_cpu"], ) # Update Batch type for cuda graph for only_decode_batch @@ -1280,7 +1278,7 @@ class GPUModelRunner(ModelRunnerBase): # adapted to cudagraph. self.share_inputs["decoder_num_blocks_device"] = paddle.full([1], 0, dtype="int32") self.share_inputs["decoder_chunk_size_device"] = paddle.full([1], 64, dtype="int32") - self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu() + self.share_inputs["max_len_tensor_cpu"] = paddle.full([9], 0, dtype="int32").cpu() self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") @@ -1289,7 +1287,6 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") self.share_inputs["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") self.share_inputs["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() - self.share_inputs["max_len_kv_cpu"] = paddle.full([1], 0, dtype="int32").cpu() # Get the attention backend attn_cls = get_attention_backend() diff --git a/tests/layers/test_append_attention.py b/tests/layers/test_append_attention.py index 6da6681e7..35a5331cb 100644 --- a/tests/layers/test_append_attention.py +++ b/tests/layers/test_append_attention.py @@ -394,7 +394,6 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): self.kv_batch_ids = paddle.full([self.batch_size], 0, dtype="int32") self.kv_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32") self.kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() - self.max_len_kv_cpu = paddle.full([1], 0, dtype="int32").cpu() self.cache_shape = ( self.max_block_num, @@ -495,7 +494,6 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): self.kv_batch_ids, self.kv_tile_ids_per_batch, self.kv_num_blocks_x_cpu, - self.max_len_kv_cpu, 64, 12, (self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head, @@ -529,7 +527,6 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): self.decoder_tile_ids_per_batch, self.decoder_num_blocks_cpu, self.max_len_tensor_cpu, - self.max_len_kv_cpu, self.rope_emb, # rope_emb None, # attn_mask None, # qkv_bias @@ -591,7 +588,6 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): self.decoder_tile_ids_per_batch, self.decoder_num_blocks_cpu, self.max_len_tensor_cpu, - self.max_len_kv_cpu, self.rope_emb, # rope_emb None, # attn_mask None, # qkv_bias diff --git a/tests/layers/test_append_attention_with_output.py b/tests/layers/test_append_attention_with_output.py index c198d1291..ea2c1802b 100644 --- a/tests/layers/test_append_attention_with_output.py +++ b/tests/layers/test_append_attention_with_output.py @@ -391,7 +391,6 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): self.kv_batch_ids = paddle.full([self.batch_size], 0, dtype="int32") self.kv_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32") self.kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() - self.max_len_kv_cpu = paddle.full([1], 0, dtype="int32").cpu() self.cache_shape = ( self.max_block_num, @@ -476,7 +475,6 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): self.kv_batch_ids, self.kv_tile_ids_per_batch, self.kv_num_blocks_x_cpu, - self.max_len_kv_cpu, 64, 12, (self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head, @@ -512,7 +510,6 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): self.decoder_tile_ids_per_batch, self.decoder_num_blocks_cpu, self.max_len_tensor_cpu, - self.max_len_kv_cpu, out, self.rope_emb, # rope_emb None, # attn_mask diff --git a/tests/operators/test_tree_mask.py b/tests/operators/test_tree_mask.py index a6bb8bd46..795c2354e 100644 --- a/tests/operators/test_tree_mask.py +++ b/tests/operators/test_tree_mask.py @@ -204,14 +204,13 @@ class TestTreeMask(unittest.TestCase): decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory() decoder_num_blocks_device = paddle.full([1], 0, dtype="int32") decoder_chunk_size_device = paddle.full([1], 64, dtype="int32") - max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu() + max_len_tensor_cpu = paddle.full([9], 0, dtype="int32").cpu() encoder_batch_ids = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") encoder_tile_ids_per_batch = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") encoder_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() kv_batch_ids = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") kv_tile_ids_per_batch = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() - max_len_kv_cpu = paddle.full([1], 0, dtype="int32").cpu() q_norm_weight = np.ones([self.head_dim]) k_norm_weight = np.ones([self.head_dim]) self.q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32") @@ -233,7 +232,6 @@ class TestTreeMask(unittest.TestCase): kv_batch_ids, kv_tile_ids_per_batch, kv_num_blocks_x_cpu, - max_len_kv_cpu, encoder_block_shape_q, decoder_block_shape_q, self.num_q_head // self.num_kv_head, @@ -264,7 +262,6 @@ class TestTreeMask(unittest.TestCase): decoder_tile_ids_per_batch, decoder_num_blocks, max_len_tensor_cpu, - max_len_kv_cpu, rotary_embs, attn_mask, None, # qkv_bias