diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index 0cf20ebbd..a46f427b9 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -195,22 +195,25 @@ std::vector GetBlockShapeAndSplitKVBlock( const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_this_time, - const int encoder_block_shape_q, const int decoder_block_shape_q, - const int group_size, const int block_size, - const int decoder_step_token_num) { + paddle::Tensor &decoder_batch_ids, // Inplace + paddle::Tensor &decoder_tile_ids_per_batch, // Inplace + paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory + paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int group_size, + const int block_size, + const int decoder_step_token_num) +{ auto stream = seq_lens_encoder.stream(); int bsz = seq_lens_this_time.shape()[0]; - auto max_len_tensor = - GetEmptyTensor({8}, paddle::DataType::INT32, seq_lens_encoder.place()); - GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder, - max_len_tensor, bsz); - // max_len_this_time, max_enc_len_this_time, max_dec_len_this_time, - // max_enc_dec_len_this_time, max_just_dec_len_this_time, - // max_just_dec_merged_len_this_time, max_system_len, - // max_just_dec_len_without_system - auto max_len_cpu = max_len_tensor.copy_to(paddle::CPUPlace(), false); - auto max_len_cpu_ptr = max_len_cpu.data(); + paddle::Tensor max_len_tensor_gpu = GetEmptyTensor({max_len_tensor_cpu.shape()[0]}, paddle::DataType::INT32, seq_lens_this_time.place()); + GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder, + max_len_tensor_gpu, bsz); + max_len_tensor_cpu.copy_(max_len_tensor_gpu, max_len_tensor_cpu.place(), false); + + auto max_len_cpu_ptr = max_len_tensor_cpu.data(); int max_len_this_time = max_len_cpu_ptr[0]; int max_enc_len_this_time = max_len_cpu_ptr[1]; int max_dec_len_this_time = max_len_cpu_ptr[2]; @@ -222,14 +225,11 @@ std::vector GetBlockShapeAndSplitKVBlock( paddle::Tensor encoder_batch_ids; paddle::Tensor encoder_tile_ids_per_batch; - paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/ + paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/ paddle::Tensor kv_batch_ids; paddle::Tensor kv_tile_ids_per_batch; - paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/ - paddle::Tensor decoder_batch_ids; - paddle::Tensor decoder_tile_ids_per_batch; - paddle::Tensor decoder_num_blocks_x_cpu; /*cpu*/ - paddle::Tensor max_len_kv_cpu; /*cpu*/ + paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/ + paddle::Tensor max_len_kv_cpu; /*cpu*/ auto max_len_kv = GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place()); @@ -291,92 +291,64 @@ std::vector GetBlockShapeAndSplitKVBlock( kv_num_blocks_x_cpu = GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); } - if (max_just_dec_len_this_time > 0) { - const uint32_t decoder_max_tile_size_per_bs_q = - div_up((decoder_step_token_num * group_size), decoder_block_shape_q); - decoder_batch_ids = - GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q}, - paddle::DataType::INT32, seq_lens_encoder.place()); - decoder_tile_ids_per_batch = - GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q}, - paddle::DataType::INT32, seq_lens_encoder.place()); + if (max_just_dec_len_this_time > 0) { + // Clear buffer + const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q); + const uint32_t decoder_batch_shape = bsz * decoder_max_tile_size_per_bs_q; + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data(), 0, decoder_batch_shape * sizeof(int32_t), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data(), 0, decoder_batch_shape * sizeof(int32_t), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_x_cpu.data(), 0, sizeof(int32_t), stream)); + auto decoder_num_blocks_x = GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place()); split_q_block<<<1, 32, 0, stream>>>( - seq_lens_this_time.data(), seq_lens_encoder.data(), - decoder_batch_ids.data(), decoder_tile_ids_per_batch.data(), - decoder_num_blocks_x.data(), bsz, decoder_block_shape_q, + seq_lens_this_time.data(), + seq_lens_encoder.data(), + decoder_batch_ids.data(), + decoder_tile_ids_per_batch.data(), + decoder_num_blocks_x.data(), + bsz, + decoder_block_shape_q, group_size); - decoder_num_blocks_x_cpu = - decoder_num_blocks_x.copy_to(paddle::CPUPlace(), false); - } else { - decoder_batch_ids = - GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); - decoder_tile_ids_per_batch = - GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); - decoder_num_blocks_x_cpu = - GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace()); + decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false); } - return {encoder_batch_ids, - encoder_tile_ids_per_batch, - encoder_num_blocks_x_cpu, /*cpu*/ - kv_batch_ids, - kv_tile_ids_per_batch, - kv_num_blocks_x_cpu, /*cpu*/ - decoder_batch_ids, - decoder_tile_ids_per_batch, - decoder_num_blocks_x_cpu, /*cpu*/ - max_len_kv_cpu /*cpu*/, - max_len_cpu}; -} - -std::vector GetBlockShapeAndSplitKVBlockInferDtype( - const paddle::DataType &seq_lens_encoder_dtype, - const paddle::DataType &seq_lens_decoder_dtype, - const paddle::DataType &seq_lens_this_time_dtype) { return { - paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32, - paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32, - paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32, - paddle::DataType::INT32, paddle::DataType::INT32}; -} - -std::vector> GetBlockShapeAndSplitKVBlockInferShape( - const std::vector &seq_lens_encoder_shape, - const std::vector &seq_lens_decoder_shape, - const std::vector &seq_lens_this_time_shape) { - std::vector dynamic_shape = {-1}; - - return {dynamic_shape, - dynamic_shape, - {1}, - dynamic_shape, - dynamic_shape, - {1}, - dynamic_shape, - dynamic_shape, - {1}, - {1}, - {8}}; + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_x_cpu, /*cpu*/ + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, /*cpu*/ + max_len_kv_cpu, /*cpu*/ + }; } PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block) - .Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time"}) - .Outputs({paddle::Optional("encoder_batch_ids"), - paddle::Optional("encoder_tile_ids_per_batch"), - paddle::Optional("encoder_num_blocks"), - paddle::Optional("kv_batch_ids"), - paddle::Optional("kv_tile_ids_per_batch"), - paddle::Optional("kv_num_blocks"), - paddle::Optional("decoder_batch_ids"), - paddle::Optional("decoder_tile_ids_per_batch"), - paddle::Optional("decoder_num_blocks"), - paddle::Optional("max_len_kv"), "set_max_lengths"}) - .Attrs({"encoder_block_shape_q: int", "decoder_block_shape_q: int", - "group_size: int", "block_size: int", - "decoder_step_token_num: int"}) - .SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock)) - .SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype)); + .Inputs({ + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "decoder_batch_ids", + "decoder_tile_ids_per_batch", + "decoder_num_blocks_x_cpu", + "max_len_tensor_cpu" + }) + .Outputs({ + paddle::Optional("encoder_batch_ids"), + paddle::Optional("encoder_tile_ids_per_batch"), + paddle::Optional("encoder_num_blocks_x_cpu"), + paddle::Optional("kv_batch_ids"), + paddle::Optional("kv_tile_ids_per_batch"), + paddle::Optional("kv_num_blocks_x_cpu"), + "max_len_kv_cpu" + }) + .Attrs({ + "encoder_block_shape_q: int", + "decoder_block_shape_q: int", + "group_size: int", + "block_size: int", + "decoder_step_token_num: int" + }) + .SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock)); diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 266d50599..b4d7b952d 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -235,8 +235,14 @@ std::vector GetBlockShapeAndSplitKVBlock( const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_this_time, - const int encoder_block_shape_q, const int decoder_block_shape_q, - const int group_size, const int block_size, + paddle::Tensor &decoder_batch_ids, // Inplace + paddle::Tensor &decoder_tile_ids_per_batch, // Inplace + paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory + paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int group_size, + const int block_size, const int decoder_step_token_num); std::vector GetPaddingOffset(const paddle::Tensor &input_ids, diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index 8ee6396fc..be5d7f702 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -77,6 +77,10 @@ class ForwardMeta: decoder_batch_ids: Optional[paddle.Tensor] = None # Tile ID for each batch of the decoder. Used by attention backend. decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None + # The number of blocks that attention backend can use in decode stage + decoder_num_blocks_cpu: Optional[paddle.Tensor] = None + # Recorded multiple lengths related to prefill or decode + max_len_tensor_cpu: Optional[paddle.Tensor] = None # Sequence length of encoder for ever batch seq_lens_encoder: Optional[paddle.Tensor] = None diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index a148d3250..cffc4adf7 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -48,17 +48,13 @@ class AppendAttentionMetadata(AttentionMetadata): AppendAttentionMetadata """ - max_len_kv: paddle.Tensor = None - set_max_lengths: int = -1 encoder_batch_ids: paddle.Tensor = None encoder_tile_ids_per_batch: paddle.Tensor = None encoder_num_blocks: paddle.Tensor = None kv_batch_ids: paddle.Tensor = None kv_tile_ids_per_batch: paddle.Tensor = None kv_num_blocks: paddle.Tensor = None - decoder_batch_ids: paddle.Tensor = None - decoder_tile_ids_per_batch: paddle.Tensor = None - decoder_num_blocks: paddle.Tensor = None + max_len_kv: paddle.Tensor = None _dtype: paddle.dtype = paddle.bfloat16 encoder_max_partition_size: int = 32768 @@ -66,8 +62,6 @@ class AppendAttentionMetadata(AttentionMetadata): block_tables: Optional[paddle.Tensor] = None rotary_embs: Optional[paddle.Tensor] = None attn_mask: Optional[paddle.Tensor] = None - encoder_block_shape_q: int = -1 - decoder_block_shape_q: int = -1 _fuse_kernel_compute_dtype: str = "bf16" # pd_disaggregation @@ -89,6 +83,8 @@ class AppendAttentionBackend(AttentionBackend): kv_num_heads: int, num_heads: int, head_dim: int, + encoder_block_shape_q: int = -1, + decoder_block_shape_q: int = -1, ) -> None: """ AppendAttentionBackend __init__ @@ -110,9 +106,12 @@ class AppendAttentionBackend(AttentionBackend): self.kv_num_heads: int = kv_num_heads self.num_heads: int = num_heads + self.group_size: int = self.num_heads // self.kv_num_heads self.head_dim: int = fd_config.model_config.head_dim self.num_layers: int = fd_config.model_config.num_hidden_layers self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 32768)) + self.encoder_block_shape_q: int = encoder_block_shape_q + self.decoder_block_shape_q: int = decoder_block_shape_q self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode @@ -126,8 +125,6 @@ class AppendAttentionBackend(AttentionBackend): def init_attention_metadata(self, forward_meta: ForwardMeta): """Initialize attntion metadata hence all layers in the forward pass can reuse it.""" metadata = AppendAttentionMetadata() - metadata.encoder_block_shape_q = 64 - metadata.decoder_block_shape_q = 16 metadata.max_partition_size = self.max_partition_size metadata.encoder_max_partition_size = self.max_seq_len metadata._dtype = paddle.get_default_dtype() @@ -148,18 +145,18 @@ class AppendAttentionBackend(AttentionBackend): metadata.kv_batch_ids, metadata.kv_tile_ids_per_batch, metadata.kv_num_blocks, - metadata.decoder_batch_ids, # will copy to buffer - metadata.decoder_tile_ids_per_batch, # will copy to buffer - metadata.decoder_num_blocks, metadata.max_len_kv, - metadata.set_max_lengths, ) = get_block_shape_and_split_kv_block( forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, - metadata.encoder_block_shape_q, - metadata.decoder_block_shape_q, - self.num_heads // self.kv_num_heads, + forward_meta.decoder_batch_ids, + forward_meta.decoder_tile_ids_per_batch, + forward_meta.decoder_num_blocks_cpu, + forward_meta.max_len_tensor_cpu, + self.encoder_block_shape_q, + self.decoder_block_shape_q, + self.group_size, self.block_size, self.speculate_max_draft_token_num + 1, ) @@ -181,8 +178,6 @@ class AppendAttentionBackend(AttentionBackend): ) self.attention_metadata: AttentionMetadata = metadata - forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False) - forward_meta.decoder_tile_ids_per_batch.copy_(metadata.decoder_tile_ids_per_batch, False) def get_attntion_meta(self) -> AttentionMetadata: """get_attntion_meta""" @@ -249,10 +244,10 @@ class AppendAttentionBackend(AttentionBackend): metadata.kv_batch_ids, metadata.kv_tile_ids_per_batch, metadata.kv_num_blocks, - forward_meta.decoder_batch_ids, # from buffer - forward_meta.decoder_tile_ids_per_batch, # from buffer - metadata.decoder_num_blocks, - metadata.set_max_lengths, + forward_meta.decoder_batch_ids, + forward_meta.decoder_tile_ids_per_batch, + forward_meta.decoder_num_blocks_cpu, + forward_meta.max_len_tensor_cpu, metadata.max_len_kv, metadata.rotary_embs, metadata.attn_mask, @@ -275,8 +270,8 @@ class AppendAttentionBackend(AttentionBackend): getattr(layer, "quant_max_bound", 0.0), getattr(layer, "quant_min_bound", 0.0), getattr(layer, "out_scale", -1.0), - metadata.encoder_block_shape_q, - metadata.decoder_block_shape_q, + self.encoder_block_shape_q, + self.decoder_block_shape_q, metadata.max_partition_size, metadata.encoder_max_partition_size, self.speculate_max_draft_token_num + 1, diff --git a/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py b/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py index c9ca9cdec..2802e97ba 100644 --- a/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py @@ -38,17 +38,12 @@ class BlockAttentionMetadata(AttentionMetadata): BlockAttentionMetadata """ - max_len_kv: paddle.Tensor = None - set_max_lengths: int = -1 encoder_batch_ids: paddle.Tensor = None encoder_tile_ids_per_batch: paddle.Tensor = None encoder_num_blocks: paddle.Tensor = None kv_batch_ids: paddle.Tensor = None kv_tile_ids_per_batch: paddle.Tensor = None kv_num_blocks: paddle.Tensor = None - decoder_batch_ids: paddle.Tensor = None - decoder_tile_ids_per_batch: paddle.Tensor = None - decoder_num_blocks: paddle.Tensor = None _dtype: paddle.dtype = paddle.bfloat16 encoder_max_partition_size: int = 32768 @@ -56,8 +51,6 @@ class BlockAttentionMetadata(AttentionMetadata): block_tables: Optional[paddle.Tensor] = None rotary_embs: Optional[paddle.Tensor] = None attn_mask: Optional[paddle.Tensor] = None - encoder_block_shape_q: int = -1 - decoder_block_shape_q: int = -1 _fuse_kernel_compute_dtype: str = "bf16" # pd_disaggregation @@ -79,6 +72,8 @@ class BlockAttentionBackend(AttentionBackend): kv_num_heads: int, num_heads: int, head_dim: int, + encoder_block_shape_q: int = -1, + decoder_block_shape_q: int = -1, ): """ BlockAttentionBackend __init__ diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index fcbf6fa64..306164635 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -53,8 +53,6 @@ class FlashAttentionMetadata(AttentionMetadata): FlashAttentionMetadata """ - max_len_kv: paddle.Tensor = None - set_max_lengths: int = -1 rotary_embs: Optional[paddle.Tensor] = None block_tables: Optional[paddle.Tensor] = None encoder_batch_ids: paddle.Tensor = None @@ -63,12 +61,6 @@ class FlashAttentionMetadata(AttentionMetadata): kv_batch_ids: paddle.Tensor = None kv_tile_ids_per_batch: paddle.Tensor = None kv_num_blocks: paddle.Tensor = None - decoder_batch_ids: paddle.Tensor = None - decoder_tile_ids_per_batch: paddle.Tensor = None - decoder_num_blocks: paddle.Tensor = None - - encoder_block_shape_q: int = -1 - decoder_block_shape_q: int = -1 cu_seqlens_q: paddle.Tensor = None cu_seqlens_k: paddle.Tensor = None @@ -100,6 +92,8 @@ class FlashAttentionBackend(AttentionBackend): kv_num_heads: int, num_heads: int, head_dim: int, + encoder_block_shape_q: int = -1, + decoder_block_shape_q: int = -1, ): """ FlashAttentionBackend __init__ @@ -111,10 +105,13 @@ class FlashAttentionBackend(AttentionBackend): self.kv_num_heads = kv_num_heads self.num_heads = num_heads + self.group_size: int = self.num_heads // self.kv_num_heads self.head_dim = fd_config.model_config.head_dim self.attn_outputsize_tp = self.num_heads * self.head_dim self.block_size = fd_config.cache_config.block_size self.num_layers: int = fd_config.model_config.num_hidden_layers + self.encoder_block_shape_q: int = encoder_block_shape_q + self.decoder_block_shape_q: int = decoder_block_shape_q self.speculative_method = fd_config.speculative_config.method self.use_speculate = self.speculative_method is not None @@ -176,8 +173,6 @@ class FlashAttentionBackend(AttentionBackend): def init_attention_metadata(self, forward_meta: ForwardMeta): metadata = FlashAttentionMetadata() - metadata.encoder_block_shape_q = 64 - metadata.decoder_block_shape_q = 16 metadata.cu_seqlens_q = forward_meta.cu_seqlens_q metadata.rotary_embs = forward_meta.rotary_embs metadata.block_tables = forward_meta.block_tables @@ -188,18 +183,18 @@ class FlashAttentionBackend(AttentionBackend): metadata.kv_batch_ids, metadata.kv_tile_ids_per_batch, metadata.kv_num_blocks, - metadata.decoder_batch_ids, - metadata.decoder_tile_ids_per_batch, - metadata.decoder_num_blocks, metadata.max_len_kv, - metadata.set_max_lengths, ) = get_block_shape_and_split_kv_block( forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, - metadata.encoder_block_shape_q, - metadata.decoder_block_shape_q, - self.num_heads // self.kv_num_heads, + forward_meta.decoder_batch_ids, + forward_meta.decoder_tile_ids_per_batch, + forward_meta.decoder_num_blocks_cpu, + forward_meta.max_len_tensor_cpu, + self.encoder_block_shape_q, + self.decoder_block_shape_q, + self.group_size, self.block_size, self.speculate_max_draft_token_num + 1, ) @@ -233,8 +228,6 @@ class FlashAttentionBackend(AttentionBackend): self.rank, int(self.device_id), self.keep_pd_step_flag ) self.attention_metadata = metadata - forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False) - forward_meta.decoder_tile_ids_per_batch.copy_(metadata.decoder_tile_ids_per_batch, False) def forward_mixed( self, @@ -291,8 +284,8 @@ class FlashAttentionBackend(AttentionBackend): v, metadata.cu_seqlens_q, metadata.cu_seqlens_k, - max_seqlen_q=metadata.set_max_lengths[0], - max_seqlen_k=metadata.set_max_lengths[3], + max_seqlen_q=forward_meta.max_len_tensor_cpu[0], + max_seqlen_k=forward_meta.max_len_tensor_cpu[3], causal=self.causal, **self.flash_attn_kwargs, )[0].reshape([-1, self.attn_outputsize_tp]) diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 20fa775ed..5279b68f6 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -64,17 +64,13 @@ class MLAAttentionMetadata(AttentionMetadata): MLAAttentionMetadata for Multi-Layer Attention """ - max_len_kv: paddle.Tensor = None - set_max_lengths: int = -1 encoder_batch_ids: paddle.Tensor = None encoder_tile_ids_per_batch: paddle.Tensor = None encoder_num_blocks: paddle.Tensor = None kv_batch_ids: paddle.Tensor = None kv_tile_ids_per_batch: paddle.Tensor = None kv_num_blocks: paddle.Tensor = None - decoder_batch_ids: paddle.Tensor = None - decoder_tile_ids_per_batch: paddle.Tensor = None - decoder_num_blocks: paddle.Tensor = None + max_len_kv: paddle.Tensor = None _dtype: paddle.dtype = paddle.bfloat16 encoder_max_partition_size: int = 32768 @@ -82,8 +78,6 @@ class MLAAttentionMetadata(AttentionMetadata): block_tables: Optional[paddle.Tensor] = None rotary_embs: Optional[paddle.Tensor] = None attn_mask: Optional[paddle.Tensor] = None - encoder_block_shape_q: int = -1 - decoder_block_shape_q: int = -1 _fuse_kernel_compute_dtype: str = "bf16" # pd_disaggregation @@ -105,6 +99,8 @@ class MLAAttentionBackend(AttentionBackend): kv_num_heads: int, num_heads: int, head_dim: int, + encoder_block_shape_q: int = -1, + decoder_block_shape_q: int = -1, ) -> None: """ MLAAttentionBackend __init__ @@ -128,8 +124,11 @@ class MLAAttentionBackend(AttentionBackend): self.kv_num_heads: int = kv_num_heads self.num_heads: int = num_heads + self.group_size: int = self.num_heads // self.kv_num_heads self.head_dim: int = fd_config.model_config.head_dim self.num_layers: int = fd_config.model_config.num_hidden_layers + self.encoder_block_shape_q: int = encoder_block_shape_q + self.decoder_block_shape_q: int = decoder_block_shape_q # For Multi Head Latent Attention self.kv_lora_rank: int = fd_config.model_config.kv_lora_rank @@ -152,8 +151,6 @@ class MLAAttentionBackend(AttentionBackend): def init_attention_metadata(self, forward_meta: ForwardMeta): """Initialize attention metadata hence all layers in the forward pass can reuse it.""" metadata = MLAAttentionMetadata() - metadata.encoder_block_shape_q = 64 - metadata.decoder_block_shape_q = 16 metadata.max_partition_size = 32768 metadata.encoder_max_partition_size = self.max_seq_len metadata._dtype = paddle.get_default_dtype() @@ -176,27 +173,25 @@ class MLAAttentionBackend(AttentionBackend): metadata.kv_batch_ids, metadata.kv_tile_ids_per_batch, metadata.kv_num_blocks, - metadata.decoder_batch_ids, - metadata.decoder_tile_ids_per_batch, - metadata.decoder_num_blocks, metadata.max_len_kv, - metadata.set_max_lengths, ) = get_block_shape_and_split_kv_block( forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, - metadata.encoder_block_shape_q, - metadata.decoder_block_shape_q, - self.num_heads // self.kv_num_heads, + forward_meta.decoder_batch_ids, + forward_meta.decoder_tile_ids_per_batch, + forward_meta.decoder_num_blocks_cpu, + forward_meta.max_len_tensor_cpu, + self.encoder_block_shape_q, + self.decoder_block_shape_q, + self.group_size, self.block_size, self.speculate_max_draft_token_num + 1, ) # MLA - metadata.max_enc_len_this_time = metadata.set_max_lengths[1] - metadata.max_dec_len_this_time = metadata.set_max_lengths[2] - forward_meta.max_enc_len_this_time = metadata.set_max_lengths[1] - forward_meta.max_dec_len_this_time = metadata.set_max_lengths[2] + metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1] + metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2] # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers @@ -216,9 +211,6 @@ class MLAAttentionBackend(AttentionBackend): self.attention_metadata: AttentionMetadata = metadata - forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False) - forward_meta.decoder_tile_ids_per_batch.copy_(metadata.decoder_tile_ids_per_batch, False) - def get_attntion_meta(self) -> AttentionMetadata: """get_attntion_meta""" return self.attention_metadata @@ -354,8 +346,8 @@ class MLAAttentionBackend(AttentionBackend): metadata.kv_num_blocks, forward_meta.decoder_batch_ids, forward_meta.decoder_tile_ids_per_batch, - metadata.decoder_num_blocks, - metadata.decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu + forward_meta.decoder_num_blocks_cpu, + forward_meta.decoder_num_blocks_cpu, metadata.max_enc_len_this_time, metadata.max_dec_len_this_time, metadata.max_len_kv, @@ -476,8 +468,8 @@ class MLAAttentionBackend(AttentionBackend): metadata.kv_num_blocks, forward_meta.decoder_batch_ids, forward_meta.decoder_tile_ids_per_batch, - metadata.decoder_num_blocks, - metadata.decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu + forward_meta.decoder_num_blocks_cpu, + forward_meta.decoder_num_blocks_cpu, metadata.max_enc_len_this_time, metadata.max_dec_len_this_time, metadata.max_len_kv, diff --git a/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py b/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py index f2e252a42..dd57b5259 100644 --- a/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py +++ b/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py @@ -28,6 +28,10 @@ def get_block_shape_and_split_kv_block( seq_lens_encoder: paddle.Tensor, seq_lens_decoder: paddle.Tensor, seq_lens_this_time: paddle.Tensor, + decoder_batch_ids: paddle.Tensor, + decoder_tile_ids_per_batch: paddle.Tensor, + decoder_num_blocks_x_cpu: paddle.Tensor, + max_len_tensor_cpu: paddle.Tensor, encoder_block_shape_q: int, decoder_block_shape_q: int, group_size: int, @@ -45,15 +49,15 @@ def get_block_shape_and_split_kv_block( kv_batch_ids, kv_tile_ids_per_batch, kv_num_blocks, - decoder_batch_ids, - decoder_tile_ids_per_batch, - decoder_num_blocks, - max_len_kv, - set_max_lengths, + max_len_kv_cpu, ) = get_block_shape_and_split_kv_block_cuda( seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_x_cpu, + max_len_tensor_cpu, encoder_block_shape_q, decoder_block_shape_q, group_size, @@ -67,11 +71,7 @@ def get_block_shape_and_split_kv_block( kv_batch_ids, kv_tile_ids_per_batch, kv_num_blocks, - decoder_batch_ids, - decoder_tile_ids_per_batch, - decoder_num_blocks, - max_len_kv, - set_max_lengths, + max_len_kv_cpu, ) else: raise NotImplementedError diff --git a/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py b/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py index 0cdf605d2..45ae75184 100644 --- a/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py @@ -44,26 +44,13 @@ class XPUAttentionMetadata(AttentionMetadata): XPUAttentionMetadata """ - max_len_kv: paddle.Tensor = None - set_max_lengths: int = -1 - encoder_batch_ids: paddle.Tensor = None - encoder_tile_ids_per_batch: paddle.Tensor = None - encoder_num_blocks: paddle.Tensor = None - kv_batch_ids: paddle.Tensor = None - kv_tile_ids_per_batch: paddle.Tensor = None - kv_num_blocks: paddle.Tensor = None - decoder_batch_ids: paddle.Tensor = None - decoder_tile_ids_per_batch: paddle.Tensor = None - decoder_num_blocks: paddle.Tensor = None - _dtype: paddle.dtype = paddle.bfloat16 encoder_max_partition_size: int = 32768 max_partition_size: int = 32768 block_tables: Optional[paddle.Tensor] = None rotary_embs: Optional[paddle.Tensor] = None attn_mask: Optional[paddle.Tensor] = None - encoder_block_shape_q: int = -1 - decoder_block_shape_q: int = -1 + _fuse_kernel_compute_dtype: str = "bf16" # pd_disaggregation @@ -91,7 +78,6 @@ class XPUAttentionBackend(AttentionBackend): """ super().__init__() self.attention_metadata: XPUAttentionMetadata = None - # TODO(gongshaotian): Use fd_config parameters in the correct location self.block_size: int = fd_config.cache_config.block_size self.max_seq_len: int = fd_config.parallel_config.max_model_len self.rope_theta: float = ( @@ -99,9 +85,6 @@ class XPUAttentionBackend(AttentionBackend): ) self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) self.causal: bool = getattr(fd_config.model_config, "causal", True) - # self.speculate_method = fd_config.parallel_config.speculate_method - # self.use_speculate = self.speculate_method is not None - # self.speculate_max_draft_token_num = fd_config.parallel_config.speculate_max_draft_tokens self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" self.rank: int = fd_config.parallel_config.tensor_parallel_rank @@ -117,8 +100,6 @@ class XPUAttentionBackend(AttentionBackend): def init_attention_metadata(self, forward_meta: ForwardMeta): """Initialize attntion metadata hence all layers in the forward pass can reuse it.""" metadata = XPUAttentionMetadata() - metadata.encoder_block_shape_q = 64 - metadata.decoder_block_shape_q = 16 metadata.max_partition_size = 32768 metadata.encoder_max_partition_size = 32768 metadata._dtype = paddle.get_default_dtype() diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 6503a8d38..39f0fce42 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -184,13 +184,26 @@ class MTPProposer(Proposer): """ assert len(self.attn_backends) == 0 - # TODO(gongshaotian): Get rank from config num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_size - self.model_config.kv_num_heads = ( - int(self.model_config.num_key_value_heads) // self.parallel_config.tensor_parallel_size + self.model_config.kv_num_heads = max( + 1, + int(self.model_config.num_key_value_heads) // self.parallel_config.tensor_parallel_size, ) head_dim = self.model_config.head_dim + # Initialize AttentionBackend buffers + encoder_block_shape_q = 64 + decoder_block_shape_q = 16 + + self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.main_model_inputs["decoder_batch_ids"]) + self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like( + self.main_model_inputs["decoder_tile_ids_per_batch"] + ) + self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like( + self.main_model_inputs["decoder_num_blocks_cpu"] + ).pin_memory() + self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like(self.main_model_inputs["max_len_tensor_cpu"]).cpu() + # Get the attention backend attn_cls = get_attention_backend() attn_backend = attn_cls( @@ -198,6 +211,8 @@ class MTPProposer(Proposer): kv_num_heads=self.model_config.kv_num_heads, num_heads=num_heads, head_dim=head_dim, + encoder_block_shape_q=encoder_block_shape_q, + decoder_block_shape_q=decoder_block_shape_q, ) if attn_backend is None: raise NotImplementedError( @@ -293,6 +308,12 @@ class MTPProposer(Proposer): self.model_inputs["base_model_draft_tokens"] = self.main_model_inputs["draft_tokens"] self.model_inputs["substep"] = 0 + # Declare AttentionBackend buffers + self.model_inputs["decoder_batch_ids"] = None + self.model_inputs["decoder_tile_ids_per_batch"] = None + self.model_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory + self.model_inputs["max_len_tensor_cpu"] = None # CPU + # Input tokens self.model_inputs["draft_tokens"] = paddle.full(shape=[self.max_num_seqs, 2], fill_value=-1, dtype="int64") @@ -405,6 +426,8 @@ class MTPProposer(Proposer): attn_backend=self.attn_backends[0], decoder_batch_ids=self.model_inputs["decoder_batch_ids"], decoder_tile_ids_per_batch=self.model_inputs["decoder_tile_ids_per_batch"], + decoder_num_blocks_cpu=self.model_inputs["decoder_num_blocks_cpu"], + max_len_tensor_cpu=self.model_inputs["max_len_tensor_cpu"], seq_lens_encoder=self.model_inputs["seq_lens_encoder"], seq_lens_decoder=self.model_inputs["seq_lens_decoder"], seq_lens_this_time=self.model_inputs["seq_lens_this_time"], diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py index 850c3fc9a..531304017 100644 --- a/fastdeploy/worker/gcu_model_runner.py +++ b/fastdeploy/worker/gcu_model_runner.py @@ -417,9 +417,11 @@ class GCUModelRunner(ModelRunnerBase): self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") - # AttentionBackend buffers - self.share_inputs["decoder_batch_ids"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") - self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + # Declare AttentionBackend buffers + self.share_inputs["decoder_batch_ids"] = None + self.share_inputs["decoder_tile_ids_per_batch"] = None + self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory + self.share_inputs["max_len_tensor_cpu"] = None # CPU # Initialize rotary position embedding tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1)) @@ -579,6 +581,8 @@ class GCUModelRunner(ModelRunnerBase): attn_backend=self.attn_backends[0], decoder_batch_ids=self.share_inputs["decoder_batch_ids"], decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"], + decoder_num_blocks_cpu=self.share_inputs["decoder_num_blocks_cpu"], + max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"], seq_lens_encoder=self.share_inputs["seq_lens_encoder"], seq_lens_decoder=self.share_inputs["seq_lens_decoder"], seq_lens_this_time=self.share_inputs["seq_lens_this_time"], @@ -655,6 +659,18 @@ class GCUModelRunner(ModelRunnerBase): ) head_dim = self.model_config.head_dim + # Initialize AttentionBackend buffers + encoder_block_shape_q = 64 + decoder_block_shape_q = 16 + decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1 + decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil( + (decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q + ) + self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory() + self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu() + # Get the attention backend attn_cls = get_attention_backend() attn_backend = attn_cls( @@ -662,6 +678,8 @@ class GCUModelRunner(ModelRunnerBase): kv_num_heads=self.model_config.kv_num_heads, num_heads=num_heads, head_dim=head_dim, + encoder_block_shape_q=encoder_block_shape_q, + decoder_block_shape_q=decoder_block_shape_q, ) if attn_backend is None: raise NotImplementedError( @@ -1179,9 +1197,5 @@ class GCUModelRunner(ModelRunnerBase): Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch. In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch. """ - # TODO(gongshaotian): Use more efficient implementation - if self.forward_meta.step_use_cudagraph: - num_empty_batch = (self.forward_meta.seq_lens_this_time == 0).sum() - for i in range(1, num_empty_batch + 1): - self.forward_meta.decoder_batch_ids[-i] = 0 - self.forward_meta.decoder_tile_ids_per_batch[-i] = 0 + # In init_attention_metadata, the decode buffer has already been cleared + return diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 1fb6235f9..4b67b595e 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -610,9 +610,7 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["step_seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["prompt_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int64") self.share_inputs["step_idx"] = paddle.full([max_num_seqs, 1], 0, dtype="int64") - self.share_inputs["not_need_stop"] = paddle.full( - [1], False, dtype="bool" - ).cpu() # TODO(gongshaotian): move to pinnd memory + self.share_inputs["not_need_stop"] = paddle.full([1], False, dtype="bool").cpu() self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool") self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64") @@ -643,9 +641,12 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") - # AttentionBackend buffers - self.share_inputs["decoder_batch_ids"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") - self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + + # Declare AttentionBackend buffers + self.share_inputs["decoder_batch_ids"] = None + self.share_inputs["decoder_tile_ids_per_batch"] = None + self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory + self.share_inputs["max_len_tensor_cpu"] = None # CPU # Initialize rotary position embedding tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1)) @@ -845,6 +846,8 @@ class GPUModelRunner(ModelRunnerBase): attn_backend=self.attn_backends[0], decoder_batch_ids=self.share_inputs["decoder_batch_ids"], decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"], + decoder_num_blocks_cpu=self.share_inputs["decoder_num_blocks_cpu"], + max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"], seq_lens_encoder=self.share_inputs["seq_lens_encoder"], seq_lens_decoder=self.share_inputs["seq_lens_decoder"], seq_lens_this_time=self.share_inputs["seq_lens_this_time"], @@ -856,7 +859,6 @@ class GPUModelRunner(ModelRunnerBase): ) # Update Batch type for cuda graph - # TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch only_decode_batch = True prefill_exists = None # mix ep in single node @@ -946,6 +948,18 @@ class GPUModelRunner(ModelRunnerBase): ) head_dim = self.model_config.head_dim + # Initialize AttentionBackend buffers + encoder_block_shape_q = 64 + decoder_block_shape_q = 16 + decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1 + decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil( + (decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q + ) + self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory() + self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu() + # Get the attention backend attn_cls = get_attention_backend() attn_backend = attn_cls( @@ -953,6 +967,8 @@ class GPUModelRunner(ModelRunnerBase): kv_num_heads=self.model_config.kv_num_heads, num_heads=num_heads, head_dim=head_dim, + encoder_block_shape_q=encoder_block_shape_q, + decoder_block_shape_q=decoder_block_shape_q, ) self.attn_backends.append(attn_backend) @@ -1527,12 +1543,8 @@ class GPUModelRunner(ModelRunnerBase): Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch. In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch. """ - # TODO(gongshaotian): Use more efficient implementation - if self.forward_meta.step_use_cudagraph: - num_empty_batch = (self.forward_meta.seq_lens_this_time == 0).sum() - for i in range(1, num_empty_batch + 1): - self.forward_meta.decoder_batch_ids[-i] = 0 - self.forward_meta.decoder_tile_ids_per_batch[-i] = 0 + # In init_attention_metadata, the decode buffer has already been cleared + return def _init_image_preprocess(self) -> None: processor = DataProcessor( diff --git a/fastdeploy/worker/iluvatar_model_runner.py b/fastdeploy/worker/iluvatar_model_runner.py index 1d0b4d208..a84ab7118 100644 --- a/fastdeploy/worker/iluvatar_model_runner.py +++ b/fastdeploy/worker/iluvatar_model_runner.py @@ -383,8 +383,8 @@ class IluvatarModelRunner(ModelRunnerBase): self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") # AttentionBackend buffers - self.share_inputs["decoder_batch_ids"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") - self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["decoder_batch_ids"] = None + self.share_inputs["decoder_tile_ids_per_batch"] = None # Initialize rotary position embedding tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))