support chunk_prefill in MTP (#2705)

This commit is contained in:
freeliuzc
2025-07-04 11:55:48 +08:00
committed by GitHub
parent b38823bc66
commit 667547be59
3 changed files with 64 additions and 6 deletions

View File

@@ -61,3 +61,13 @@ class Proposer(ABC):
Implemention for different method Implemention for different method
""" """
raise NotImplementedError raise NotImplementedError
def is_chunk_prefill_enabled(self) -> bool:
"""
Check whether chunk-based prefill is enabled.
Default is False.
Returns:
bool: True if chunk prefill is enabled; False otherwise.
"""
return False

View File

@@ -405,17 +405,21 @@ class MTPProposer(Proposer):
1:length] 1:length]
self.model_inputs["pre_ids"][idx:idx + 1] = -1 self.model_inputs["pre_ids"][idx:idx + 1] = -1
self.model_inputs["step_idx"][idx:idx + 1] = 0 self.model_inputs["step_idx"][idx:idx + 1] = 0
# TODO(liuzichang) finish chunked_prefill
if self.parallel_config.enable_chunked_prefill: if self.parallel_config.enable_chunked_prefill:
raise NotImplementedError( token_chunk_size = request.prefill_chunk_info[0]
"MTP don't support chunked_prefill now") self.model_inputs["seq_lens_encoder"][idx:idx +
1] = token_chunk_size
self.model_inputs["seq_lens_this_time"][
idx:idx + 1] = token_chunk_size
else: else:
self.model_inputs["seq_lens_encoder"][idx:idx + 1] = length self.model_inputs["seq_lens_encoder"][idx:idx + 1] = length
self.model_inputs["seq_lens_decoder"][idx:idx + 1] = (
request.get("seq_lens_decoder", 0))
self.model_inputs["seq_lens_this_time"][idx:idx + self.model_inputs["seq_lens_this_time"][idx:idx +
1] = length 1] = length
self.model_inputs["seq_lens_decoder"][idx:idx +
1] = (request.get(
"seq_lens_decoder",
0))
self.model_inputs["stop_flags"][idx:idx + 1] = False self.model_inputs["stop_flags"][idx:idx + 1] = False
self.model_inputs["batch_drop"][idx:idx + 1] = False self.model_inputs["batch_drop"][idx:idx + 1] = False
@@ -578,7 +582,6 @@ class MTPProposer(Proposer):
self.model_inputs["output_padding_offset"], self.model_inputs["output_padding_offset"],
self.parallel_config.max_model_len, self.parallel_config.max_model_len,
) )
paddle.device.synchronize()
# 4. Compute logits, Sample # 4. Compute logits, Sample
logits = self.model.compute_logits(hiddden_states) logits = self.model.compute_logits(hiddden_states)
@@ -595,6 +598,43 @@ class MTPProposer(Proposer):
self._post_process(sampled_token_ids) self._post_process(sampled_token_ids)
def update_task_chunk_prefill(self, task):
"""
Update single task's chunk_prefill info
"""
idx = task.idx
start_idx = sum(task.prefill_chunk_info[:task.chunk_idx])
if task.chunk_idx == len(task.prefill_chunk_info):
self.model_inputs['seq_lens_encoder'][idx:idx + 1] = 0
self.model_inputs["step_idx"][idx:idx + 1] = 1
self.model_inputs["seq_lens_decoder"][idx:idx +
1] = start_idx + task.get(
"seq_lens_decoder", 0)
else:
token_chunk_size = task.prefill_chunk_info[task.chunk_idx]
if task.chunk_idx < len(task.prefill_chunk_info) - 1:
self.model_inputs['input_ids'][
idx, :token_chunk_size] = np.array(
task.prompt_token_ids[start_idx + 1:start_idx +
token_chunk_size + 1])
# Last prefill
else:
self.model_inputs['input_ids'][
idx, :token_chunk_size - 1] = np.array(
task.prompt_token_ids[start_idx + 1:start_idx +
token_chunk_size])
self.model_inputs["seq_lens_this_time"][idx:idx +
1] = token_chunk_size
self.model_inputs['seq_lens_encoder'][idx:idx +
1] = token_chunk_size
self.model_inputs["step_idx"][idx:idx + 1] = 0
self.model_inputs["seq_lens_decoder"][idx:idx +
1] = start_idx + task.get(
"seq_lens_decoder", 0)
def _update_status(self): def _update_status(self):
""" """
Update main-model's forward info in next step. Update main-model's forward info in next step.
@@ -624,6 +664,11 @@ class MTPProposer(Proposer):
) )
def _run_impl(self, full_hidden_states): def _run_impl(self, full_hidden_states):
""""""
target_hidden_states = self._prepare_inputs(full_hidden_states) target_hidden_states = self._prepare_inputs(full_hidden_states)
self._propose(target_hidden_states=target_hidden_states) self._propose(target_hidden_states=target_hidden_states)
self._update_status() self._update_status()
def is_chunk_prefill_enabled(self):
""""""
return True

View File

@@ -898,6 +898,9 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["step_idx"][idx:idx + 1] = 0 self.share_inputs["step_idx"][idx:idx + 1] = 0
self.share_inputs["seq_lens_decoder"][ self.share_inputs["seq_lens_decoder"][
idx:idx + 1] = start_idx + task.get("seq_lens_decoder", 0) idx:idx + 1] = start_idx + task.get("seq_lens_decoder", 0)
if self.speculative_decoding and self.proposer.is_chunk_prefill_enabled(
):
self.proposer.update_task_chunk_prefill(task)
task.chunk_idx += 1 task.chunk_idx += 1
def _dummy_sampler_run(self) -> paddle.Tensor: def _dummy_sampler_run(self) -> paddle.Tensor: