mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[MTP]update hybrid-mtp-with-ngram (#4047)
This commit is contained in:
@@ -295,6 +295,11 @@ class MTPProposer(Proposer):
|
||||
# Same shape/dytpe with base model
|
||||
self.model_inputs["block_tables"] = paddle.clone(self.target_model_inputs["block_tables"])
|
||||
self.model_inputs["input_ids"] = paddle.clone(self.target_model_inputs["input_ids"])
|
||||
self.model_inputs["input_ids_cpu"] = paddle.full(
|
||||
shape=[self.max_num_seqs, self.parallel_config.max_model_len],
|
||||
fill_value=-1,
|
||||
dtype="int64",
|
||||
).cpu()
|
||||
self.seq_lens_this_time_buffer = paddle.clone(self.target_model_inputs["seq_lens_this_time"])
|
||||
|
||||
self.model_inputs["seq_lens_encoder"] = paddle.clone(self.target_model_inputs["seq_lens_encoder"])
|
||||
@@ -401,11 +406,14 @@ class MTPProposer(Proposer):
|
||||
|
||||
input_ids = request.prompt_token_ids + request.output_token_ids
|
||||
|
||||
self.input_ids_len[idx] = length
|
||||
self.input_ids_len[idx] = length - 1
|
||||
self.model_inputs["pre_ids"][idx : idx + 1] = -1
|
||||
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs["input_ids"][
|
||||
idx : idx + 1, 1:length
|
||||
]
|
||||
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = self.target_model_inputs[
|
||||
"input_ids"
|
||||
][idx : idx + 1, 1:length].cpu()
|
||||
encoder_block_num = len(request.block_tables)
|
||||
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
|
||||
@@ -468,10 +476,17 @@ class MTPProposer(Proposer):
|
||||
request = req_dicts[i]
|
||||
idx = request.idx
|
||||
length = len(request.prompt_token_ids)
|
||||
self.input_ids_len[idx] = length
|
||||
self.input_ids_len[idx] = length - 1
|
||||
|
||||
if req_dicts[i].disaggregate_info is not None and req_dicts[i].disaggregate_info["role"] == "decode":
|
||||
length = len(request.prompt_token_ids)
|
||||
if length > 1:
|
||||
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[
|
||||
"input_ids"
|
||||
][idx : idx + 1, 1:length]
|
||||
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array(
|
||||
request.prompt_token_ids
|
||||
)[1:]
|
||||
self.model_inputs["pre_ids"][idx : idx + 1] = request.prompt_token_ids[-1]
|
||||
prefill_token_num = self.max_draft_token_num + 1
|
||||
self.model_inputs["draft_tokens"][idx : idx + 1, 0:1] = paddle.to_tensor(
|
||||
@@ -500,6 +515,9 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[
|
||||
"input_ids"
|
||||
][idx : idx + 1, 1:length]
|
||||
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array(
|
||||
request.prompt_token_ids
|
||||
)[1:]
|
||||
self.model_inputs["pre_ids"][idx : idx + 1] = -1
|
||||
self.model_inputs["step_idx"][idx : idx + 1] = 0
|
||||
if self.cache_config.enable_chunked_prefill:
|
||||
@@ -800,7 +818,7 @@ class MTPProposer(Proposer):
|
||||
seq_lens_this_time = self.target_model_inputs["seq_lens_this_time"].cpu()
|
||||
seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu()
|
||||
hybrid_mtp_ngram(
|
||||
self.model_inputs["input_ids"]._copy_to(device, True),
|
||||
self.model_inputs["input_ids_cpu"],
|
||||
self.input_ids_len,
|
||||
self.model_inputs["pre_ids"]._copy_to(device, True),
|
||||
self.model_inputs["step_idx"].cpu(),
|
||||
|
Reference in New Issue
Block a user