mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[MTP Fix] Fix code and register cpp operators (#2965)
This commit is contained in:
@@ -491,13 +491,6 @@ class MTPProposer(Proposer):
|
||||
"""
|
||||
for substep in range(self.max_draft_token_num):
|
||||
if self.model_inputs["not_need_stop"]:
|
||||
if substep != 0:
|
||||
target_hidden_states = eagle_get_self_hidden_states(
|
||||
hiddden_states,
|
||||
self.last_seq_lens_this_time,
|
||||
self.model_inputs["seq_lens_this_time"],
|
||||
self.model_inputs["step_idx"],
|
||||
)
|
||||
self.model_inputs["substep"] = substep
|
||||
# Remove padding
|
||||
(
|
||||
@@ -543,17 +536,15 @@ class MTPProposer(Proposer):
|
||||
)
|
||||
|
||||
if self.max_draft_token_num > 1:
|
||||
self.last_seq_lens_this_time = paddle.clone(
|
||||
self.model_inputs["seq_lens_this_time"]
|
||||
)
|
||||
|
||||
self.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"])
|
||||
|
||||
model_output = self.model(
|
||||
ids_remove_padding=self.model_inputs["ids_remove_padding"],
|
||||
previous_hidden_states=target_hidden_states,
|
||||
forward_meta=self.forward_meta,
|
||||
)
|
||||
|
||||
hiddden_states = rebuild_padding(
|
||||
hidden_states = rebuild_padding(
|
||||
model_output,
|
||||
self.model_inputs["cum_offsets"],
|
||||
self.model_inputs["seq_lens_this_time"],
|
||||
@@ -564,7 +555,7 @@ class MTPProposer(Proposer):
|
||||
)
|
||||
|
||||
# 4. Compute logits, Sample
|
||||
logits = self.model.compute_logits(hiddden_states)
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
|
||||
sampled_token_ids = self.sampler(
|
||||
logits,
|
||||
@@ -578,6 +569,21 @@ class MTPProposer(Proposer):
|
||||
|
||||
self._post_process(sampled_token_ids)
|
||||
|
||||
if substep != self.max_draft_token_num - 1:
|
||||
target_hidden_states = self._get_self_hidden_states(hidden_states)
|
||||
|
||||
def _get_self_hidden_states(self, hidden_states):
|
||||
target_hidden_states = eagle_get_self_hidden_states(
|
||||
hidden_states,
|
||||
self.last_seq_lens_this_time,
|
||||
self.model_inputs["seq_lens_this_time"],
|
||||
self.model_inputs["step_idx"],
|
||||
)
|
||||
if isinstance(target_hidden_states, list):
|
||||
target_hidden_states = target_hidden_states[0]
|
||||
|
||||
return target_hidden_states
|
||||
|
||||
def update_task_chunk_prefill(self, task):
|
||||
"""
|
||||
Update single task's chunk_prefill info
|
||||
|
Reference in New Issue
Block a user