diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc index 6440c65a8..1addce9c2 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc @@ -46,9 +46,12 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, const paddle::Tensor& token_num_per_batch, const paddle::Tensor& cu_batch_token_offset, const paddle::Tensor& not_need_stop, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& prompt_lens, int message_flag, // Target: 3, Draft: 4 - int64_t rank_id) { - if (rank_id > 0) { + int64_t rank_id, + bool save_each_rank) { + if (!save_each_rank && rank_id > 0) { return; } auto sampled_token_ids_cpu = @@ -61,12 +64,17 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, token_num_per_batch.copy_to(paddle::CPUPlace(), false); auto cu_batch_token_offset_cpu = cu_batch_token_offset.copy_to(paddle::CPUPlace(), false); + auto seq_lens_decoder_cpu = + seq_lens_decoder.copy_to(paddle::CPUPlace(), true); + auto prompt_lens_cpu = prompt_lens.copy_to(paddle::CPUPlace(), true); int64_t* sampled_token_ids_data = sampled_token_ids_cpu.data(); int64_t* logprob_token_ids_data = logprob_token_ids_cpu.data(); float* logprob_scores_data = logprob_scores_cpu.data(); int64_t* logprob_ranks_data = logprob_ranks_cpu.data(); int* token_num_per_batch_data = token_num_per_batch_cpu.data(); int* cu_batch_token_offset_data = cu_batch_token_offset_cpu.data(); + int* seq_lens_decoder_data = seq_lens_decoder_cpu.data(); + int64_t* prompt_lens_data = prompt_lens_cpu.data(); static struct msgdata msg_sed; int msg_queue_id = 1; @@ -127,7 +135,12 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, msg_sed.meta[2] = bsz; int max_num_logprobs = logprob_token_ids.shape()[1]; for (int i = 0; i < bsz; i++) { - int cur_token_num = token_num_per_batch_data[i]; + int cur_token_num; + if (seq_lens_decoder_data[i] < prompt_lens_data[i]) { + cur_token_num = 0; + } else { + cur_token_num = token_num_per_batch_data[i]; + } msg_sed.meta[3 + i] = cur_token_num; auto* cur_batch_msg_sed = &msg_sed.mtext[i]; int token_offset = cu_batch_token_offset_data[i]; @@ -198,6 +211,8 @@ PD_BUILD_STATIC_OP(speculate_save_output_topk) "token_num_per_batch", "cu_batch_token_offset", "not_need_stop", + "seq_lens_decoder", + "prompt_lens", }) - .Attrs({"message_flag: int", "rank_id: int64_t"}) + .Attrs({"message_flag: int", "rank_id: int64_t", "save_each_rank: bool"}) .SetKernelFn(PD_KERNEL(SpeculateSaveOutMmsgTopK)); diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 54218ffb8..18ba48156 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -349,8 +349,11 @@ def post_process_specualate( sampler_output.token_num_per_batch, sampler_output.cu_batch_token_offset, model_output.not_need_stop, + model_output.seq_lens_decoder, + model_output.prompt_lens, 3, # mtype model_output.mp_rank, + save_each_rank, ) # Update pre_ids through accept tokens diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index f7825b212..0861ccb8c 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -584,6 +584,8 @@ class TokenProcessor: num_accepted_tokens = sum([x - 1 for x in real_accept_num]) self.num_accepted_tokens += num_accepted_tokens num_emitted_tokens = sum(real_accept_num) + if num_emitted_tokens == 0: + return self.num_emitted_tokens += num_emitted_tokens main_process_metrics.spec_decode_num_accepted_tokens_total.inc(num_accepted_tokens) diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 02a48d394..5fe2b39ac 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -313,6 +313,7 @@ class MTPProposer(Proposer): self.model_inputs["seq_lens_encoder"] = paddle.clone(self.target_model_inputs["seq_lens_encoder"]) self.model_inputs["seq_lens_decoder"] = paddle.clone(self.target_model_inputs["seq_lens_decoder"]) + self.model_inputs["prompt_lens"] = paddle.clone(self.target_model_inputs["prompt_lens"]) self.model_inputs["step_idx"] = paddle.clone(self.target_model_inputs["step_idx"]) self.model_inputs["stop_flags"] = paddle.clone(self.target_model_inputs["stop_flags"]) self.model_inputs["stop_nums"] = paddle.clone(self.target_model_inputs["stop_nums"]) @@ -705,7 +706,7 @@ class MTPProposer(Proposer): self.parallel_config.use_ep, ) - def _propose(self, step_use_cudagraph: bool = False): + def _propose(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False): """ Main process for MTP inference. Args: @@ -830,7 +831,12 @@ class MTPProposer(Proposer): self.model_inputs, ) - if substep == 0 and sampler_output.logprobs_tensors is not None: + if ( + not is_dummy_run + and self.parallel_config.tensor_parallel_rank == 0 + and substep == 0 + and sampler_output.logprobs_tensors is not None + ): real_bsz = self.model_inputs["seq_lens_this_time"].shape[0] speculate_save_output_topk( sampler_output.sampled_token_ids, @@ -840,8 +846,11 @@ class MTPProposer(Proposer): self.model_inputs["batch_token_num"][:real_bsz], self.model_inputs["cu_batch_token_offset"][:real_bsz], self.model_inputs["not_need_stop"], + self.model_inputs["seq_lens_decoder"], + self.model_inputs["prompt_lens"], 4, # mtype self.local_rank, + self.parallel_config.use_ep, ) if self.parallel_config.tensor_parallel_size > 1: @@ -949,10 +958,12 @@ class MTPProposer(Proposer): self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda() self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda() - def _run_impl(self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False): + def _run_impl( + self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False, is_dummy_run: bool = False + ): """Execute Draft Model""" self._prepare_inputs(full_hidden_states) - self._propose(step_use_cudagraph=step_use_cudagraph) + self._propose(step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run) self._update_status() if self.hybrid_mode: self._extend_draft_token_with_ngram_match() diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index a1461b91c..496e04b33 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1281,7 +1281,9 @@ class GPUModelRunner(ModelRunnerBase): if self.speculative_decoding: if self.speculative_method == "mtp": self.proposer.run( - full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph + full_hidden_states=model_output, + step_use_cudagraph=self.forward_meta.step_use_cudagraph, + is_dummy_run=True, ) else: self.proposer.run(share_inputs=self.share_inputs)