[MTP]update hybrid-mtp-with-ngram (#4047)

This commit is contained in:
freeliuzc
2025-09-15 17:13:31 +08:00
committed by GitHub
parent b1b33211e8
commit 46911f903d
3 changed files with 41 additions and 3 deletions

View File

@@ -295,6 +295,11 @@ class MTPProposer(Proposer):
# Same shape/dytpe with base model
self.model_inputs["block_tables"] = paddle.clone(self.target_model_inputs["block_tables"])
self.model_inputs["input_ids"] = paddle.clone(self.target_model_inputs["input_ids"])
self.model_inputs["input_ids_cpu"] = paddle.full(
shape=[self.max_num_seqs, self.parallel_config.max_model_len],
fill_value=-1,
dtype="int64",
).cpu()
self.seq_lens_this_time_buffer = paddle.clone(self.target_model_inputs["seq_lens_this_time"])
self.model_inputs["seq_lens_encoder"] = paddle.clone(self.target_model_inputs["seq_lens_encoder"])
@@ -401,11 +406,14 @@ class MTPProposer(Proposer):
input_ids = request.prompt_token_ids + request.output_token_ids
self.input_ids_len[idx] = length
self.input_ids_len[idx] = length - 1
self.model_inputs["pre_ids"][idx : idx + 1] = -1
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs["input_ids"][
idx : idx + 1, 1:length
]
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = self.target_model_inputs[
"input_ids"
][idx : idx + 1, 1:length].cpu()
encoder_block_num = len(request.block_tables)
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
@@ -468,10 +476,17 @@ class MTPProposer(Proposer):
request = req_dicts[i]
idx = request.idx
length = len(request.prompt_token_ids)
self.input_ids_len[idx] = length
self.input_ids_len[idx] = length - 1
if req_dicts[i].disaggregate_info is not None and req_dicts[i].disaggregate_info["role"] == "decode":
length = len(request.prompt_token_ids)
if length > 1:
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[
"input_ids"
][idx : idx + 1, 1:length]
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array(
request.prompt_token_ids
)[1:]
self.model_inputs["pre_ids"][idx : idx + 1] = request.prompt_token_ids[-1]
prefill_token_num = self.max_draft_token_num + 1
self.model_inputs["draft_tokens"][idx : idx + 1, 0:1] = paddle.to_tensor(
@@ -500,6 +515,9 @@ class MTPProposer(Proposer):
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[
"input_ids"
][idx : idx + 1, 1:length]
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array(
request.prompt_token_ids
)[1:]
self.model_inputs["pre_ids"][idx : idx + 1] = -1
self.model_inputs["step_idx"][idx : idx + 1] = 0
if self.cache_config.enable_chunked_prefill:
@@ -800,7 +818,7 @@ class MTPProposer(Proposer):
seq_lens_this_time = self.target_model_inputs["seq_lens_this_time"].cpu()
seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu()
hybrid_mtp_ngram(
self.model_inputs["input_ids"]._copy_to(device, True),
self.model_inputs["input_ids_cpu"],
self.input_ids_len,
self.model_inputs["pre_ids"]._copy_to(device, True),
self.model_inputs["step_idx"].cpu(),