mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
[Feature][MTP]Support multi-step MTP (#2952)
This commit is contained in:
@@ -277,6 +277,8 @@ class SpeculativeConfig:
|
|||||||
|
|
||||||
for key, value in args.items():
|
for key, value in args.items():
|
||||||
if key in name_map.keys() and hasattr(self, name_map[key]):
|
if key in name_map.keys() and hasattr(self, name_map[key]):
|
||||||
|
if key == "speculative_benchmark_mode":
|
||||||
|
value = True if value.lower() == "true" else False
|
||||||
setattr(self, name_map[key], value)
|
setattr(self, name_map[key], value)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -288,9 +288,12 @@ class TokenProcessor:
|
|||||||
if self.cfg.speculative_config.method in ["mtp"]:
|
if self.cfg.speculative_config.method in ["mtp"]:
|
||||||
single_head_acceptance_rates = []
|
single_head_acceptance_rates = []
|
||||||
for head in range(self.cfg.speculative_config.num_speculative_tokens):
|
for head in range(self.cfg.speculative_config.num_speculative_tokens):
|
||||||
|
if self.num_rest_requests_per_head[head] != 0:
|
||||||
single_head_acceptance_rates.append(
|
single_head_acceptance_rates.append(
|
||||||
self.num_accept_requests_per_head[head] / self.num_rest_requests_per_head[head]
|
self.num_accept_requests_per_head[head] / self.num_rest_requests_per_head[head]
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
single_head_acceptance_rates.append(0)
|
||||||
spec_logger.info(f" Single head accept ratio: {single_head_acceptance_rates}")
|
spec_logger.info(f" Single head accept ratio: {single_head_acceptance_rates}")
|
||||||
|
|
||||||
if self.number_of_output_tokens > 1000000:
|
if self.number_of_output_tokens > 1000000:
|
||||||
@@ -599,9 +602,12 @@ class TokenProcessor:
|
|||||||
# Update the rest requests for each head
|
# Update the rest requests for each head
|
||||||
num_rest_requests = num_accept_requests
|
num_rest_requests = num_accept_requests
|
||||||
# Calculate the acceptance rate for each head
|
# Calculate the acceptance rate for each head
|
||||||
|
if self.num_rest_requests_per_head[head] != 0:
|
||||||
single_head_acceptance_rate = (
|
single_head_acceptance_rate = (
|
||||||
self.num_accept_requests_per_head[head] / self.num_rest_requests_per_head[head]
|
self.num_accept_requests_per_head[head] / self.num_rest_requests_per_head[head]
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
single_head_acceptance_rate = 0
|
||||||
main_process_metrics.spec_decode_draft_single_head_acceptance_rate[head].set(
|
main_process_metrics.spec_decode_draft_single_head_acceptance_rate[head].set(
|
||||||
single_head_acceptance_rate
|
single_head_acceptance_rate
|
||||||
)
|
)
|
||||||
|
@@ -34,6 +34,7 @@ from fastdeploy.model_executor.ops.gpu import (
|
|||||||
draft_model_preprocess,
|
draft_model_preprocess,
|
||||||
draft_model_update,
|
draft_model_update,
|
||||||
eagle_get_hidden_states,
|
eagle_get_hidden_states,
|
||||||
|
eagle_get_self_hidden_states,
|
||||||
mtp_save_first_token,
|
mtp_save_first_token,
|
||||||
mtp_step_paddle,
|
mtp_step_paddle,
|
||||||
share_external_data,
|
share_external_data,
|
||||||
@@ -305,6 +306,10 @@ class MTPProposer(Proposer):
|
|||||||
|
|
||||||
self.model_inputs["batch_drop"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool")
|
self.model_inputs["batch_drop"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool")
|
||||||
self.model_inputs["used_list_len"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32")
|
self.model_inputs["used_list_len"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32")
|
||||||
|
if self.max_draft_token_num > 1:
|
||||||
|
self.last_seq_lens_this_time = paddle.full_like(
|
||||||
|
self.main_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32"
|
||||||
|
)
|
||||||
|
|
||||||
def insert_prefill_inputs(self, req_dicts: List[Request]):
|
def insert_prefill_inputs(self, req_dicts: List[Request]):
|
||||||
"""
|
"""
|
||||||
@@ -486,6 +491,13 @@ class MTPProposer(Proposer):
|
|||||||
"""
|
"""
|
||||||
for substep in range(self.max_draft_token_num):
|
for substep in range(self.max_draft_token_num):
|
||||||
if self.model_inputs["not_need_stop"]:
|
if self.model_inputs["not_need_stop"]:
|
||||||
|
if substep != 0:
|
||||||
|
target_hidden_states = eagle_get_self_hidden_states(
|
||||||
|
hiddden_states,
|
||||||
|
self.last_seq_lens_this_time,
|
||||||
|
self.model_inputs["seq_lens_this_time"],
|
||||||
|
self.model_inputs["step_idx"],
|
||||||
|
)
|
||||||
self.model_inputs["substep"] = substep
|
self.model_inputs["substep"] = substep
|
||||||
# Remove padding
|
# Remove padding
|
||||||
(
|
(
|
||||||
@@ -530,6 +542,11 @@ class MTPProposer(Proposer):
|
|||||||
eos_token_ids=self.model_inputs["eos_token_id"],
|
eos_token_ids=self.model_inputs["eos_token_id"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.max_draft_token_num > 1:
|
||||||
|
self.last_seq_lens_this_time = paddle.clone(
|
||||||
|
self.model_inputs["seq_lens_this_time"]
|
||||||
|
)
|
||||||
|
|
||||||
model_output = self.model(
|
model_output = self.model(
|
||||||
ids_remove_padding=self.model_inputs["ids_remove_padding"],
|
ids_remove_padding=self.model_inputs["ids_remove_padding"],
|
||||||
previous_hidden_states=target_hidden_states,
|
previous_hidden_states=target_hidden_states,
|
||||||
|
@@ -499,8 +499,8 @@ def parse_args():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--speculative_benchmark_mode",
|
"--speculative_benchmark_mode",
|
||||||
default=False,
|
default="False",
|
||||||
type=bool,
|
type=str,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_num_batched_tokens",
|
"--max_num_batched_tokens",
|
||||||
|
Reference in New Issue
Block a user