mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[MTP Fix] Fix code and register cpp operators (#2965)
This commit is contained in:
@@ -681,6 +681,12 @@ std::vector<paddle::Tensor> EagleGetHiddenStates(
|
|||||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||||
const int actual_draft_token_num);
|
const int actual_draft_token_num);
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
|
||||||
|
const paddle::Tensor& input,
|
||||||
|
const paddle::Tensor& last_seq_lens_this_time,
|
||||||
|
const paddle::Tensor& seq_lens_this_time,
|
||||||
|
const paddle::Tensor& step_idx);
|
||||||
|
|
||||||
void MTPStepPaddle(
|
void MTPStepPaddle(
|
||||||
const paddle::Tensor &base_model_stop_flags,
|
const paddle::Tensor &base_model_stop_flags,
|
||||||
const paddle::Tensor &stop_flags,
|
const paddle::Tensor &stop_flags,
|
||||||
@@ -1063,6 +1069,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
|
|
||||||
m.def("eagle_get_hidden_states",&EagleGetHiddenStates, "eagle_get_hidden_states function");
|
m.def("eagle_get_hidden_states",&EagleGetHiddenStates, "eagle_get_hidden_states function");
|
||||||
|
|
||||||
|
m.def("eagle_get_self_hidden_states", &EagleGetSelfHiddenStates, "eagle_get_self_hidden_states function");
|
||||||
|
|
||||||
m.def("mtp_step_paddle",&MTPStepPaddle, "mtp_step_paddle function");
|
m.def("mtp_step_paddle",&MTPStepPaddle, "mtp_step_paddle function");
|
||||||
|
|
||||||
m.def("speculate_step_paddle",&SpeculateStepPaddle, "speculate_step_paddle function");
|
m.def("speculate_step_paddle",&SpeculateStepPaddle, "speculate_step_paddle function");
|
||||||
|
@@ -491,13 +491,6 @@ class MTPProposer(Proposer):
|
|||||||
"""
|
"""
|
||||||
for substep in range(self.max_draft_token_num):
|
for substep in range(self.max_draft_token_num):
|
||||||
if self.model_inputs["not_need_stop"]:
|
if self.model_inputs["not_need_stop"]:
|
||||||
if substep != 0:
|
|
||||||
target_hidden_states = eagle_get_self_hidden_states(
|
|
||||||
hiddden_states,
|
|
||||||
self.last_seq_lens_this_time,
|
|
||||||
self.model_inputs["seq_lens_this_time"],
|
|
||||||
self.model_inputs["step_idx"],
|
|
||||||
)
|
|
||||||
self.model_inputs["substep"] = substep
|
self.model_inputs["substep"] = substep
|
||||||
# Remove padding
|
# Remove padding
|
||||||
(
|
(
|
||||||
@@ -543,9 +536,7 @@ class MTPProposer(Proposer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.max_draft_token_num > 1:
|
if self.max_draft_token_num > 1:
|
||||||
self.last_seq_lens_this_time = paddle.clone(
|
self.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"])
|
||||||
self.model_inputs["seq_lens_this_time"]
|
|
||||||
)
|
|
||||||
|
|
||||||
model_output = self.model(
|
model_output = self.model(
|
||||||
ids_remove_padding=self.model_inputs["ids_remove_padding"],
|
ids_remove_padding=self.model_inputs["ids_remove_padding"],
|
||||||
@@ -553,7 +544,7 @@ class MTPProposer(Proposer):
|
|||||||
forward_meta=self.forward_meta,
|
forward_meta=self.forward_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
hiddden_states = rebuild_padding(
|
hidden_states = rebuild_padding(
|
||||||
model_output,
|
model_output,
|
||||||
self.model_inputs["cum_offsets"],
|
self.model_inputs["cum_offsets"],
|
||||||
self.model_inputs["seq_lens_this_time"],
|
self.model_inputs["seq_lens_this_time"],
|
||||||
@@ -564,7 +555,7 @@ class MTPProposer(Proposer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 4. Compute logits, Sample
|
# 4. Compute logits, Sample
|
||||||
logits = self.model.compute_logits(hiddden_states)
|
logits = self.model.compute_logits(hidden_states)
|
||||||
|
|
||||||
sampled_token_ids = self.sampler(
|
sampled_token_ids = self.sampler(
|
||||||
logits,
|
logits,
|
||||||
@@ -578,6 +569,21 @@ class MTPProposer(Proposer):
|
|||||||
|
|
||||||
self._post_process(sampled_token_ids)
|
self._post_process(sampled_token_ids)
|
||||||
|
|
||||||
|
if substep != self.max_draft_token_num - 1:
|
||||||
|
target_hidden_states = self._get_self_hidden_states(hidden_states)
|
||||||
|
|
||||||
|
def _get_self_hidden_states(self, hidden_states):
|
||||||
|
target_hidden_states = eagle_get_self_hidden_states(
|
||||||
|
hidden_states,
|
||||||
|
self.last_seq_lens_this_time,
|
||||||
|
self.model_inputs["seq_lens_this_time"],
|
||||||
|
self.model_inputs["step_idx"],
|
||||||
|
)
|
||||||
|
if isinstance(target_hidden_states, list):
|
||||||
|
target_hidden_states = target_hidden_states[0]
|
||||||
|
|
||||||
|
return target_hidden_states
|
||||||
|
|
||||||
def update_task_chunk_prefill(self, task):
|
def update_task_chunk_prefill(self, task):
|
||||||
"""
|
"""
|
||||||
Update single task's chunk_prefill info
|
Update single task's chunk_prefill info
|
||||||
|
Reference in New Issue
Block a user