support chunk_prefill in MTP (#2705)

This commit is contained in:
freeliuzc
2025-07-04 11:55:48 +08:00
committed by GitHub
parent b38823bc66
commit 667547be59
3 changed files with 64 additions and 6 deletions

View File

@@ -405,17 +405,21 @@ class MTPProposer(Proposer):
1:length]
self.model_inputs["pre_ids"][idx:idx + 1] = -1
self.model_inputs["step_idx"][idx:idx + 1] = 0
# TODO(liuzichang) finish chunked_prefill
if self.parallel_config.enable_chunked_prefill:
raise NotImplementedError(
"MTP don't support chunked_prefill now")
token_chunk_size = request.prefill_chunk_info[0]
self.model_inputs["seq_lens_encoder"][idx:idx +
1] = token_chunk_size
self.model_inputs["seq_lens_this_time"][
idx:idx + 1] = token_chunk_size
else:
self.model_inputs["seq_lens_encoder"][idx:idx + 1] = length
self.model_inputs["seq_lens_decoder"][idx:idx + 1] = (
request.get("seq_lens_decoder", 0))
self.model_inputs["seq_lens_this_time"][idx:idx +
1] = length
self.model_inputs["seq_lens_decoder"][idx:idx +
1] = (request.get(
"seq_lens_decoder",
0))
self.model_inputs["stop_flags"][idx:idx + 1] = False
self.model_inputs["batch_drop"][idx:idx + 1] = False
@@ -578,7 +582,6 @@ class MTPProposer(Proposer):
self.model_inputs["output_padding_offset"],
self.parallel_config.max_model_len,
)
paddle.device.synchronize()
# 4. Compute logits, Sample
logits = self.model.compute_logits(hiddden_states)
@@ -595,6 +598,43 @@ class MTPProposer(Proposer):
self._post_process(sampled_token_ids)
def update_task_chunk_prefill(self, task):
"""
Update single task's chunk_prefill info
"""
idx = task.idx
start_idx = sum(task.prefill_chunk_info[:task.chunk_idx])
if task.chunk_idx == len(task.prefill_chunk_info):
self.model_inputs['seq_lens_encoder'][idx:idx + 1] = 0
self.model_inputs["step_idx"][idx:idx + 1] = 1
self.model_inputs["seq_lens_decoder"][idx:idx +
1] = start_idx + task.get(
"seq_lens_decoder", 0)
else:
token_chunk_size = task.prefill_chunk_info[task.chunk_idx]
if task.chunk_idx < len(task.prefill_chunk_info) - 1:
self.model_inputs['input_ids'][
idx, :token_chunk_size] = np.array(
task.prompt_token_ids[start_idx + 1:start_idx +
token_chunk_size + 1])
# Last prefill
else:
self.model_inputs['input_ids'][
idx, :token_chunk_size - 1] = np.array(
task.prompt_token_ids[start_idx + 1:start_idx +
token_chunk_size])
self.model_inputs["seq_lens_this_time"][idx:idx +
1] = token_chunk_size
self.model_inputs['seq_lens_encoder'][idx:idx +
1] = token_chunk_size
self.model_inputs["step_idx"][idx:idx + 1] = 0
self.model_inputs["seq_lens_decoder"][idx:idx +
1] = start_idx + task.get(
"seq_lens_decoder", 0)
def _update_status(self):
"""
Update main-model's forward info in next step.
@@ -624,6 +664,11 @@ class MTPProposer(Proposer):
)
def _run_impl(self, full_hidden_states):
""""""
target_hidden_states = self._prepare_inputs(full_hidden_states)
self._propose(target_hidden_states=target_hidden_states)
self._update_status()
def is_chunk_prefill_enabled(self):
""""""
return True