From f1e36ff2f77e326657452ea952e00e16b8f89310 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Thu, 20 Nov 2025 15:26:01 +0800 Subject: [PATCH] [Speculative Decoding][MTP]Support stop_seqs and pd-split mode (#5029) * support multi_stop_seqs in speculative decoding * support mtp tp with ep split * fix custom op register * fix spec stop_seqs params --- .../speculate_set_stop_value_multi_seqs.cu | 234 +++++++++--------- .../model_executor/load_weight_utils.py | 24 +- .../model_executor/pre_and_post_process.py | 13 +- 3 files changed, 143 insertions(+), 128 deletions(-) diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu index 44f00c2a9..956beceb5 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu @@ -17,7 +17,6 @@ #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) #endif - // #define DEBUG_SPEC_STOP_SEQS __global__ void spec_set_value_by_stop_seqs(bool *stop_flags, @@ -34,100 +33,101 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags, const int stop_seqs_bs, const int stop_seqs_max_len, const int pre_ids_len) { - const int bid = blockIdx.x; - const int tid = threadIdx.x; - if (tid >= stop_seqs_bs) return; - const int stop_seq_len = stop_seqs_len[tid]; - if (stop_seq_len <= 0) return; - if (bid < bs) { - const int64_t *stop_seq_now = stop_seqs + tid * stop_seqs_max_len; - const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len; - int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len; - const int accept_num = accept_nums[bid]; - const int64_t step_idx_now = step_idx[bid]; - if (!stop_flags[bid]) { - int accept_idx = 0; - bool is_end = false; - // 遍历起始位置 - for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) { - if (step_idx_now - accept_num + accept_idx + 1 < stop_seq_len) { + const int bid = blockIdx.x; + const int tid = threadIdx.x; + if (tid >= stop_seqs_bs) return; + const int stop_seq_len = stop_seqs_len[bid * stop_seqs_bs + tid]; + if (stop_seq_len <= 0) return; + if (bid < bs) { + const int64_t *stop_seq_now = stop_seqs + + bid * stop_seqs_max_len * stop_seqs_bs + + tid * stop_seqs_max_len; + const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len; + int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len; + const int accept_num = accept_nums[bid]; + const int64_t step_idx_now = step_idx[bid]; + if (!stop_flags[bid]) { + int accept_idx = 0; + bool is_end = false; + // 遍历起始位置 + for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) { + if (step_idx_now - accept_num + accept_idx + 1 < stop_seq_len) { #ifdef DEBUG_SPEC_STOP_SEQS - printf("num %d < stop_seq_len %d\n", - step_idx_now - accept_num + accept_idx + 1, - stop_seq_len); + printf("num %d < stop_seq_len %d\n", + step_idx_now - accept_num + accept_idx + 1, + stop_seq_len); #endif - continue; - } - // 遍历一个 stop_seqs - for (int i = stop_seq_len - 1; i >= 0; --i) { - int64_t cur_token_idx = -1; - - // 通过当前值判断 token 是在 pre_ids 还是 accept_token 里 - if (stop_seq_len - 1 - i < accept_idx) { -#ifdef DEBUG_SPEC_STOP_SEQS - printf( - "AcceptTokens bid:%d. tid:%d, accept_idx:%d, " - "accept_token_idx: " - "%d\n", - bid, - tid, - accept_idx, - accept_idx - (stop_seq_len - 1 - i) - 1); -#endif - cur_token_idx = - accept_tokens_now[accept_idx - - (stop_seq_len - 1 - i) - 1]; - } else { -#ifdef DEBUG_SPEC_STOP_SEQS - printf( - "PreIds bid:%d. tid:%d, step_idx_now:%ld. " - "accept_idx:%d. " - "pre_id_idx: %ld\n", - bid, - tid, - step_idx_now, - accept_idx, - step_idx_now - accept_num + accept_idx - - (stop_seq_len - 1 - i)); -#endif - int pre_ids_idx = step_idx_now - accept_num + - accept_idx - (stop_seq_len - 1 - i); - // EC3 - // 特殊拼接会导致input_ids最后一位无特殊token,即pre_ids[0]可能为23, - // 导致异常结束 - if (pre_ids_idx <= 0) { - break; - } - cur_token_idx = pre_ids_now[pre_ids_idx]; - } -#ifdef DEBUG_SPEC_STOP_SEQS - printf( - "bid:%d. tid:%d, cur_token_idx: %ld. stop_seq_now " - "%ld\n", - bid, - tid, - cur_token_idx, - stop_seq_now[i]); -#endif - if (cur_token_idx != stop_seq_now[i]) { - break; - } - if (i == 0) { - is_end = true; - } - } - } - if (is_end) { -#ifdef DEBUG_SPEC_STOP_SEQS - printf("bid:%d end with accept_idx %d", bid, accept_idx); -#endif - - accept_nums[bid] = accept_idx; - accept_tokens_now[accept_idx - 1] = end_ids[0]; - stop_flags[bid] = true; - } + continue; } + // 遍历一个 stop_seqs + for (int i = stop_seq_len - 1; i >= 0; --i) { + int64_t cur_token_idx = -1; + + // 通过当前值判断 token 是在 pre_ids 还是 accept_token 里 + if (stop_seq_len - 1 - i < accept_idx) { +#ifdef DEBUG_SPEC_STOP_SEQS + printf( + "AcceptTokens bid:%d. tid:%d, accept_idx:%d, " + "accept_token_idx: " + "%d\n", + bid, + tid, + accept_idx, + accept_idx - (stop_seq_len - 1 - i) - 1); +#endif + cur_token_idx = + accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1]; + } else { +#ifdef DEBUG_SPEC_STOP_SEQS + printf( + "PreIds bid:%d. tid:%d, step_idx_now:%ld. " + "accept_idx:%d. " + "pre_id_idx: %ld\n", + bid, + tid, + step_idx_now, + accept_idx, + step_idx_now - accept_num + accept_idx - + (stop_seq_len - 1 - i)); +#endif + int pre_ids_idx = + step_idx_now - accept_num + accept_idx - (stop_seq_len - 1 - i); + // EC3 + // 特殊拼接会导致input_ids最后一位无特殊token,即pre_ids[0]可能为23, + // 导致异常结束 + if (pre_ids_idx <= 0) { + break; + } + cur_token_idx = pre_ids_now[pre_ids_idx]; + } +#ifdef DEBUG_SPEC_STOP_SEQS + printf( + "bid:%d. tid:%d, cur_token_idx: %ld. stop_seq_now " + "%ld\n", + bid, + tid, + cur_token_idx, + stop_seq_now[i]); +#endif + if (cur_token_idx != stop_seq_now[i]) { + break; + } + if (i == 0) { + is_end = true; + } + } + } + if (is_end) { +#ifdef DEBUG_SPEC_STOP_SEQS + printf("bid:%d end with accept_idx %d", bid, accept_idx); +#endif + + accept_nums[bid] = accept_idx; + accept_tokens_now[accept_idx - 1] = end_ids[0]; + stop_flags[bid] = true; + } } + } } void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens, @@ -139,34 +139,34 @@ void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens, const paddle::Tensor &stop_seqs, const paddle::Tensor &stop_seqs_len, const paddle::Tensor &end_ids) { - PD_CHECK(accept_tokens.dtype() == paddle::DataType::INT64); - PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL); + PD_CHECK(accept_tokens.dtype() == paddle::DataType::INT64); + PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL); - auto cu_stream = accept_tokens.stream(); - std::vector shape = accept_tokens.shape(); - std::vector stop_seqs_shape = stop_seqs.shape(); - int bs_now = shape[0]; - int stop_seqs_bs = stop_seqs_shape[0]; - int stop_seqs_max_len = stop_seqs_shape[1]; - int pre_ids_len = pre_ids.shape()[1]; - int accept_tokens_len = accept_tokens.shape()[1]; + auto cu_stream = accept_tokens.stream(); + std::vector shape = accept_tokens.shape(); + std::vector stop_seqs_shape = stop_seqs.shape(); + int bs_now = shape[0]; + int stop_seqs_bs = stop_seqs_shape[1]; + int stop_seqs_max_len = stop_seqs_shape[2]; + int pre_ids_len = pre_ids.shape()[1]; + int accept_tokens_len = accept_tokens.shape()[1]; - int block_size = (stop_seqs_bs + 31) / 32 * 32; - spec_set_value_by_stop_seqs<<>>( - const_cast(stop_flags.data()), - const_cast(accept_tokens.data()), - const_cast(accept_num.data()), - pre_ids.data(), - step_idx.data(), - stop_seqs.data(), - stop_seqs_len.data(), - seq_lens.data(), - end_ids.data(), - bs_now, - accept_tokens_len, - stop_seqs_bs, - stop_seqs_max_len, - pre_ids_len); + int block_size = (stop_seqs_bs + 31) / 32 * 32; + spec_set_value_by_stop_seqs<<>>( + const_cast(stop_flags.data()), + const_cast(accept_tokens.data()), + const_cast(accept_num.data()), + pre_ids.data(), + step_idx.data(), + stop_seqs.data(), + stop_seqs_len.data(), + seq_lens.data(), + end_ids.data(), + bs_now, + accept_tokens_len, + stop_seqs_bs, + stop_seqs_max_len, + pre_ids_len); } PD_BUILD_STATIC_OP(speculate_set_stop_value_multi_seqs) diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index 408607b10..d2b03f8b1 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -247,18 +247,22 @@ def load_ep_checkpoint(cls: PretrainedModel, model_path: str, fd_config: FDConfi ) return base_range + prefix_layer_name = ( + "mtp_block" if getattr(fd_config.speculative_config, "model_type", "main") == "mtp" else "layers" + ) + for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers): for j in get_expert_ranges(fd_config): - up_gate_proj_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight" - down_proj_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight" + up_gate_proj_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.up_gate_proj.weight" + down_proj_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.down_proj.weight" - up_gate_proj_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.quant_weight" - down_proj_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.quant_weight" + up_gate_proj_quant_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.up_gate_proj.quant_weight" + down_proj_quant_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.down_proj.quant_weight" - up_gate_proj_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight_scale" - down_proj_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight_scale" + up_gate_proj_scale_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.up_gate_proj.weight_scale" + down_proj_scale_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.down_proj.weight_scale" - down_proj_in_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.activation_scale" + down_proj_in_scale_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.down_proj.activation_scale" num_local_ffn_keys.append(up_gate_proj_key) num_local_ffn_keys.append(down_proj_key) num_local_ffn_keys.append(up_gate_proj_quant_key) @@ -273,7 +277,7 @@ def load_ep_checkpoint(cls: PretrainedModel, model_path: str, fd_config: FDConfi num_experts = num_experts[0] for j in range(num_experts): - up_gate_proj_in_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.activation_scale" + up_gate_proj_in_scale_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.up_gate_proj.activation_scale" num_local_ffn_keys.append(up_gate_proj_in_scale_key) for k in num_local_ffn_keys: @@ -284,7 +288,7 @@ def load_ep_checkpoint(cls: PretrainedModel, model_path: str, fd_config: FDConfi no_tp_action_keys = copy.deepcopy(num_local_ffn_keys) if fd_config.parallel_config.use_sequence_parallel_moe: for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers): - k = f"ernie.layers.{i}.self_attn.o_proj.weight" + k = f"ernie.{prefix_layer_name}.{i}.self_attn.o_proj.weight" if k in weight_list: no_tp_action_keys.append(k) tp_actions = cls._get_tensor_parallel_mappings(fd_config.model_config.pretrained_config) @@ -506,7 +510,7 @@ def load_composite_checkpoint( # 2. Tensor Parallel (TP) # 3. Pre-sharded (pre-split) """ - if fd_config.parallel_config.use_ep and fd_config.speculative_config.model_type != "mtp": + if fd_config.parallel_config.use_ep: state_dict = load_ep_checkpoint(cls, model_path, fd_config, return_numpy=True) else: rank_dirs = [ diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index c9c8b27b4..d2c82e2af 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -79,6 +79,7 @@ else: speculate_step_paddle, speculate_step_system_cache, speculate_update, + speculate_set_stop_value_multi_seqs, step_paddle, step_system_cache, update_inputs, @@ -467,7 +468,17 @@ def post_process_specualate( think_end_id=think_end_id, line_break_id=line_break_id, ) - + speculate_set_stop_value_multi_seqs( + model_output.accept_tokens, + model_output.accept_num, + model_output.pre_ids, + model_output.step_idx, + model_output.stop_flags, + model_output.seq_lens_this_time, + model_output.stop_token_ids, + model_output.stop_seqs_len, + model_output.eos_token_id, + ) speculate_update( model_output.seq_lens_encoder, model_output.seq_lens_decoder,