mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
update hybrid-mtp-with-ngram (#3924)
This commit is contained in:
@@ -268,7 +268,11 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["block_tables"] = paddle.clone(self.main_model_inputs["block_tables"])
|
||||
self.model_inputs["input_ids"] = paddle.clone(self.main_model_inputs["input_ids"])
|
||||
self.seq_lens_this_time_buffer = paddle.clone(self.main_model_inputs["seq_lens_this_time"])
|
||||
|
||||
self.model_inputs["input_ids_cpu"] = paddle.full(
|
||||
shape=[self.max_num_seqs, self.parallel_config.max_model_len],
|
||||
fill_value=-1,
|
||||
dtype="int64",
|
||||
).cpu()
|
||||
self.model_inputs["seq_lens_encoder"] = paddle.clone(self.main_model_inputs["seq_lens_encoder"])
|
||||
self.model_inputs["seq_lens_decoder"] = paddle.clone(self.main_model_inputs["seq_lens_decoder"])
|
||||
self.model_inputs["step_idx"] = paddle.clone(self.main_model_inputs["step_idx"])
|
||||
@@ -368,10 +372,17 @@ class MTPProposer(Proposer):
|
||||
request = req_dicts[i]
|
||||
idx = request.idx
|
||||
length = len(request.prompt_token_ids)
|
||||
self.input_ids_len[idx] = length
|
||||
self.input_ids_len[idx] = length - 1
|
||||
|
||||
if req_dicts[i].disaggregate_info is not None and req_dicts[i].disaggregate_info["role"] == "decode":
|
||||
length = len(request.prompt_token_ids)
|
||||
if length > 1:
|
||||
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[
|
||||
"input_ids"
|
||||
][idx : idx + 1, 1:length]
|
||||
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array(
|
||||
request.prompt_token_ids
|
||||
)[1:]
|
||||
self.model_inputs["pre_ids"][idx : idx + 1] = request.prompt_token_ids[-1]
|
||||
prefill_token_num = self.max_draft_token_num + 1
|
||||
self.model_inputs["draft_tokens"][idx : idx + 1, 0:1] = paddle.to_tensor(
|
||||
@@ -400,6 +411,10 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.main_model_inputs["input_ids"][
|
||||
idx : idx + 1, 1:length
|
||||
]
|
||||
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array(
|
||||
request.prompt_token_ids
|
||||
)[1:]
|
||||
|
||||
self.model_inputs["pre_ids"][idx : idx + 1] = -1
|
||||
self.model_inputs["step_idx"][idx : idx + 1] = 0
|
||||
if self.cache_config.enable_chunked_prefill:
|
||||
@@ -688,7 +703,7 @@ class MTPProposer(Proposer):
|
||||
seq_lens_this_time = self.main_model_inputs["seq_lens_this_time"].cpu()
|
||||
seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu()
|
||||
hybrid_mtp_ngram(
|
||||
self.model_inputs["input_ids"]._copy_to(device, True),
|
||||
self.model_inputs["input_ids_cpu"],
|
||||
self.input_ids_len,
|
||||
self.model_inputs["pre_ids"]._copy_to(device, True),
|
||||
self.model_inputs["step_idx"].cpu(),
|
||||
|
Reference in New Issue
Block a user