diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index e1d48f41c..c20455e06 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -681,6 +681,12 @@ std::vector EagleGetHiddenStates( const paddle::Tensor& base_model_seq_lens_encoder, const int actual_draft_token_num); +std::vector EagleGetSelfHiddenStates( + const paddle::Tensor& input, + const paddle::Tensor& last_seq_lens_this_time, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& step_idx); + void MTPStepPaddle( const paddle::Tensor &base_model_stop_flags, const paddle::Tensor &stop_flags, @@ -1063,6 +1069,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("eagle_get_hidden_states",&EagleGetHiddenStates, "eagle_get_hidden_states function"); + m.def("eagle_get_self_hidden_states", &EagleGetSelfHiddenStates, "eagle_get_self_hidden_states function"); + m.def("mtp_step_paddle",&MTPStepPaddle, "mtp_step_paddle function"); m.def("speculate_step_paddle",&SpeculateStepPaddle, "speculate_step_paddle function"); diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index d421b6a54..3acf7714d 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -491,13 +491,6 @@ class MTPProposer(Proposer): """ for substep in range(self.max_draft_token_num): if self.model_inputs["not_need_stop"]: - if substep != 0: - target_hidden_states = eagle_get_self_hidden_states( - hiddden_states, - self.last_seq_lens_this_time, - self.model_inputs["seq_lens_this_time"], - self.model_inputs["step_idx"], - ) self.model_inputs["substep"] = substep # Remove padding ( @@ -543,17 +536,15 @@ class MTPProposer(Proposer): ) if self.max_draft_token_num > 1: - self.last_seq_lens_this_time = paddle.clone( - self.model_inputs["seq_lens_this_time"] - ) - + self.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"]) + model_output = self.model( ids_remove_padding=self.model_inputs["ids_remove_padding"], previous_hidden_states=target_hidden_states, forward_meta=self.forward_meta, ) - hiddden_states = rebuild_padding( + hidden_states = rebuild_padding( model_output, self.model_inputs["cum_offsets"], self.model_inputs["seq_lens_this_time"], @@ -564,7 +555,7 @@ class MTPProposer(Proposer): ) # 4. Compute logits, Sample - logits = self.model.compute_logits(hiddden_states) + logits = self.model.compute_logits(hidden_states) sampled_token_ids = self.sampler( logits, @@ -578,6 +569,21 @@ class MTPProposer(Proposer): self._post_process(sampled_token_ids) + if substep != self.max_draft_token_num - 1: + target_hidden_states = self._get_self_hidden_states(hidden_states) + + def _get_self_hidden_states(self, hidden_states): + target_hidden_states = eagle_get_self_hidden_states( + hidden_states, + self.last_seq_lens_this_time, + self.model_inputs["seq_lens_this_time"], + self.model_inputs["step_idx"], + ) + if isinstance(target_hidden_states, list): + target_hidden_states = target_hidden_states[0] + + return target_hidden_states + def update_task_chunk_prefill(self, task): """ Update single task's chunk_prefill info