mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-29 02:52:55 +08:00
[Graph Optimization][Speculative Decoding] Fix the bug of CUDAGraph + MTP + EP (#4430)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* Fix MTP dummy run bug * Target Model and Draft Model using the same flag * aovid moe bug in cudagraph padding * In mtp replace use_cudagraph as step_use_cudagraph
This commit is contained in:
@@ -496,12 +496,12 @@ std::vector<paddle::Tensor> AppendAttention(
|
|||||||
paddle::Tensor fmha_out;
|
paddle::Tensor fmha_out;
|
||||||
if (out_linear_in_scale > 0.0) {
|
if (out_linear_in_scale > 0.0) {
|
||||||
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
||||||
fmha_out = GetEmptyTensor(
|
fmha_out = paddle::zeros(
|
||||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||||
paddle::DataType::INT8,
|
paddle::DataType::INT8,
|
||||||
qkv.place());
|
qkv.place());
|
||||||
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
||||||
fmha_out = GetEmptyTensor(
|
fmha_out = paddle::zeros(
|
||||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||||
paddle::DataType::FLOAT8_E4M3FN,
|
paddle::DataType::FLOAT8_E4M3FN,
|
||||||
qkv.place());
|
qkv.place());
|
||||||
@@ -509,7 +509,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
|||||||
PD_THROW("Only supported attr of quant_max_bound in ['127', '448'].");
|
PD_THROW("Only supported attr of quant_max_bound in ['127', '448'].");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
fmha_out = GetEmptyTensor(
|
fmha_out = paddle::zeros(
|
||||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||||
dtype_id,
|
dtype_id,
|
||||||
qkv.place());
|
qkv.place());
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ class MTPProposer(Proposer):
|
|||||||
self._init_model_inputs()
|
self._init_model_inputs()
|
||||||
|
|
||||||
# CUDA Graph
|
# CUDA Graph
|
||||||
self.use_cudagraph = False # self.graph_opt_config.use_cudagraph
|
self.use_cudagraph = False # TODO(gongshaotian): Use Target Model flag
|
||||||
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
|
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
|
||||||
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
|
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
|
||||||
|
|
||||||
@@ -117,6 +117,9 @@ class MTPProposer(Proposer):
|
|||||||
self.parallel_config.max_model_len - max_dec_len,
|
self.parallel_config.max_model_len - max_dec_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.fd_config.parallel_config.enable_expert_parallel:
|
||||||
|
input_length = min(input_length, 32)
|
||||||
|
|
||||||
block_num = (
|
block_num = (
|
||||||
input_length + self.cache_config.block_size - 1
|
input_length + self.cache_config.block_size - 1
|
||||||
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
|
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
|
||||||
@@ -541,7 +544,7 @@ class MTPProposer(Proposer):
|
|||||||
self.model_inputs["not_need_stop"][0] = True
|
self.model_inputs["not_need_stop"][0] = True
|
||||||
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
|
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
|
||||||
|
|
||||||
def _initialize_forward_meta(self):
|
def _initialize_forward_meta(self, step_use_cudagraph: bool = False):
|
||||||
"""
|
"""
|
||||||
Initialize forward meta and attention meta data
|
Initialize forward meta and attention meta data
|
||||||
"""
|
"""
|
||||||
@@ -569,23 +572,8 @@ class MTPProposer(Proposer):
|
|||||||
for attn_backend in self.attn_backends:
|
for attn_backend in self.attn_backends:
|
||||||
attn_backend.init_attention_metadata(self.forward_meta)
|
attn_backend.init_attention_metadata(self.forward_meta)
|
||||||
|
|
||||||
# Update Batch type for cuda graph
|
# TODO(gongshaotian): Use CUDAGraph with Draft Model
|
||||||
only_decode_batch = True
|
self.forward_meta.step_use_cudagraph = step_use_cudagraph and self.use_cudagraph
|
||||||
prefill_exists = None
|
|
||||||
|
|
||||||
# Mix ep in single node
|
|
||||||
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
|
|
||||||
only_decode_batch_list = []
|
|
||||||
prefill_exists = self.exist_prefill()
|
|
||||||
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
|
|
||||||
only_decode_batch = all(only_decode_batch_list)
|
|
||||||
self.fd_config.model_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"
|
|
||||||
|
|
||||||
self.forward_meta.step_use_cudagraph = (
|
|
||||||
self.use_cudagraph
|
|
||||||
and only_decode_batch
|
|
||||||
and not (prefill_exists if prefill_exists is not None else self.exist_prefill())
|
|
||||||
)
|
|
||||||
|
|
||||||
def exist_prefill(self):
|
def exist_prefill(self):
|
||||||
"""
|
"""
|
||||||
@@ -671,9 +659,12 @@ class MTPProposer(Proposer):
|
|||||||
self.parallel_config.use_ep,
|
self.parallel_config.use_ep,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _propose(self):
|
def _propose(self, step_use_cudagraph: bool = False):
|
||||||
"""
|
"""
|
||||||
Main process for MTP inference
|
Main process for MTP inference.
|
||||||
|
Args:
|
||||||
|
step_use_cudagraph: bool
|
||||||
|
Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP.
|
||||||
"""
|
"""
|
||||||
for substep in range(self.num_model_steps):
|
for substep in range(self.num_model_steps):
|
||||||
if self.model_inputs["not_need_stop"]:
|
if self.model_inputs["not_need_stop"]:
|
||||||
@@ -697,7 +688,7 @@ class MTPProposer(Proposer):
|
|||||||
|
|
||||||
# Initialize forward meta data
|
# Initialize forward meta data
|
||||||
self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
|
self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
|
||||||
self.model_inputs["batch_id_per_token"].copy_(batch_id_per_token, False)
|
self.model_inputs["batch_id_per_token"][:] = -1
|
||||||
self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
|
self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
|
||||||
self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
|
self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
|
||||||
# for speculative decoding
|
# for speculative decoding
|
||||||
@@ -705,7 +696,8 @@ class MTPProposer(Proposer):
|
|||||||
self.model_inputs["output_padding_offset"].copy_(output_padding_offset, False)
|
self.model_inputs["output_padding_offset"].copy_(output_padding_offset, False)
|
||||||
|
|
||||||
# Initialize forward meta data
|
# Initialize forward meta data
|
||||||
self._initialize_forward_meta()
|
self._initialize_forward_meta(step_use_cudagraph=step_use_cudagraph)
|
||||||
|
self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False)
|
||||||
|
|
||||||
# Padding inputs for cuda graph
|
# Padding inputs for cuda graph
|
||||||
self.padding_cudagraph_inputs()
|
self.padding_cudagraph_inputs()
|
||||||
@@ -733,7 +725,7 @@ class MTPProposer(Proposer):
|
|||||||
previous_hidden_states=self.model_inputs["target_hidden_states"],
|
previous_hidden_states=self.model_inputs["target_hidden_states"],
|
||||||
forward_meta=self.forward_meta,
|
forward_meta=self.forward_meta,
|
||||||
)
|
)
|
||||||
if self.use_cudagraph:
|
if self.forward_meta.step_use_cudagraph:
|
||||||
model_output = model_output[: self.real_token_num]
|
model_output = model_output[: self.real_token_num]
|
||||||
|
|
||||||
hidden_states = rebuild_padding(
|
hidden_states = rebuild_padding(
|
||||||
@@ -861,10 +853,10 @@ class MTPProposer(Proposer):
|
|||||||
self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda()
|
self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda()
|
||||||
self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()
|
self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()
|
||||||
|
|
||||||
def _run_impl(self, full_hidden_states):
|
def _run_impl(self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False):
|
||||||
""""""
|
"""Execute Draft Model"""
|
||||||
self._prepare_inputs(full_hidden_states)
|
self._prepare_inputs(full_hidden_states)
|
||||||
self._propose()
|
self._propose(step_use_cudagraph=step_use_cudagraph)
|
||||||
self._update_status()
|
self._update_status()
|
||||||
if self.hybrid_mode:
|
if self.hybrid_mode:
|
||||||
self._extend_draft_token_with_ngram_match()
|
self._extend_draft_token_with_ngram_match()
|
||||||
@@ -881,7 +873,7 @@ class MTPProposer(Proposer):
|
|||||||
# In init_attention_metadata, the decode buffer has already been cleared
|
# In init_attention_metadata, the decode buffer has already been cleared
|
||||||
|
|
||||||
# To adapt to CUDA Graph, keep the forward pass at the maximum batch size.
|
# To adapt to CUDA Graph, keep the forward pass at the maximum batch size.
|
||||||
if self.use_cudagraph:
|
if self.forward_meta.step_use_cudagraph:
|
||||||
self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer
|
self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer
|
||||||
self.real_token_num = self.forward_meta.ids_remove_padding.shape[0]
|
self.real_token_num = self.forward_meta.ids_remove_padding.shape[0]
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1253,7 +1253,9 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
if self.speculative_decoding:
|
if self.speculative_decoding:
|
||||||
if self.speculative_method == "mtp":
|
if self.speculative_method == "mtp":
|
||||||
self.proposer.run(full_hidden_states=model_output)
|
self.proposer.run(
|
||||||
|
full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.proposer.run(share_inputs=self.share_inputs)
|
self.proposer.run(share_inputs=self.share_inputs)
|
||||||
|
|
||||||
@@ -1600,7 +1602,9 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
# 6. Speculative decode
|
# 6. Speculative decode
|
||||||
if self.speculative_decoding:
|
if self.speculative_decoding:
|
||||||
if self.speculative_method == "mtp":
|
if self.speculative_method == "mtp":
|
||||||
self.proposer.run(full_hidden_states=model_output)
|
self.proposer.run(
|
||||||
|
full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.proposer.run(share_inputs=self.share_inputs)
|
self.proposer.run(share_inputs=self.share_inputs)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user