diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc index 78eb6c1d4..4e547d297 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc @@ -28,15 +28,15 @@ #define MAX_DRAFT_TOKEN_NUM 6 struct batch_msgdata { - int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)]; - float scores[MAX_DRAFT_TOKEN_NUM * (K + 1)]; - int ranks[MAX_DRAFT_TOKEN_NUM]; + int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)]; + float scores[MAX_DRAFT_TOKEN_NUM * (K + 1)]; + int ranks[MAX_DRAFT_TOKEN_NUM]; }; struct msgdata { - long mtype; - int meta[3 + MAX_BSZ]; // stop_flag, message_flag, bsz, batch_token_nums - batch_msgdata mtext[MAX_BSZ]; + long mtype; + int meta[3 + MAX_BSZ]; // stop_flag, message_flag, bsz, batch_token_nums + batch_msgdata mtext[MAX_BSZ]; }; void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, @@ -46,146 +46,150 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, const paddle::Tensor& token_num_per_batch, const paddle::Tensor& cu_batch_token_offset, const paddle::Tensor& not_need_stop, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& prompt_lens, int message_flag, // Target: 3, Draft: 4 - int64_t rank_id) { - if (rank_id > 0) { - return; - } - auto sampled_token_ids_cpu = - sampled_token_ids.copy_to(paddle::CPUPlace(), false); - auto logprob_token_ids_cpu = - logprob_token_ids.copy_to(paddle::CPUPlace(), false); - auto logprob_scores_cpu = logprob_scores.copy_to(paddle::CPUPlace(), false); - auto logprob_ranks_cpu = logprob_ranks.copy_to(paddle::CPUPlace(), false); - auto token_num_per_batch_cpu = - token_num_per_batch.copy_to(paddle::CPUPlace(), false); - auto cu_batch_token_offset_cpu = - cu_batch_token_offset.copy_to(paddle::CPUPlace(), false); - int64_t* sampled_token_ids_data = sampled_token_ids_cpu.data(); - int64_t* logprob_token_ids_data = logprob_token_ids_cpu.data(); - float* logprob_scores_data = logprob_scores_cpu.data(); - int64_t* logprob_ranks_data = logprob_ranks_cpu.data(); - int* token_num_per_batch_data = token_num_per_batch_cpu.data(); - int* cu_batch_token_offset_data = cu_batch_token_offset_cpu.data(); + int64_t rank_id, + bool save_each_rank) { + if (!save_each_rank && rank_id > 0) { + return; + } + auto sampled_token_ids_cpu = + sampled_token_ids.copy_to(paddle::CPUPlace(), false); + auto logprob_token_ids_cpu = + logprob_token_ids.copy_to(paddle::CPUPlace(), false); + auto logprob_scores_cpu = logprob_scores.copy_to(paddle::CPUPlace(), false); + auto logprob_ranks_cpu = logprob_ranks.copy_to(paddle::CPUPlace(), false); + auto token_num_per_batch_cpu = + token_num_per_batch.copy_to(paddle::CPUPlace(), false); + auto cu_batch_token_offset_cpu = + cu_batch_token_offset.copy_to(paddle::CPUPlace(), false); + auto seq_lens_decoder_cpu = + seq_lens_decoder.copy_to(paddle::CPUPlace(), true); + auto prompt_lens_cpu = prompt_lens.copy_to(paddle::CPUPlace(), true); + int64_t* sampled_token_ids_data = sampled_token_ids_cpu.data(); + int64_t* logprob_token_ids_data = logprob_token_ids_cpu.data(); + float* logprob_scores_data = logprob_scores_cpu.data(); + int64_t* logprob_ranks_data = logprob_ranks_cpu.data(); + int* token_num_per_batch_data = token_num_per_batch_cpu.data(); + int* cu_batch_token_offset_data = cu_batch_token_offset_cpu.data(); + int* seq_lens_decoder_data = seq_lens_decoder_cpu.data(); + int64_t* prompt_lens_data = prompt_lens_cpu.data(); - static struct msgdata msg_sed; - int msg_queue_id = 1; - if (const char* inference_msg_queue_id_env_p = - std::getenv("INFERENCE_MSG_QUEUE_ID")) { - std::string inference_msg_queue_id_env_str( - inference_msg_queue_id_env_p); - int inference_msg_queue_id_from_env = - std::stoi(inference_msg_queue_id_env_str); - msg_queue_id = inference_msg_queue_id_from_env; + static struct msgdata msg_sed; + int msg_queue_id = 1; + if (const char* inference_msg_queue_id_env_p = + std::getenv("INFERENCE_MSG_QUEUE_ID")) { + std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p); + int inference_msg_queue_id_from_env = + std::stoi(inference_msg_queue_id_env_str); + msg_queue_id = inference_msg_queue_id_from_env; #ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG - std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " - << inference_msg_queue_id_from_env << std::endl; + std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " + << inference_msg_queue_id_from_env << std::endl; #endif + } else { +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "Failed to got INFERENCE_MSG_QUEUE_ID at env, use default." + << std::endl; +#endif + } + int inference_msg_id_from_env = 1; + if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) { + std::string inference_msg_id_env_str(inference_msg_id_env_p); + inference_msg_id_from_env = std::stoi(inference_msg_id_env_str); + if (inference_msg_id_from_env == 2) { + // 2 and -2 is perserve for no-output indication. + throw std::runtime_error( + " INFERENCE_MSG_ID cannot be 2, please use other number."); + } + if (inference_msg_id_from_env < 0) { + throw std::runtime_error( + " INFERENCE_MSG_ID cannot be negative, please use other " + "number."); + } +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env + << std::endl; +#endif + } else { +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default." + << std::endl; +#endif + } + static key_t key = ftok("/dev/shm", msg_queue_id); + static int msgid = msgget(key, IPC_CREAT | 0666); +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "save_output_key: " << key << std::endl; + std::cout << "save msgid: " << msgid << std::endl; +#endif + msg_sed.mtype = 1; + msg_sed.meta[0] = not_need_stop.data()[0] ? inference_msg_id_from_env + : -inference_msg_id_from_env; + msg_sed.meta[1] = message_flag; + int bsz = token_num_per_batch.shape()[0]; + msg_sed.meta[2] = bsz; + int max_num_logprobs = logprob_token_ids.shape()[1]; + for (int i = 0; i < bsz; i++) { + int cur_token_num; + if (seq_lens_decoder_data[i] < prompt_lens_data[i]) { + cur_token_num = 0; } else { -#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG - std::cout << "Failed to got INFERENCE_MSG_QUEUE_ID at env, use default." - << std::endl; -#endif + cur_token_num = token_num_per_batch_data[i]; } - int inference_msg_id_from_env = 1; - if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) { - std::string inference_msg_id_env_str(inference_msg_id_env_p); - inference_msg_id_from_env = std::stoi(inference_msg_id_env_str); - if (inference_msg_id_from_env == 2) { - // 2 and -2 is perserve for no-output indication. - throw std::runtime_error( - " INFERENCE_MSG_ID cannot be 2, please use other number."); + msg_sed.meta[3 + i] = cur_token_num; + auto* cur_batch_msg_sed = &msg_sed.mtext[i]; + int token_offset = cu_batch_token_offset_data[i]; + for (int j = 0; j < cur_token_num; j++) { + auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (K + 1)]; + auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)]; + for (int k = 0; k < K + 1; k++) { + if (k == 0) { + cur_tokens[k] = (int)sampled_token_ids_data[token_offset + j]; + cur_scores[k] = logprob_scores_data[(token_offset + j) * (K + 1) + k]; + } else if (k < max_num_logprobs) { + cur_tokens[k] = + (int)logprob_token_ids_data[(token_offset + j) * (K + 1) + k]; + cur_scores[k] = logprob_scores_data[(token_offset + j) * (K + 1) + k]; + } else { + cur_tokens[k] = -1; + cur_scores[k] = 0.0; } - if (inference_msg_id_from_env < 0) { - throw std::runtime_error( - " INFERENCE_MSG_ID cannot be negative, please use other " - "number."); - } -#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG - std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env - << std::endl; -#endif - } else { -#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG - std::cout - << "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default." - << std::endl; -#endif + } + cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j]; } - static key_t key = ftok("/dev/shm", msg_queue_id); - static int msgid = msgget(key, IPC_CREAT | 0666); + } #ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG - std::cout << "save_output_key: " << key << std::endl; - std::cout << "save msgid: " << msgid << std::endl; + std::cout << "msg data: " << std::endl; + std::cout << "stop_flag: " << msg_sed.meta[0] + << ", message_flag: " << msg_sed.meta[1] + << ", bsz: " << msg_sed.meta[2] << std::endl; + for (int i = 0; i < bsz; i++) { + int cur_token_num = msg_sed.meta[3 + i]; + auto* cur_batch_msg_sed = &msg_sed.mtext[i]; + std::cout << "batch " << i << " token_num: " << cur_token_num << std::endl; + for (int j = 0; j < cur_token_num; j++) { + auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (K + 1)]; + auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)]; + std::cout << "tokens: "; + for (int k = 0; k < K + 1; k++) { + std::cout << cur_tokens[k] << " "; + } + std::cout << std::endl; + std::cout << "scores: "; + for (int k = 0; k < K + 1; k++) { + std::cout << cur_scores[k] << " "; + } + std::cout << std::endl; + std::cout << "ranks: " << cur_batch_msg_sed->ranks[j] << std::endl; + } + } + std::cout << std::endl; #endif - msg_sed.mtype = 1; - msg_sed.meta[0] = not_need_stop.data()[0] - ? inference_msg_id_from_env - : -inference_msg_id_from_env; - msg_sed.meta[1] = message_flag; - int bsz = token_num_per_batch.shape()[0]; - msg_sed.meta[2] = bsz; - int max_num_logprobs = logprob_token_ids.shape()[1]; - for (int i = 0; i < bsz; i++) { - int cur_token_num = token_num_per_batch_data[i]; - msg_sed.meta[3 + i] = cur_token_num; - auto* cur_batch_msg_sed = &msg_sed.mtext[i]; - int token_offset = cu_batch_token_offset_data[i]; - for (int j = 0; j < cur_token_num; j++) { - auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (K + 1)]; - auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)]; - for (int k = 0; k < K + 1; k++) { - if (k == 0) { - cur_tokens[k] = - (int)sampled_token_ids_data[token_offset + j]; - cur_scores[k] = - logprob_scores_data[(token_offset + j) * (K + 1) + k]; - } else if (k < max_num_logprobs) { - cur_tokens[k] = (int) - logprob_token_ids_data[(token_offset + j) * (K + 1) + - k]; - cur_scores[k] = - logprob_scores_data[(token_offset + j) * (K + 1) + k]; - } else { - cur_tokens[k] = -1; - cur_scores[k] = 0.0; - } - } - cur_batch_msg_sed->ranks[j] = - (int)logprob_ranks_data[token_offset + j]; - } - } -#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG - std::cout << "msg data: " << std::endl; - std::cout << "stop_flag: " << msg_sed.meta[0] - << ", message_flag: " << msg_sed.meta[1] - << ", bsz: " << msg_sed.meta[2] << std::endl; - for (int i = 0; i < bsz; i++) { - int cur_token_num = msg_sed.meta[3 + i]; - auto* cur_batch_msg_sed = &msg_sed.mtext[i]; - std::cout << "batch " << i << " token_num: " << cur_token_num - << std::endl; - for (int j = 0; j < cur_token_num; j++) { - auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (K + 1)]; - auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)]; - std::cout << "tokens: "; - for (int k = 0; k < K + 1; k++) { - std::cout << cur_tokens[k] << " "; - } - std::cout << std::endl; - std::cout << "scores: "; - for (int k = 0; k < K + 1; k++) { - std::cout << cur_scores[k] << " "; - } - std::cout << std::endl; - std::cout << "ranks: " << cur_batch_msg_sed->ranks[j] << std::endl; - } - } - std::cout << std::endl; -#endif - if (msgsnd(msgid, &msg_sed, sizeof(msg_sed) - sizeof(long), 0) == -1) { - printf("full msg buffer\n"); - } + if (msgsnd(msgid, &msg_sed, sizeof(msg_sed) - sizeof(long), 0) == -1) { + printf("full msg buffer\n"); + } } PD_BUILD_STATIC_OP(speculate_save_output_topk) @@ -197,6 +201,8 @@ PD_BUILD_STATIC_OP(speculate_save_output_topk) "token_num_per_batch", "cu_batch_token_offset", "not_need_stop", + "seq_lens_decoder", + "prompt_lens", }) - .Attrs({"message_flag: int", "rank_id: int64_t"}) + .Attrs({"message_flag: int", "rank_id: int64_t", "save_each_rank: bool"}) .SetKernelFn(PD_KERNEL(SpeculateSaveOutMmsgTopK)); diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index e25e9360f..411eec7d1 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -514,8 +514,11 @@ def post_process_specualate( sampler_output.token_num_per_batch, sampler_output.cu_batch_token_offset, model_output.not_need_stop, + model_output.seq_lens_decoder, + model_output.prompt_lens, 3, # mtype model_output.mp_rank, + save_each_rank, ) # Update pre_ids through accept tokens diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 4aa3ab307..529e1f4a7 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -830,6 +830,8 @@ class TokenProcessor: num_accepted_tokens = sum([x - 1 for x in real_accept_num]) self.num_accepted_tokens += num_accepted_tokens num_emitted_tokens = sum(real_accept_num) + if num_emitted_tokens == 0: + return self.num_emitted_tokens += num_emitted_tokens main_process_metrics.spec_decode_num_accepted_tokens_total.inc(num_accepted_tokens) diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 611c3ab5f..437e89bd7 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -338,6 +338,7 @@ class MTPProposer(Proposer): self.model_inputs["seq_lens_encoder"] = paddle.clone(self.target_model_inputs["seq_lens_encoder"]) self.model_inputs["seq_lens_decoder"] = paddle.clone(self.target_model_inputs["seq_lens_decoder"]) + self.model_inputs["prompt_lens"] = paddle.clone(self.target_model_inputs["prompt_lens"]) self.model_inputs["step_idx"] = paddle.clone(self.target_model_inputs["step_idx"]) self.model_inputs["stop_flags"] = paddle.clone(self.target_model_inputs["stop_flags"]) self.model_inputs["stop_nums"] = paddle.clone(self.target_model_inputs["stop_nums"]) @@ -766,7 +767,7 @@ class MTPProposer(Proposer): self.model_inputs["step_idx"], ) - def _propose(self, step_use_cudagraph: bool = False): + def _propose(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False): """ Main process for MTP inference. Args: @@ -891,7 +892,12 @@ class MTPProposer(Proposer): self.model_inputs, ) - if substep == 0 and sampler_output.logprobs_tensors is not None: + if ( + not is_dummy_run + and self.parallel_config.tensor_parallel_rank == 0 + and substep == 0 + and sampler_output.logprobs_tensors is not None + ): real_bsz = self.model_inputs["seq_lens_this_time"].shape[0] speculate_save_output_topk( sampler_output.sampled_token_ids, @@ -901,8 +907,11 @@ class MTPProposer(Proposer): self.model_inputs["batch_token_num"][:real_bsz], self.model_inputs["cu_batch_token_offset"][:real_bsz], self.model_inputs["not_need_stop"], + self.model_inputs["seq_lens_decoder"], + self.model_inputs["prompt_lens"], 4, # mtype self.local_rank, + self.parallel_config.use_ep, ) if self.parallel_config.tensor_parallel_size > 1: @@ -1009,10 +1018,12 @@ class MTPProposer(Proposer): self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda() self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda() - def _run_impl(self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False): + def _run_impl( + self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False, is_dummy_run: bool = False + ): """Execute Draft Model""" self._prepare_inputs(full_hidden_states) - self._propose(step_use_cudagraph=step_use_cudagraph) + self._propose(step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run) self._update_status() if self.hybrid_mode: self._extend_draft_token_with_ngram_match() diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 81f44ff81..f824b48a4 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1772,7 +1772,9 @@ class GPUModelRunner(ModelRunnerBase): if self.speculative_decoding: if self.speculative_method == "mtp": self.proposer.run( - full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph + full_hidden_states=model_output, + step_use_cudagraph=self.forward_meta.step_use_cudagraph, + is_dummy_run=True, ) else: self.proposer.run(share_inputs=self.share_inputs)