mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 17:41:52 +08:00
support chunk_prefill in MTP (#2705)
This commit is contained in:
@@ -405,17 +405,21 @@ class MTPProposer(Proposer):
|
||||
1:length]
|
||||
self.model_inputs["pre_ids"][idx:idx + 1] = -1
|
||||
self.model_inputs["step_idx"][idx:idx + 1] = 0
|
||||
# TODO(liuzichang) finish chunked_prefill
|
||||
if self.parallel_config.enable_chunked_prefill:
|
||||
raise NotImplementedError(
|
||||
"MTP don't support chunked_prefill now")
|
||||
token_chunk_size = request.prefill_chunk_info[0]
|
||||
self.model_inputs["seq_lens_encoder"][idx:idx +
|
||||
1] = token_chunk_size
|
||||
self.model_inputs["seq_lens_this_time"][
|
||||
idx:idx + 1] = token_chunk_size
|
||||
else:
|
||||
self.model_inputs["seq_lens_encoder"][idx:idx + 1] = length
|
||||
self.model_inputs["seq_lens_decoder"][idx:idx + 1] = (
|
||||
request.get("seq_lens_decoder", 0))
|
||||
self.model_inputs["seq_lens_this_time"][idx:idx +
|
||||
1] = length
|
||||
|
||||
self.model_inputs["seq_lens_decoder"][idx:idx +
|
||||
1] = (request.get(
|
||||
"seq_lens_decoder",
|
||||
0))
|
||||
self.model_inputs["stop_flags"][idx:idx + 1] = False
|
||||
self.model_inputs["batch_drop"][idx:idx + 1] = False
|
||||
|
||||
@@ -578,7 +582,6 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["output_padding_offset"],
|
||||
self.parallel_config.max_model_len,
|
||||
)
|
||||
paddle.device.synchronize()
|
||||
|
||||
# 4. Compute logits, Sample
|
||||
logits = self.model.compute_logits(hiddden_states)
|
||||
@@ -595,6 +598,43 @@ class MTPProposer(Proposer):
|
||||
|
||||
self._post_process(sampled_token_ids)
|
||||
|
||||
def update_task_chunk_prefill(self, task):
|
||||
"""
|
||||
Update single task's chunk_prefill info
|
||||
"""
|
||||
idx = task.idx
|
||||
start_idx = sum(task.prefill_chunk_info[:task.chunk_idx])
|
||||
|
||||
if task.chunk_idx == len(task.prefill_chunk_info):
|
||||
self.model_inputs['seq_lens_encoder'][idx:idx + 1] = 0
|
||||
self.model_inputs["step_idx"][idx:idx + 1] = 1
|
||||
self.model_inputs["seq_lens_decoder"][idx:idx +
|
||||
1] = start_idx + task.get(
|
||||
"seq_lens_decoder", 0)
|
||||
else:
|
||||
token_chunk_size = task.prefill_chunk_info[task.chunk_idx]
|
||||
|
||||
if task.chunk_idx < len(task.prefill_chunk_info) - 1:
|
||||
self.model_inputs['input_ids'][
|
||||
idx, :token_chunk_size] = np.array(
|
||||
task.prompt_token_ids[start_idx + 1:start_idx +
|
||||
token_chunk_size + 1])
|
||||
# Last prefill
|
||||
else:
|
||||
self.model_inputs['input_ids'][
|
||||
idx, :token_chunk_size - 1] = np.array(
|
||||
task.prompt_token_ids[start_idx + 1:start_idx +
|
||||
token_chunk_size])
|
||||
|
||||
self.model_inputs["seq_lens_this_time"][idx:idx +
|
||||
1] = token_chunk_size
|
||||
self.model_inputs['seq_lens_encoder'][idx:idx +
|
||||
1] = token_chunk_size
|
||||
self.model_inputs["step_idx"][idx:idx + 1] = 0
|
||||
self.model_inputs["seq_lens_decoder"][idx:idx +
|
||||
1] = start_idx + task.get(
|
||||
"seq_lens_decoder", 0)
|
||||
|
||||
def _update_status(self):
|
||||
"""
|
||||
Update main-model's forward info in next step.
|
||||
@@ -624,6 +664,11 @@ class MTPProposer(Proposer):
|
||||
)
|
||||
|
||||
def _run_impl(self, full_hidden_states):
|
||||
""""""
|
||||
target_hidden_states = self._prepare_inputs(full_hidden_states)
|
||||
self._propose(target_hidden_states=target_hidden_states)
|
||||
self._update_status()
|
||||
|
||||
def is_chunk_prefill_enabled(self):
|
||||
""""""
|
||||
return True
|
||||
|
Reference in New Issue
Block a user