mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
[Feature][MTP]update multi-draft-token strategy (#3369)
* update multi-draft-token strategy * fix format --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -317,7 +317,9 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["max_len_tensor_cpu"] = None # CPU
|
||||
|
||||
# Input tokens
|
||||
self.model_inputs["draft_tokens"] = paddle.full(shape=[self.max_num_seqs, 2], fill_value=-1, dtype="int64")
|
||||
self.model_inputs["draft_tokens"] = paddle.full(
|
||||
shape=[self.max_num_seqs, self.max_draft_token_num + 1], fill_value=-1, dtype="int64"
|
||||
)
|
||||
|
||||
self.model_inputs["encoder_block_lens"] = paddle.clone(self.main_model_inputs["encoder_block_lens"])
|
||||
|
||||
@@ -461,6 +463,7 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["batch_drop"],
|
||||
self.main_model_inputs["accept_tokens"],
|
||||
self.main_model_inputs["accept_num"],
|
||||
self.main_model_inputs["seq_lens_this_time"],
|
||||
self.main_model_inputs["seq_lens_encoder"],
|
||||
self.main_model_inputs["seq_lens_decoder"],
|
||||
self.main_model_inputs["step_idx"],
|
||||
|
Reference in New Issue
Block a user