[MTP Fix] Fix code and register cpp operators (#2965)

This commit is contained in:
GoldPancake
2025-07-22 19:36:24 +08:00
committed by GitHub
parent 93bb68aa71
commit 9b84d51e25
2 changed files with 27 additions and 13 deletions

View File

@@ -681,6 +681,12 @@ std::vector<paddle::Tensor> EagleGetHiddenStates(
const paddle::Tensor& base_model_seq_lens_encoder, const paddle::Tensor& base_model_seq_lens_encoder,
const int actual_draft_token_num); const int actual_draft_token_num);
std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
const paddle::Tensor& input,
const paddle::Tensor& last_seq_lens_this_time,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& step_idx);
void MTPStepPaddle( void MTPStepPaddle(
const paddle::Tensor &base_model_stop_flags, const paddle::Tensor &base_model_stop_flags,
const paddle::Tensor &stop_flags, const paddle::Tensor &stop_flags,
@@ -1063,6 +1069,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("eagle_get_hidden_states",&EagleGetHiddenStates, "eagle_get_hidden_states function"); m.def("eagle_get_hidden_states",&EagleGetHiddenStates, "eagle_get_hidden_states function");
m.def("eagle_get_self_hidden_states", &EagleGetSelfHiddenStates, "eagle_get_self_hidden_states function");
m.def("mtp_step_paddle",&MTPStepPaddle, "mtp_step_paddle function"); m.def("mtp_step_paddle",&MTPStepPaddle, "mtp_step_paddle function");
m.def("speculate_step_paddle",&SpeculateStepPaddle, "speculate_step_paddle function"); m.def("speculate_step_paddle",&SpeculateStepPaddle, "speculate_step_paddle function");

View File

@@ -491,13 +491,6 @@ class MTPProposer(Proposer):
""" """
for substep in range(self.max_draft_token_num): for substep in range(self.max_draft_token_num):
if self.model_inputs["not_need_stop"]: if self.model_inputs["not_need_stop"]:
if substep != 0:
target_hidden_states = eagle_get_self_hidden_states(
hiddden_states,
self.last_seq_lens_this_time,
self.model_inputs["seq_lens_this_time"],
self.model_inputs["step_idx"],
)
self.model_inputs["substep"] = substep self.model_inputs["substep"] = substep
# Remove padding # Remove padding
( (
@@ -543,9 +536,7 @@ class MTPProposer(Proposer):
) )
if self.max_draft_token_num > 1: if self.max_draft_token_num > 1:
self.last_seq_lens_this_time = paddle.clone( self.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"])
self.model_inputs["seq_lens_this_time"]
)
model_output = self.model( model_output = self.model(
ids_remove_padding=self.model_inputs["ids_remove_padding"], ids_remove_padding=self.model_inputs["ids_remove_padding"],
@@ -553,7 +544,7 @@ class MTPProposer(Proposer):
forward_meta=self.forward_meta, forward_meta=self.forward_meta,
) )
hiddden_states = rebuild_padding( hidden_states = rebuild_padding(
model_output, model_output,
self.model_inputs["cum_offsets"], self.model_inputs["cum_offsets"],
self.model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_this_time"],
@@ -564,7 +555,7 @@ class MTPProposer(Proposer):
) )
# 4. Compute logits, Sample # 4. Compute logits, Sample
logits = self.model.compute_logits(hiddden_states) logits = self.model.compute_logits(hidden_states)
sampled_token_ids = self.sampler( sampled_token_ids = self.sampler(
logits, logits,
@@ -578,6 +569,21 @@ class MTPProposer(Proposer):
self._post_process(sampled_token_ids) self._post_process(sampled_token_ids)
if substep != self.max_draft_token_num - 1:
target_hidden_states = self._get_self_hidden_states(hidden_states)
def _get_self_hidden_states(self, hidden_states):
target_hidden_states = eagle_get_self_hidden_states(
hidden_states,
self.last_seq_lens_this_time,
self.model_inputs["seq_lens_this_time"],
self.model_inputs["step_idx"],
)
if isinstance(target_hidden_states, list):
target_hidden_states = target_hidden_states[0]
return target_hidden_states
def update_task_chunk_prefill(self, task): def update_task_chunk_prefill(self, task):
""" """
Update single task's chunk_prefill info Update single task's chunk_prefill info