[MTP Fix] Fix code and register cpp operators (#2965)

This commit is contained in:
GoldPancake
2025-07-22 19:36:24 +08:00
committed by GitHub
parent 93bb68aa71
commit 9b84d51e25
2 changed files with 27 additions and 13 deletions

View File

@@ -491,13 +491,6 @@ class MTPProposer(Proposer):
"""
for substep in range(self.max_draft_token_num):
if self.model_inputs["not_need_stop"]:
if substep != 0:
target_hidden_states = eagle_get_self_hidden_states(
hiddden_states,
self.last_seq_lens_this_time,
self.model_inputs["seq_lens_this_time"],
self.model_inputs["step_idx"],
)
self.model_inputs["substep"] = substep
# Remove padding
(
@@ -543,17 +536,15 @@ class MTPProposer(Proposer):
)
if self.max_draft_token_num > 1:
self.last_seq_lens_this_time = paddle.clone(
self.model_inputs["seq_lens_this_time"]
)
self.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"])
model_output = self.model(
ids_remove_padding=self.model_inputs["ids_remove_padding"],
previous_hidden_states=target_hidden_states,
forward_meta=self.forward_meta,
)
hiddden_states = rebuild_padding(
hidden_states = rebuild_padding(
model_output,
self.model_inputs["cum_offsets"],
self.model_inputs["seq_lens_this_time"],
@@ -564,7 +555,7 @@ class MTPProposer(Proposer):
)
# 4. Compute logits, Sample
logits = self.model.compute_logits(hiddden_states)
logits = self.model.compute_logits(hidden_states)
sampled_token_ids = self.sampler(
logits,
@@ -578,6 +569,21 @@ class MTPProposer(Proposer):
self._post_process(sampled_token_ids)
if substep != self.max_draft_token_num - 1:
target_hidden_states = self._get_self_hidden_states(hidden_states)
def _get_self_hidden_states(self, hidden_states):
target_hidden_states = eagle_get_self_hidden_states(
hidden_states,
self.last_seq_lens_this_time,
self.model_inputs["seq_lens_this_time"],
self.model_inputs["step_idx"],
)
if isinstance(target_hidden_states, list):
target_hidden_states = target_hidden_states[0]
return target_hidden_states
def update_task_chunk_prefill(self, task):
"""
Update single task's chunk_prefill info