mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Feature][MTP]support new speculative decoding method named hybrid mtp with ngram (#3610)
This commit is contained in:
@@ -35,6 +35,7 @@ from fastdeploy.model_executor.ops.gpu import (
|
||||
draft_model_update,
|
||||
eagle_get_hidden_states,
|
||||
eagle_get_self_hidden_states,
|
||||
hybrid_mtp_ngram,
|
||||
mtp_save_first_token,
|
||||
mtp_step_paddle,
|
||||
share_external_data,
|
||||
@@ -57,6 +58,8 @@ class MTPProposer(Proposer):
|
||||
self._update_cfg(main_model)
|
||||
self._load_model()
|
||||
self.main_model_inputs = main_model_inputs
|
||||
self.mtp_strategy = self.speculative_config.mtp_strategy
|
||||
self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps
|
||||
|
||||
# [mixed, prefill, decoder]
|
||||
self.role = "mixed"
|
||||
@@ -336,10 +339,11 @@ class MTPProposer(Proposer):
|
||||
|
||||
self.model_inputs["batch_drop"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool")
|
||||
self.model_inputs["used_list_len"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32")
|
||||
if self.max_draft_token_num > 1:
|
||||
if self.num_model_steps > 1:
|
||||
self.last_seq_lens_this_time = paddle.full_like(
|
||||
self.main_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32"
|
||||
)
|
||||
self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu()
|
||||
|
||||
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int):
|
||||
"""
|
||||
@@ -364,6 +368,7 @@ class MTPProposer(Proposer):
|
||||
request = req_dicts[i]
|
||||
idx = request.idx
|
||||
length = len(request.prompt_token_ids)
|
||||
self.input_ids_len[idx] = length
|
||||
|
||||
if req_dicts[i].disaggregate_info is not None and req_dicts[i].disaggregate_info["role"] == "decode":
|
||||
length = len(request.prompt_token_ids)
|
||||
@@ -460,6 +465,7 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["step_idx"],
|
||||
self.model_inputs["not_need_stop"],
|
||||
self.model_inputs["batch_drop"],
|
||||
self.model_inputs["pre_ids"],
|
||||
self.main_model_inputs["accept_tokens"],
|
||||
self.main_model_inputs["accept_num"],
|
||||
self.main_model_inputs["seq_lens_this_time"],
|
||||
@@ -469,7 +475,7 @@ class MTPProposer(Proposer):
|
||||
self.main_model_inputs["stop_flags"],
|
||||
self.main_model_inputs["is_block_step"],
|
||||
self.main_model_inputs["draft_tokens"],
|
||||
self.max_draft_token_num,
|
||||
self.num_model_steps,
|
||||
self.speculative_method in ["eagle", "mtp"],
|
||||
self.role == "prefill",
|
||||
)
|
||||
@@ -483,7 +489,7 @@ class MTPProposer(Proposer):
|
||||
self.main_model_inputs["accept_num"],
|
||||
self.main_model_inputs["seq_lens_this_time"],
|
||||
self.main_model_inputs["seq_lens_encoder"],
|
||||
self.max_draft_token_num,
|
||||
self.num_model_steps,
|
||||
)
|
||||
if isinstance(target_hidden_states, list):
|
||||
target_hidden_states = target_hidden_states[0]
|
||||
@@ -523,7 +529,7 @@ class MTPProposer(Proposer):
|
||||
"""
|
||||
Main process for MTP inference
|
||||
"""
|
||||
for substep in range(self.max_draft_token_num):
|
||||
for substep in range(self.num_model_steps):
|
||||
if self.model_inputs["not_need_stop"]:
|
||||
self.model_inputs["substep"] = substep
|
||||
# Remove padding
|
||||
@@ -542,6 +548,7 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["seq_lens_encoder"],
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
)
|
||||
|
||||
# Initialize forward meta data
|
||||
self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
|
||||
self.model_inputs["batch_id_per_token"].copy_(batch_id_per_token, False)
|
||||
@@ -567,7 +574,7 @@ class MTPProposer(Proposer):
|
||||
eos_token_ids=self.model_inputs["eos_token_id"],
|
||||
)
|
||||
|
||||
if self.max_draft_token_num > 1:
|
||||
if self.num_model_steps > 1:
|
||||
self.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"])
|
||||
|
||||
model_output = self.model(
|
||||
@@ -601,7 +608,7 @@ class MTPProposer(Proposer):
|
||||
|
||||
self._post_process(sampled_token_ids)
|
||||
|
||||
if substep != self.max_draft_token_num - 1:
|
||||
if substep != self.num_model_steps - 1:
|
||||
target_hidden_states = self._get_self_hidden_states(hidden_states)
|
||||
|
||||
def _get_self_hidden_states(self, hidden_states):
|
||||
@@ -673,11 +680,37 @@ class MTPProposer(Proposer):
|
||||
self.max_draft_token_num,
|
||||
)
|
||||
|
||||
def _extend_draft_token_with_ngram_match(self):
|
||||
# TODO(liuzichang): Optimize this Kernel to CUDA Kernel to reduce lantency
|
||||
device = paddle.CUDAPinnedPlace()
|
||||
|
||||
draft_tokens = self.main_model_inputs["draft_tokens"].cpu()
|
||||
seq_lens_this_time = self.main_model_inputs["seq_lens_this_time"].cpu()
|
||||
seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu()
|
||||
hybrid_mtp_ngram(
|
||||
self.model_inputs["input_ids"]._copy_to(device, True),
|
||||
self.input_ids_len,
|
||||
self.model_inputs["pre_ids"]._copy_to(device, True),
|
||||
self.model_inputs["step_idx"].cpu(),
|
||||
self.main_model_inputs["actual_draft_token_num"].cpu(),
|
||||
draft_tokens,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
self.model_inputs["max_dec_len"].cpu(),
|
||||
self.max_ngram_size,
|
||||
self.min_ngram_size,
|
||||
self.max_draft_token_num,
|
||||
)
|
||||
self.main_model_inputs["draft_tokens"][:] = draft_tokens.cuda()
|
||||
self.main_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()
|
||||
|
||||
def _run_impl(self, full_hidden_states):
|
||||
""""""
|
||||
target_hidden_states = self._prepare_inputs(full_hidden_states)
|
||||
self._propose(target_hidden_states=target_hidden_states)
|
||||
self._update_status()
|
||||
if self.hybrid_mode:
|
||||
self._extend_draft_token_with_ngram_match()
|
||||
|
||||
def is_chunk_prefill_enabled(self):
|
||||
""""""
|
||||
|
Reference in New Issue
Block a user