diff --git a/fastdeploy/spec_decode/base.py b/fastdeploy/spec_decode/base.py index aa9950ef5..86eaabc1f 100644 --- a/fastdeploy/spec_decode/base.py +++ b/fastdeploy/spec_decode/base.py @@ -61,3 +61,13 @@ class Proposer(ABC): Implemention for different method """ raise NotImplementedError + + def is_chunk_prefill_enabled(self) -> bool: + """ + Check whether chunk-based prefill is enabled. + Default is False. + + Returns: + bool: True if chunk prefill is enabled; False otherwise. + """ + return False diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 264656cbf..ba50d1f89 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -405,17 +405,21 @@ class MTPProposer(Proposer): 1:length] self.model_inputs["pre_ids"][idx:idx + 1] = -1 self.model_inputs["step_idx"][idx:idx + 1] = 0 - # TODO(liuzichang) finish chunked_prefill if self.parallel_config.enable_chunked_prefill: - raise NotImplementedError( - "MTP don't support chunked_prefill now") + token_chunk_size = request.prefill_chunk_info[0] + self.model_inputs["seq_lens_encoder"][idx:idx + + 1] = token_chunk_size + self.model_inputs["seq_lens_this_time"][ + idx:idx + 1] = token_chunk_size else: self.model_inputs["seq_lens_encoder"][idx:idx + 1] = length - self.model_inputs["seq_lens_decoder"][idx:idx + 1] = ( - request.get("seq_lens_decoder", 0)) self.model_inputs["seq_lens_this_time"][idx:idx + 1] = length + self.model_inputs["seq_lens_decoder"][idx:idx + + 1] = (request.get( + "seq_lens_decoder", + 0)) self.model_inputs["stop_flags"][idx:idx + 1] = False self.model_inputs["batch_drop"][idx:idx + 1] = False @@ -578,7 +582,6 @@ class MTPProposer(Proposer): self.model_inputs["output_padding_offset"], self.parallel_config.max_model_len, ) - paddle.device.synchronize() # 4. Compute logits, Sample logits = self.model.compute_logits(hiddden_states) @@ -595,6 +598,43 @@ class MTPProposer(Proposer): self._post_process(sampled_token_ids) + def update_task_chunk_prefill(self, task): + """ + Update single task's chunk_prefill info + """ + idx = task.idx + start_idx = sum(task.prefill_chunk_info[:task.chunk_idx]) + + if task.chunk_idx == len(task.prefill_chunk_info): + self.model_inputs['seq_lens_encoder'][idx:idx + 1] = 0 + self.model_inputs["step_idx"][idx:idx + 1] = 1 + self.model_inputs["seq_lens_decoder"][idx:idx + + 1] = start_idx + task.get( + "seq_lens_decoder", 0) + else: + token_chunk_size = task.prefill_chunk_info[task.chunk_idx] + + if task.chunk_idx < len(task.prefill_chunk_info) - 1: + self.model_inputs['input_ids'][ + idx, :token_chunk_size] = np.array( + task.prompt_token_ids[start_idx + 1:start_idx + + token_chunk_size + 1]) + # Last prefill + else: + self.model_inputs['input_ids'][ + idx, :token_chunk_size - 1] = np.array( + task.prompt_token_ids[start_idx + 1:start_idx + + token_chunk_size]) + + self.model_inputs["seq_lens_this_time"][idx:idx + + 1] = token_chunk_size + self.model_inputs['seq_lens_encoder'][idx:idx + + 1] = token_chunk_size + self.model_inputs["step_idx"][idx:idx + 1] = 0 + self.model_inputs["seq_lens_decoder"][idx:idx + + 1] = start_idx + task.get( + "seq_lens_decoder", 0) + def _update_status(self): """ Update main-model's forward info in next step. @@ -624,6 +664,11 @@ class MTPProposer(Proposer): ) def _run_impl(self, full_hidden_states): + """""" target_hidden_states = self._prepare_inputs(full_hidden_states) self._propose(target_hidden_states=target_hidden_states) self._update_status() + + def is_chunk_prefill_enabled(self): + """""" + return True diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index c13f232d3..8d6ca79a1 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -898,6 +898,9 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["step_idx"][idx:idx + 1] = 0 self.share_inputs["seq_lens_decoder"][ idx:idx + 1] = start_idx + task.get("seq_lens_decoder", 0) + if self.speculative_decoding and self.proposer.is_chunk_prefill_enabled( + ): + self.proposer.update_task_chunk_prefill(task) task.chunk_idx += 1 def _dummy_sampler_run(self) -> paddle.Tensor: