[Feature][MTP]Support multi-step MTP (#2952)

This commit is contained in:
GoldPancake
2025-07-22 16:26:29 +08:00
committed by GitHub
parent 920e6b3f60
commit dc67c10a7e
4 changed files with 33 additions and 8 deletions

View File

@@ -277,6 +277,8 @@ class SpeculativeConfig:
for key, value in args.items(): for key, value in args.items():
if key in name_map.keys() and hasattr(self, name_map[key]): if key in name_map.keys() and hasattr(self, name_map[key]):
if key == "speculative_benchmark_mode":
value = True if value.lower() == "true" else False
setattr(self, name_map[key], value) setattr(self, name_map[key], value)

View File

@@ -288,9 +288,12 @@ class TokenProcessor:
if self.cfg.speculative_config.method in ["mtp"]: if self.cfg.speculative_config.method in ["mtp"]:
single_head_acceptance_rates = [] single_head_acceptance_rates = []
for head in range(self.cfg.speculative_config.num_speculative_tokens): for head in range(self.cfg.speculative_config.num_speculative_tokens):
if self.num_rest_requests_per_head[head] != 0:
single_head_acceptance_rates.append( single_head_acceptance_rates.append(
self.num_accept_requests_per_head[head] / self.num_rest_requests_per_head[head] self.num_accept_requests_per_head[head] / self.num_rest_requests_per_head[head]
) )
else:
single_head_acceptance_rates.append(0)
spec_logger.info(f" Single head accept ratio: {single_head_acceptance_rates}") spec_logger.info(f" Single head accept ratio: {single_head_acceptance_rates}")
if self.number_of_output_tokens > 1000000: if self.number_of_output_tokens > 1000000:
@@ -599,9 +602,12 @@ class TokenProcessor:
# Update the rest requests for each head # Update the rest requests for each head
num_rest_requests = num_accept_requests num_rest_requests = num_accept_requests
# Calculate the acceptance rate for each head # Calculate the acceptance rate for each head
if self.num_rest_requests_per_head[head] != 0:
single_head_acceptance_rate = ( single_head_acceptance_rate = (
self.num_accept_requests_per_head[head] / self.num_rest_requests_per_head[head] self.num_accept_requests_per_head[head] / self.num_rest_requests_per_head[head]
) )
else:
single_head_acceptance_rate = 0
main_process_metrics.spec_decode_draft_single_head_acceptance_rate[head].set( main_process_metrics.spec_decode_draft_single_head_acceptance_rate[head].set(
single_head_acceptance_rate single_head_acceptance_rate
) )

View File

@@ -34,6 +34,7 @@ from fastdeploy.model_executor.ops.gpu import (
draft_model_preprocess, draft_model_preprocess,
draft_model_update, draft_model_update,
eagle_get_hidden_states, eagle_get_hidden_states,
eagle_get_self_hidden_states,
mtp_save_first_token, mtp_save_first_token,
mtp_step_paddle, mtp_step_paddle,
share_external_data, share_external_data,
@@ -305,6 +306,10 @@ class MTPProposer(Proposer):
self.model_inputs["batch_drop"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool") self.model_inputs["batch_drop"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool")
self.model_inputs["used_list_len"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32") self.model_inputs["used_list_len"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32")
if self.max_draft_token_num > 1:
self.last_seq_lens_this_time = paddle.full_like(
self.main_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32"
)
def insert_prefill_inputs(self, req_dicts: List[Request]): def insert_prefill_inputs(self, req_dicts: List[Request]):
""" """
@@ -486,6 +491,13 @@ class MTPProposer(Proposer):
""" """
for substep in range(self.max_draft_token_num): for substep in range(self.max_draft_token_num):
if self.model_inputs["not_need_stop"]: if self.model_inputs["not_need_stop"]:
if substep != 0:
target_hidden_states = eagle_get_self_hidden_states(
hiddden_states,
self.last_seq_lens_this_time,
self.model_inputs["seq_lens_this_time"],
self.model_inputs["step_idx"],
)
self.model_inputs["substep"] = substep self.model_inputs["substep"] = substep
# Remove padding # Remove padding
( (
@@ -530,6 +542,11 @@ class MTPProposer(Proposer):
eos_token_ids=self.model_inputs["eos_token_id"], eos_token_ids=self.model_inputs["eos_token_id"],
) )
if self.max_draft_token_num > 1:
self.last_seq_lens_this_time = paddle.clone(
self.model_inputs["seq_lens_this_time"]
)
model_output = self.model( model_output = self.model(
ids_remove_padding=self.model_inputs["ids_remove_padding"], ids_remove_padding=self.model_inputs["ids_remove_padding"],
previous_hidden_states=target_hidden_states, previous_hidden_states=target_hidden_states,

View File

@@ -499,8 +499,8 @@ def parse_args():
) )
parser.add_argument( parser.add_argument(
"--speculative_benchmark_mode", "--speculative_benchmark_mode",
default=False, default="False",
type=bool, type=str,
) )
parser.add_argument( parser.add_argument(
"--max_num_batched_tokens", "--max_num_batched_tokens",