mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-09 02:20:17 +08:00
support chunk_prefill in MTP (#2705)
This commit is contained in:
@@ -61,3 +61,13 @@ class Proposer(ABC):
|
|||||||
Implemention for different method
|
Implemention for different method
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def is_chunk_prefill_enabled(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check whether chunk-based prefill is enabled.
|
||||||
|
Default is False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if chunk prefill is enabled; False otherwise.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
@@ -405,17 +405,21 @@ class MTPProposer(Proposer):
|
|||||||
1:length]
|
1:length]
|
||||||
self.model_inputs["pre_ids"][idx:idx + 1] = -1
|
self.model_inputs["pre_ids"][idx:idx + 1] = -1
|
||||||
self.model_inputs["step_idx"][idx:idx + 1] = 0
|
self.model_inputs["step_idx"][idx:idx + 1] = 0
|
||||||
# TODO(liuzichang) finish chunked_prefill
|
|
||||||
if self.parallel_config.enable_chunked_prefill:
|
if self.parallel_config.enable_chunked_prefill:
|
||||||
raise NotImplementedError(
|
token_chunk_size = request.prefill_chunk_info[0]
|
||||||
"MTP don't support chunked_prefill now")
|
self.model_inputs["seq_lens_encoder"][idx:idx +
|
||||||
|
1] = token_chunk_size
|
||||||
|
self.model_inputs["seq_lens_this_time"][
|
||||||
|
idx:idx + 1] = token_chunk_size
|
||||||
else:
|
else:
|
||||||
self.model_inputs["seq_lens_encoder"][idx:idx + 1] = length
|
self.model_inputs["seq_lens_encoder"][idx:idx + 1] = length
|
||||||
self.model_inputs["seq_lens_decoder"][idx:idx + 1] = (
|
|
||||||
request.get("seq_lens_decoder", 0))
|
|
||||||
self.model_inputs["seq_lens_this_time"][idx:idx +
|
self.model_inputs["seq_lens_this_time"][idx:idx +
|
||||||
1] = length
|
1] = length
|
||||||
|
|
||||||
|
self.model_inputs["seq_lens_decoder"][idx:idx +
|
||||||
|
1] = (request.get(
|
||||||
|
"seq_lens_decoder",
|
||||||
|
0))
|
||||||
self.model_inputs["stop_flags"][idx:idx + 1] = False
|
self.model_inputs["stop_flags"][idx:idx + 1] = False
|
||||||
self.model_inputs["batch_drop"][idx:idx + 1] = False
|
self.model_inputs["batch_drop"][idx:idx + 1] = False
|
||||||
|
|
||||||
@@ -578,7 +582,6 @@ class MTPProposer(Proposer):
|
|||||||
self.model_inputs["output_padding_offset"],
|
self.model_inputs["output_padding_offset"],
|
||||||
self.parallel_config.max_model_len,
|
self.parallel_config.max_model_len,
|
||||||
)
|
)
|
||||||
paddle.device.synchronize()
|
|
||||||
|
|
||||||
# 4. Compute logits, Sample
|
# 4. Compute logits, Sample
|
||||||
logits = self.model.compute_logits(hiddden_states)
|
logits = self.model.compute_logits(hiddden_states)
|
||||||
@@ -595,6 +598,43 @@ class MTPProposer(Proposer):
|
|||||||
|
|
||||||
self._post_process(sampled_token_ids)
|
self._post_process(sampled_token_ids)
|
||||||
|
|
||||||
|
def update_task_chunk_prefill(self, task):
|
||||||
|
"""
|
||||||
|
Update single task's chunk_prefill info
|
||||||
|
"""
|
||||||
|
idx = task.idx
|
||||||
|
start_idx = sum(task.prefill_chunk_info[:task.chunk_idx])
|
||||||
|
|
||||||
|
if task.chunk_idx == len(task.prefill_chunk_info):
|
||||||
|
self.model_inputs['seq_lens_encoder'][idx:idx + 1] = 0
|
||||||
|
self.model_inputs["step_idx"][idx:idx + 1] = 1
|
||||||
|
self.model_inputs["seq_lens_decoder"][idx:idx +
|
||||||
|
1] = start_idx + task.get(
|
||||||
|
"seq_lens_decoder", 0)
|
||||||
|
else:
|
||||||
|
token_chunk_size = task.prefill_chunk_info[task.chunk_idx]
|
||||||
|
|
||||||
|
if task.chunk_idx < len(task.prefill_chunk_info) - 1:
|
||||||
|
self.model_inputs['input_ids'][
|
||||||
|
idx, :token_chunk_size] = np.array(
|
||||||
|
task.prompt_token_ids[start_idx + 1:start_idx +
|
||||||
|
token_chunk_size + 1])
|
||||||
|
# Last prefill
|
||||||
|
else:
|
||||||
|
self.model_inputs['input_ids'][
|
||||||
|
idx, :token_chunk_size - 1] = np.array(
|
||||||
|
task.prompt_token_ids[start_idx + 1:start_idx +
|
||||||
|
token_chunk_size])
|
||||||
|
|
||||||
|
self.model_inputs["seq_lens_this_time"][idx:idx +
|
||||||
|
1] = token_chunk_size
|
||||||
|
self.model_inputs['seq_lens_encoder'][idx:idx +
|
||||||
|
1] = token_chunk_size
|
||||||
|
self.model_inputs["step_idx"][idx:idx + 1] = 0
|
||||||
|
self.model_inputs["seq_lens_decoder"][idx:idx +
|
||||||
|
1] = start_idx + task.get(
|
||||||
|
"seq_lens_decoder", 0)
|
||||||
|
|
||||||
def _update_status(self):
|
def _update_status(self):
|
||||||
"""
|
"""
|
||||||
Update main-model's forward info in next step.
|
Update main-model's forward info in next step.
|
||||||
@@ -624,6 +664,11 @@ class MTPProposer(Proposer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _run_impl(self, full_hidden_states):
|
def _run_impl(self, full_hidden_states):
|
||||||
|
""""""
|
||||||
target_hidden_states = self._prepare_inputs(full_hidden_states)
|
target_hidden_states = self._prepare_inputs(full_hidden_states)
|
||||||
self._propose(target_hidden_states=target_hidden_states)
|
self._propose(target_hidden_states=target_hidden_states)
|
||||||
self._update_status()
|
self._update_status()
|
||||||
|
|
||||||
|
def is_chunk_prefill_enabled(self):
|
||||||
|
""""""
|
||||||
|
return True
|
||||||
|
@@ -898,6 +898,9 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["step_idx"][idx:idx + 1] = 0
|
self.share_inputs["step_idx"][idx:idx + 1] = 0
|
||||||
self.share_inputs["seq_lens_decoder"][
|
self.share_inputs["seq_lens_decoder"][
|
||||||
idx:idx + 1] = start_idx + task.get("seq_lens_decoder", 0)
|
idx:idx + 1] = start_idx + task.get("seq_lens_decoder", 0)
|
||||||
|
if self.speculative_decoding and self.proposer.is_chunk_prefill_enabled(
|
||||||
|
):
|
||||||
|
self.proposer.update_task_chunk_prefill(task)
|
||||||
task.chunk_idx += 1
|
task.chunk_idx += 1
|
||||||
|
|
||||||
def _dummy_sampler_run(self) -> paddle.Tensor:
|
def _dummy_sampler_run(self) -> paddle.Tensor:
|
||||||
|
Reference in New Issue
Block a user