mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-03 15:56:49 +08:00
[Executor] Refactor GetBlockShapeAndSplitKVBlock Kernel (#2989)
* reset decoder_block_shape_q buffer * refactor GetBlockShapeAndSplitKVBlock Kernel and cudagraph padding batch * update decode_max_tile_size * fix pre-commit * update block_multihead_attn_backend * update flas attn backend * update MLA Attention * update XPU Attention * update gcu,iluvatar model runner * Update MTP * fix MTP bug
This commit is contained in:
@@ -610,9 +610,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["step_seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["prompt_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
||||
self.share_inputs["step_idx"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
||||
self.share_inputs["not_need_stop"] = paddle.full(
|
||||
[1], False, dtype="bool"
|
||||
).cpu() # TODO(gongshaotian): move to pinnd memory
|
||||
self.share_inputs["not_need_stop"] = paddle.full([1], False, dtype="bool").cpu()
|
||||
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
|
||||
self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64")
|
||||
|
||||
@@ -643,9 +641,12 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
# AttentionBackend buffers
|
||||
self.share_inputs["decoder_batch_ids"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
|
||||
# Declare AttentionBackend buffers
|
||||
self.share_inputs["decoder_batch_ids"] = None
|
||||
self.share_inputs["decoder_tile_ids_per_batch"] = None
|
||||
self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
|
||||
self.share_inputs["max_len_tensor_cpu"] = None # CPU
|
||||
|
||||
# Initialize rotary position embedding
|
||||
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
||||
@@ -845,6 +846,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
attn_backend=self.attn_backends[0],
|
||||
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
|
||||
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
|
||||
decoder_num_blocks_cpu=self.share_inputs["decoder_num_blocks_cpu"],
|
||||
max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"],
|
||||
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
||||
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
|
||||
@@ -856,7 +859,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
|
||||
# Update Batch type for cuda graph
|
||||
# TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch
|
||||
only_decode_batch = True
|
||||
prefill_exists = None
|
||||
# mix ep in single node
|
||||
@@ -946,6 +948,18 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
head_dim = self.model_config.head_dim
|
||||
|
||||
# Initialize AttentionBackend buffers
|
||||
encoder_block_shape_q = 64
|
||||
decoder_block_shape_q = 16
|
||||
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
|
||||
decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||
(decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q
|
||||
)
|
||||
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||
self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu()
|
||||
|
||||
# Get the attention backend
|
||||
attn_cls = get_attention_backend()
|
||||
attn_backend = attn_cls(
|
||||
@@ -953,6 +967,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
kv_num_heads=self.model_config.kv_num_heads,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
encoder_block_shape_q=encoder_block_shape_q,
|
||||
decoder_block_shape_q=decoder_block_shape_q,
|
||||
)
|
||||
|
||||
self.attn_backends.append(attn_backend)
|
||||
@@ -1527,12 +1543,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
|
||||
In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch.
|
||||
"""
|
||||
# TODO(gongshaotian): Use more efficient implementation
|
||||
if self.forward_meta.step_use_cudagraph:
|
||||
num_empty_batch = (self.forward_meta.seq_lens_this_time == 0).sum()
|
||||
for i in range(1, num_empty_batch + 1):
|
||||
self.forward_meta.decoder_batch_ids[-i] = 0
|
||||
self.forward_meta.decoder_tile_ids_per_batch[-i] = 0
|
||||
# In init_attention_metadata, the decode buffer has already been cleared
|
||||
return
|
||||
|
||||
def _init_image_preprocess(self) -> None:
|
||||
processor = DataProcessor(
|
||||
|
Reference in New Issue
Block a user