mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
【Inference Optimize】DeepSeek-V3-model MLA Optimize (#3886)
* support MLA chunk_size auto search & cuda_graph
This commit is contained in:
@@ -838,6 +838,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["decoder_batch_ids"] = None
|
||||
self.share_inputs["decoder_tile_ids_per_batch"] = None
|
||||
self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
|
||||
self.share_inputs["decoder_num_blocks_device"] = None
|
||||
self.share_inputs["decoder_chunk_size_device"] = None
|
||||
self.share_inputs["max_len_tensor_cpu"] = None # CPU
|
||||
self.share_inputs["encoder_batch_ids"] = None
|
||||
self.share_inputs["encoder_tile_ids_per_batch"] = None
|
||||
@@ -991,6 +993,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
|
||||
self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
|
||||
# NOTE: (changwenbin) Initialized to max_num_seq '-1' before copying, marking illegal positions
|
||||
self.share_inputs["batch_id_per_token"][:] = -1
|
||||
self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False)
|
||||
self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
|
||||
self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
|
||||
@@ -1070,6 +1074,10 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
|
||||
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
|
||||
decoder_num_blocks_cpu=self.share_inputs["decoder_num_blocks_cpu"],
|
||||
# NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor,
|
||||
# adapted to cudagraph.
|
||||
decoder_num_blocks_device=self.share_inputs["decoder_num_blocks_device"],
|
||||
decoder_chunk_size_device=self.share_inputs["decoder_chunk_size_device"],
|
||||
max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"],
|
||||
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
||||
@@ -1196,8 +1204,12 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
|
||||
group_size = np.ceil(num_heads / self.model_config.kv_num_heads)
|
||||
|
||||
decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||
(decoder_step_token_num * group_size) / decoder_block_shape_q
|
||||
# NOTE: (changwenbin) When using auto_chunk,
|
||||
# decode_max_tile_size must take into account the maximum case, where *1024 can cover 128K.
|
||||
decode_max_tile_size = (
|
||||
1024
|
||||
* self.parallel_config.max_num_seqs
|
||||
* np.ceil((decoder_step_token_num * group_size) / decoder_block_shape_q)
|
||||
)
|
||||
encode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||
(self.model_config.max_model_len * group_size) / encoder_block_shape_q
|
||||
@@ -1208,6 +1220,10 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||
# NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor,
|
||||
# adapted to cudagraph.
|
||||
self.share_inputs["decoder_num_blocks_device"] = paddle.full([1], 0, dtype="int32")
|
||||
self.share_inputs["decoder_chunk_size_device"] = paddle.full([1], 64, dtype="int32")
|
||||
self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu()
|
||||
|
||||
self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
|
||||
|
Reference in New Issue
Block a user