mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Feature][MTP]Support multi-step MTP (#2952)
This commit is contained in:
@@ -34,6 +34,7 @@ from fastdeploy.model_executor.ops.gpu import (
|
||||
draft_model_preprocess,
|
||||
draft_model_update,
|
||||
eagle_get_hidden_states,
|
||||
eagle_get_self_hidden_states,
|
||||
mtp_save_first_token,
|
||||
mtp_step_paddle,
|
||||
share_external_data,
|
||||
@@ -305,6 +306,10 @@ class MTPProposer(Proposer):
|
||||
|
||||
self.model_inputs["batch_drop"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool")
|
||||
self.model_inputs["used_list_len"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32")
|
||||
if self.max_draft_token_num > 1:
|
||||
self.last_seq_lens_this_time = paddle.full_like(
|
||||
self.main_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32"
|
||||
)
|
||||
|
||||
def insert_prefill_inputs(self, req_dicts: List[Request]):
|
||||
"""
|
||||
@@ -486,6 +491,13 @@ class MTPProposer(Proposer):
|
||||
"""
|
||||
for substep in range(self.max_draft_token_num):
|
||||
if self.model_inputs["not_need_stop"]:
|
||||
if substep != 0:
|
||||
target_hidden_states = eagle_get_self_hidden_states(
|
||||
hiddden_states,
|
||||
self.last_seq_lens_this_time,
|
||||
self.model_inputs["seq_lens_this_time"],
|
||||
self.model_inputs["step_idx"],
|
||||
)
|
||||
self.model_inputs["substep"] = substep
|
||||
# Remove padding
|
||||
(
|
||||
@@ -530,6 +542,11 @@ class MTPProposer(Proposer):
|
||||
eos_token_ids=self.model_inputs["eos_token_id"],
|
||||
)
|
||||
|
||||
if self.max_draft_token_num > 1:
|
||||
self.last_seq_lens_this_time = paddle.clone(
|
||||
self.model_inputs["seq_lens_this_time"]
|
||||
)
|
||||
|
||||
model_output = self.model(
|
||||
ids_remove_padding=self.model_inputs["ids_remove_padding"],
|
||||
previous_hidden_states=target_hidden_states,
|
||||
|
Reference in New Issue
Block a user