[Feature][MTP]update multi-draft-token strategy (#3369)

* update multi-draft-token strategy

* fix format

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
freeliuzc
2025-08-18 13:59:56 +08:00
committed by GitHub
parent 3ee6053e5d
commit a12d0bc549
4 changed files with 69 additions and 30 deletions

View File

@@ -317,7 +317,9 @@ class MTPProposer(Proposer):
self.model_inputs["max_len_tensor_cpu"] = None # CPU
# Input tokens
self.model_inputs["draft_tokens"] = paddle.full(shape=[self.max_num_seqs, 2], fill_value=-1, dtype="int64")
self.model_inputs["draft_tokens"] = paddle.full(
shape=[self.max_num_seqs, self.max_draft_token_num + 1], fill_value=-1, dtype="int64"
)
self.model_inputs["encoder_block_lens"] = paddle.clone(self.main_model_inputs["encoder_block_lens"])
@@ -461,6 +463,7 @@ class MTPProposer(Proposer):
self.model_inputs["batch_drop"],
self.main_model_inputs["accept_tokens"],
self.main_model_inputs["accept_num"],
self.main_model_inputs["seq_lens_this_time"],
self.main_model_inputs["seq_lens_encoder"],
self.main_model_inputs["seq_lens_decoder"],
self.main_model_inputs["step_idx"],