From 582aebd48b4ab6e3c1ba745191a50268157f7a1a Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Wed, 15 Oct 2025 13:21:32 +0800 Subject: [PATCH] [MTP]support mtp chunk_prefill_v1 (#4366) * support mtp chunk_prefill_v1 * fix mtp chunkprefill output, fix unit test * fix unit test * fix save_output --- custom_ops/gpu_ops/cpp_extensions.cc | 7 ++- .../speculate_save_output.cc | 34 +++++++++---- .../speculate_schedule_cache.cu | 50 ++++++++++++++----- .../speculate_set_value_by_flags_and_idx.cu | 46 +++++++++-------- .../speculate_decoding/speculate_update.cu | 3 -- fastdeploy/engine/args_utils.py | 6 +-- .../model_executor/pre_and_post_process.py | 6 +-- fastdeploy/worker/gpu_model_runner.py | 4 ++ fastdeploy/worker/output.py | 5 ++ tests/operators/test_speculate_update.py | 3 -- .../test_speculative_schedule_cache.py | 12 +++++ 11 files changed, 118 insertions(+), 58 deletions(-) diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 57d6201ef..eca815656 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -709,8 +709,11 @@ void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all, void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens, const paddle::Tensor& accept_num, const paddle::Tensor& not_need_stop, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& prompt_lens, int64_t rank_id, - bool save_each_rank); + bool save_each_rank, + bool skip_prefill); void SpeculateClearAcceptNums(const paddle::Tensor& accept_num, @@ -719,7 +722,9 @@ void SpeculateClearAcceptNums(const paddle::Tensor& accept_num, void SpeculateScheduleCache(const paddle::Tensor &draft_tokens, const paddle::Tensor &block_tables, const paddle::Tensor &stop_flags, + const paddle::Tensor &prompt_lens, const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &step_seq_lens_decoder, const paddle::Tensor &step_draft_tokens, diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output.cc index df9312281..6d34b9736 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output.cc +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output.cc @@ -28,9 +28,12 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens, const paddle::Tensor& accept_num, const paddle::Tensor& not_need_stop, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& prompt_lens, int64_t rank_id, int msg_queue_id, - int save_each_rank) { + int save_each_rank, + bool skip_prefill) { // printf("enter save output"); if (!save_each_rank && rank_id > 0) { return; @@ -43,6 +46,11 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens, int64_t* accept_tokens_data = accept_tokens_cpu.data(); int* accept_num_data = accept_num_cpu.data(); + auto seq_lens_decoder_cpu = seq_lens_decoder.copy_to(paddle::CPUPlace(), true); + auto prompt_lens_cpu = prompt_lens.copy_to(paddle::CPUPlace(), true); + int* seq_lens_decoder_data = seq_lens_decoder_cpu.data(); + int64_t* prompt_lens_data = prompt_lens_cpu.data(); + if (const char* inference_msg_queue_id_env_p = std::getenv("INFERENCE_MSG_QUEUE_ID")) { std::string inference_msg_queue_id_env_str( @@ -95,7 +103,7 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens, msg_sed.mtext[1] = bsz; for (int i = 2; i < MAX_BSZ + 2; i++) { - if (i - 2 >= bsz) { + if (i - 2 >= bsz || (skip_prefill && seq_lens_decoder_data[i - 2] < prompt_lens_data[i - 2])) { msg_sed.mtext[i] = 0; } else { msg_sed.mtext[i] = (int)accept_num_data[i - 2]; @@ -125,32 +133,38 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens, void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens, const paddle::Tensor& accept_num, const paddle::Tensor& not_need_stop, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& prompt_lens, int64_t rank_id, - bool save_each_rank) { + bool save_each_rank, + bool skip_prefill) { SpeculateSaveWithOutputMsg( - accept_tokens, accept_num, not_need_stop, rank_id, 1, save_each_rank); + accept_tokens, accept_num, not_need_stop, seq_lens_decoder, prompt_lens, rank_id, 1, save_each_rank, skip_prefill); } void SpeculateSaveWithOutputMsgDynamic(const paddle::Tensor& accept_tokens, const paddle::Tensor& accept_num, const paddle::Tensor& not_need_stop, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& prompt_lens, int64_t rank_id, int msg_queue_id, - bool save_each_rank) { + bool save_each_rank, + bool skip_prefill) { SpeculateSaveWithOutputMsg( - accept_tokens, accept_num, not_need_stop, rank_id, msg_queue_id, save_each_rank); + accept_tokens, accept_num, not_need_stop, seq_lens_decoder, prompt_lens, rank_id, msg_queue_id, save_each_rank, skip_prefill); } PD_BUILD_STATIC_OP(speculate_save_output) - .Inputs({"accept_tokens", "accept_num", "not_need_stop"}) - .Attrs({"rank_id: int64_t", "save_each_rank: bool"}) + .Inputs({"accept_tokens", "accept_num", "not_need_stop", "seq_lens_decoder", "prompt_lens"}) + .Attrs({"rank_id: int64_t", "save_each_rank: bool", "skip_prefill: bool"}) .Outputs({"x_out"}) .SetInplaceMap({{"accept_tokens", "x_out"}}) .SetKernelFn(PD_KERNEL(SpeculateSaveWithOutputMsgStatic)); PD_BUILD_STATIC_OP(speculate_save_output_dynamic) - .Inputs({"accept_tokens", "accept_num", "not_need_stop"}) - .Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool"}) + .Inputs({"accept_tokens", "accept_num", "not_need_stop", "seq_lens_decoder", "prompt_lens"}) + .Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool", "skip_prefill: bool"}) .Outputs({"x_out"}) .SetInplaceMap({{"accept_tokens", "x_out"}}) .SetKernelFn(PD_KERNEL(SpeculateSaveWithOutputMsgDynamic)); diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_schedule_cache.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_schedule_cache.cu index 633c5bb4d..0f44293ea 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_schedule_cache.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_schedule_cache.cu @@ -19,7 +19,9 @@ __global__ void speculate_schedula_cache( const int64_t *draft_tokens, int *block_tables, bool *stop_flags, + const int64_t* prompt_lens, int *seq_lens_this_time, + int *seq_lens_encoder, int *seq_lens_decoder, int *step_seq_lens_decoder, int64_t *step_draft_tokens, @@ -44,23 +46,37 @@ __global__ void speculate_schedula_cache( int64_t *step_draft_tokens_now = step_draft_tokens + bid * draft_tokens_len; int *block_table_now = block_tables + bid * block_num_per_seq; int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len; - const int max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) / block_size; - if (max_possible_block_idx < block_num_per_seq && block_table_now[max_possible_block_idx] == -1) { - is_block_step[bid] = true; - step_seq_lens_this_time[bid] = seq_lens_this_time[bid]; - seq_lens_this_time[bid] = 0; + + if (seq_lens_decoder[bid] >= prompt_lens[bid]) { + // decoder + const int max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) / block_size; + if (max_possible_block_idx < block_num_per_seq && block_table_now[max_possible_block_idx] == -1) { + is_block_step[bid] = true; + step_seq_lens_this_time[bid] = seq_lens_this_time[bid]; + seq_lens_this_time[bid] = 0; + stop_flags[bid] = true; + stop_flag_now_int = 1; + step_seq_lens_decoder[bid] = seq_lens_decoder[bid]; + seq_lens_decoder[bid] = 0; + accept_num[bid] = 0; + for (int i = 0; i < accept_tokens_len; i++) { + accept_tokens_now[i] = -1; + } + for (int i = 0; i < draft_tokens_len; i++) { + step_draft_tokens_now[i] = draft_tokens_now[i]; + } + } + } else { + // prefill stop_flags[bid] = true; - stop_flag_now_int = 1; - step_seq_lens_decoder[bid] = seq_lens_decoder[bid]; + seq_lens_this_time[bid] = 0; seq_lens_decoder[bid] = 0; + seq_lens_encoder[bid] = 0; accept_num[bid] = 0; - for (int i = 0; i < accept_tokens_len; i++) { - accept_tokens_now[i] = -1; - } - for (int i = 0; i < draft_tokens_len; i++) { - step_draft_tokens_now[i] = draft_tokens_now[i]; - } + stop_flag_now_int = 1; } + + } else { stop_flag_now_int = 1; } @@ -83,7 +99,9 @@ __global__ void speculate_schedula_cache( void SpeculateScheduleCache(const paddle::Tensor &draft_tokens, const paddle::Tensor &block_tables, const paddle::Tensor &stop_flags, + const paddle::Tensor &prompt_lens, const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &step_seq_lens_decoder, const paddle::Tensor &step_draft_tokens, @@ -109,7 +127,9 @@ void SpeculateScheduleCache(const paddle::Tensor &draft_tokens, draft_tokens.data(), const_cast(block_tables.data()), const_cast(stop_flags.data()), + prompt_lens.data(), const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), const_cast(seq_lens_decoder.data()), const_cast(step_seq_lens_decoder.data()), const_cast(step_draft_tokens.data()), @@ -138,7 +158,9 @@ PD_BUILD_STATIC_OP(speculate_schedule_cache) .Inputs({"draft_tokens", "block_tables", "stop_flags", + "prompt_lens", "seq_lens_this_time", + "seq_lens_encoder", "seq_lens_decoder", "step_seq_lens_decoder", "step_draft_tokens", @@ -153,6 +175,7 @@ PD_BUILD_STATIC_OP(speculate_schedule_cache) "block_tables_out", "stop_flags_out", "seq_lens_this_time_out", + "seq_lens_encoder_out", "seq_lens_decoder_out", "step_seq_lens_decoder_out", "step_draft_tokens_out", @@ -165,6 +188,7 @@ PD_BUILD_STATIC_OP(speculate_schedule_cache) {"block_tables", "block_tables_out"}, {"stop_flags", "stop_flags_out"}, {"seq_lens_this_time", "seq_lens_this_time_out"}, + {"seq_lens_encoder", "seq_lens_encoder_out"}, {"seq_lens_decoder", "seq_lens_decoder_out"}, {"step_seq_lens_decoder", "step_seq_lens_decoder_out"}, {"step_draft_tokens", "step_draft_tokens_out"}, diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_set_value_by_flags_and_idx.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_set_value_by_flags_and_idx.cu index 4b1c7747e..d1ee733fe 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_set_value_by_flags_and_idx.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_set_value_by_flags_and_idx.cu @@ -20,30 +20,33 @@ __global__ void speculate_set_value_by_flag_and_id(int64_t *pre_ids_all, const int64_t *accept_tokens, - const int *accept_num, + int *accept_num, const bool *stop_flags, const int *seq_lens_encoder, - const int *seq_lens_decoder, + int *seq_lens_decoder, const int64_t *step_idx, int bs, int length, int max_draft_tokens) { int tid = threadIdx.x; - if (tid < bs && !stop_flags[tid]) { - int64_t *pre_ids_all_now = pre_ids_all + tid * length; - const int64_t *accept_tokens_now = - accept_tokens + tid * max_draft_tokens; - const int seq_len_dec = seq_lens_decoder[tid]; - const int seq_len_enc = seq_lens_encoder[tid]; - if (seq_len_dec == 0 && seq_len_enc == 0) return; // stopped - // printf("step_idx[tid] %d\n", step_idx[tid]); - if (step_idx[tid] >= 0) { - for (int i = 0; i < accept_num[tid]; i++) { - pre_ids_all_now[step_idx[tid] - i] = - accept_tokens_now[accept_num[tid] - 1 - i]; - // printf("pre_ids_all_now[step_idx[tid] - i] %d \n", - // pre_ids_all_now[step_idx[tid] - i]); + + if (tid < bs) { + if (!stop_flags[tid]) { + int64_t *pre_ids_all_now = pre_ids_all + tid * length; + const int64_t *accept_tokens_now = + accept_tokens + tid * max_draft_tokens; + const int seq_len_dec = seq_lens_decoder[tid]; + const int seq_len_enc = seq_lens_encoder[tid]; + if (seq_len_dec == 0 && seq_len_enc == 0) return; // stoped + if (step_idx[tid] >= 0) { + for (int i = 0; i < accept_num[tid]; i++) { + pre_ids_all_now[step_idx[tid] - i] = + accept_tokens_now[accept_num[tid] - 1 - i]; + } } + } else { + accept_num[tid] = 0; + seq_lens_decoder[tid] = 0; } } } @@ -67,10 +70,10 @@ void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all, speculate_set_value_by_flag_and_id<<<1, block_size, 0, cu_stream>>>( const_cast(pre_ids_all.data()), accept_tokens.data(), - accept_num.data(), + const_cast(accept_num.data()), stop_flags.data(), seq_lens_encoder.data(), - seq_lens_decoder.data(), + const_cast(seq_lens_decoder.data()), step_idx.data(), bs, length, @@ -86,6 +89,9 @@ PD_BUILD_STATIC_OP(speculate_set_value_by_flags_and_idx) "seq_lens_encoder", "seq_lens_decoder", "step_idx"}) - .Outputs({"pre_ids_all_out"}) - .SetInplaceMap({{"pre_ids_all", "pre_ids_all_out"}}) + .Outputs({"pre_ids_all_out", "accept_num_out", "seq_lens_decoder_out"}) + .SetInplaceMap({ + {"pre_ids_all", "pre_ids_all_out"}, + {"accept_num", "accept_num_out"}, + {"seq_lens_decoder", "seq_lens_decoder_out"}}) .SetKernelFn(PD_KERNEL(SpeculateSetValueByFlagsAndIdx)); diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_update.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_update.cu index 48d6557ad..828dc1728 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_update.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_update.cu @@ -71,9 +71,6 @@ __global__ void speculate_update(int *seq_lens_encoder, } draft_tokens[bid * max_draft_tokens] = accept_tokens[bid * max_draft_tokens + accept_num_now - 1]; - if (stop_flag_now_int) { - seq_lens_decoder[bid] = 0; - } } else if (bid >= real_bsz && bid < max_bsz) { stop_flag_now_int = 1; } diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 27b8eecaf..0091ab182 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -1026,11 +1026,7 @@ class EngineArgs: speculative_cfg = self.create_speculative_config() if not self.enable_chunked_prefill: - if ( - current_platform.is_cuda() - and self.splitwise_role == "mixed" - and (speculative_cfg is None or speculative_cfg.method not in ["mtp"]) - ): + if current_platform.is_cuda() and self.splitwise_role == "mixed": # default enable chunked prefill self.enable_chunked_prefill = True diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 65948ea7d..d695df0a7 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -64,7 +64,6 @@ else: save_output, save_output_topk, set_stop_value_multi_ends, - speculate_clear_accept_nums, speculate_get_output_padding_offset, speculate_get_padding_offset, speculate_get_seq_lens_output, @@ -369,12 +368,13 @@ def post_process_specualate( model_output.accept_tokens, model_output.accept_num, model_output.not_need_stop, + model_output.seq_lens_decoder, + model_output.prompt_lens, model_output.mp_rank, save_each_rank, + envs.ENABLE_V1_KVCACHE_SCHEDULER, ) - speculate_clear_accept_nums(model_output.accept_num, model_output.seq_lens_decoder) - # Update pre_ids through accept tokens speculate_set_value_by_flags_and_idx( diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 8d1be796f..bd39d6efb 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1460,6 +1460,7 @@ class GPUModelRunner(ModelRunnerBase): reasoning_index=self.share_inputs["reasoning_index"], stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], + prompt_lens=self.share_inputs["prompt_lens"], ) post_process( @@ -1814,6 +1815,7 @@ class GPUModelRunner(ModelRunnerBase): reasoning_index=self.share_inputs["reasoning_index"][:num_running_requests], stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], + prompt_lens=self.share_inputs["prompt_lens"], ) if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill": @@ -1860,7 +1862,9 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["draft_tokens"], self.share_inputs["block_tables"], self.share_inputs["stop_flags"], + self.share_inputs["prompt_lens"], self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_encoder"], self.share_inputs["seq_lens_decoder"], self.share_inputs["step_seq_lens_decoder"], self.share_inputs["step_draft_tokens"], diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index 6d820a873..2fa348634 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -250,6 +250,11 @@ class ModelOutputData: """ stop_seqs_len: paddle.Tensor = None + """ + the length of input prompt + """ + prompt_lens: paddle.Tensor = None + @dataclass class ModelRunnerOutput: diff --git a/tests/operators/test_speculate_update.py b/tests/operators/test_speculate_update.py index a9d5770c6..d3dcd7e7f 100644 --- a/tests/operators/test_speculate_update.py +++ b/tests/operators/test_speculate_update.py @@ -70,9 +70,6 @@ def speculate_update_np( draft_tokens[bid, 0] = accept_tokens[bid, accept_num[bid] - 1] - if stop_flag_now_int: - seq_lens_decoder[bid] = 0 - elif inactive: stop_flag_now_int = 1 diff --git a/tests/operators/test_speculative_schedule_cache.py b/tests/operators/test_speculative_schedule_cache.py index 9c95ad203..50a0e7cab 100644 --- a/tests/operators/test_speculative_schedule_cache.py +++ b/tests/operators/test_speculative_schedule_cache.py @@ -10,7 +10,9 @@ def cpu_reference( draft_tokens, block_tables, stop_flags, + prompt_lens, seq_lens_this_time, + seq_lens_encoder, seq_lens_decoder, step_seq_lens_decoder, step_draft_tokens, @@ -101,7 +103,9 @@ class TestSpeculateScheduleCache(unittest.TestCase): self.block_tables = paddle.to_tensor(np.full((self.real_bsz, self.block_num_per_seq), -1, dtype=np.int32)) # stop_flags length is max_bsz, others are real_bsz self.stop_flags = paddle.to_tensor(np.array([False, True, False, False, False], dtype=np.bool_)) + self.prompt_lens = paddle.to_tensor(np.array([1, 1, 1], dtype=np.int64)) self.seq_lens_this_time = paddle.to_tensor(np.array([5, 6, 7], dtype=np.int32)) + self.seq_lens_encoder = paddle.to_tensor(np.array([1, 1, 1], dtype=np.int32)) self.seq_lens_decoder = paddle.to_tensor(np.array([1, 1, 10], dtype=np.int32)) # Will be filled by kernel for the triggering bids only @@ -129,7 +133,9 @@ class TestSpeculateScheduleCache(unittest.TestCase): self.np_draft_tokens = self.draft_tokens.numpy().copy() self.np_block_tables = self.block_tables.numpy().copy() self.np_stop_flags = self.stop_flags.numpy().copy() + self.np_prompt_lens = self.prompt_lens.numpy().copy() self.np_seq_lens_this_time = self.seq_lens_this_time.numpy().copy() + self.np_seq_lens_encoder = self.seq_lens_encoder.numpy().copy() self.np_seq_lens_decoder = self.seq_lens_decoder.numpy().copy() self.np_step_seq_lens_decoder = self.step_seq_lens_decoder.numpy().copy() self.np_step_draft_tokens = self.step_draft_tokens.numpy().copy() @@ -146,7 +152,9 @@ class TestSpeculateScheduleCache(unittest.TestCase): self.draft_tokens, self.block_tables, self.stop_flags, + self.prompt_lens, self.seq_lens_this_time, + self.seq_lens_encoder, self.seq_lens_decoder, self.step_seq_lens_decoder, self.step_draft_tokens, @@ -165,7 +173,9 @@ class TestSpeculateScheduleCache(unittest.TestCase): self.np_draft_tokens, self.np_block_tables, self.np_stop_flags, + self.prompt_lens, self.np_seq_lens_this_time, + self.np_seq_lens_encoder, self.np_seq_lens_decoder, self.np_step_seq_lens_decoder, self.np_step_draft_tokens, @@ -213,7 +223,9 @@ class TestSpeculateScheduleCache(unittest.TestCase): self.draft_tokens, self.block_tables, self.stop_flags, + self.prompt_lens, self.seq_lens_this_time, + self.seq_lens_encoder, self.seq_lens_decoder, self.step_seq_lens_decoder, self.step_draft_tokens,