[Speculative Decoding][MTP]Support stop_seqs and pd-split mode (#5029)

* support multi_stop_seqs in speculative decoding

* support mtp tp with ep split

* fix custom op register

* fix spec stop_seqs params
This commit is contained in:
freeliuzc
2025-11-20 15:26:01 +08:00
committed by GitHub
parent 3e3558f492
commit f1e36ff2f7
3 changed files with 143 additions and 128 deletions

View File

@@ -17,7 +17,6 @@
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
// #define DEBUG_SPEC_STOP_SEQS
__global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
@@ -34,100 +33,101 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
const int stop_seqs_bs,
const int stop_seqs_max_len,
const int pre_ids_len) {
const int bid = blockIdx.x;
const int tid = threadIdx.x;
if (tid >= stop_seqs_bs) return;
const int stop_seq_len = stop_seqs_len[tid];
if (stop_seq_len <= 0) return;
if (bid < bs) {
const int64_t *stop_seq_now = stop_seqs + tid * stop_seqs_max_len;
const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len;
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
const int accept_num = accept_nums[bid];
const int64_t step_idx_now = step_idx[bid];
if (!stop_flags[bid]) {
int accept_idx = 0;
bool is_end = false;
// 遍历起始位置
for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) {
if (step_idx_now - accept_num + accept_idx + 1 < stop_seq_len) {
const int bid = blockIdx.x;
const int tid = threadIdx.x;
if (tid >= stop_seqs_bs) return;
const int stop_seq_len = stop_seqs_len[bid * stop_seqs_bs + tid];
if (stop_seq_len <= 0) return;
if (bid < bs) {
const int64_t *stop_seq_now = stop_seqs +
bid * stop_seqs_max_len * stop_seqs_bs +
tid * stop_seqs_max_len;
const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len;
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
const int accept_num = accept_nums[bid];
const int64_t step_idx_now = step_idx[bid];
if (!stop_flags[bid]) {
int accept_idx = 0;
bool is_end = false;
// 遍历起始位置
for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) {
if (step_idx_now - accept_num + accept_idx + 1 < stop_seq_len) {
#ifdef DEBUG_SPEC_STOP_SEQS
printf("num %d < stop_seq_len %d\n",
step_idx_now - accept_num + accept_idx + 1,
stop_seq_len);
printf("num %d < stop_seq_len %d\n",
step_idx_now - accept_num + accept_idx + 1,
stop_seq_len);
#endif
continue;
}
// 遍历一个 stop_seqs
for (int i = stop_seq_len - 1; i >= 0; --i) {
int64_t cur_token_idx = -1;
// 通过当前值判断 token 是在 pre_ids 还是 accept_token 里
if (stop_seq_len - 1 - i < accept_idx) {
#ifdef DEBUG_SPEC_STOP_SEQS
printf(
"AcceptTokens bid:%d. tid:%d, accept_idx:%d, "
"accept_token_idx: "
"%d\n",
bid,
tid,
accept_idx,
accept_idx - (stop_seq_len - 1 - i) - 1);
#endif
cur_token_idx =
accept_tokens_now[accept_idx -
(stop_seq_len - 1 - i) - 1];
} else {
#ifdef DEBUG_SPEC_STOP_SEQS
printf(
"PreIds bid:%d. tid:%d, step_idx_now:%ld. "
"accept_idx:%d. "
"pre_id_idx: %ld\n",
bid,
tid,
step_idx_now,
accept_idx,
step_idx_now - accept_num + accept_idx -
(stop_seq_len - 1 - i));
#endif
int pre_ids_idx = step_idx_now - accept_num +
accept_idx - (stop_seq_len - 1 - i);
// EC3
// 特殊拼接会导致input_ids最后一位无特殊token即pre_ids[0]可能为23,
// 导致异常结束
if (pre_ids_idx <= 0) {
break;
}
cur_token_idx = pre_ids_now[pre_ids_idx];
}
#ifdef DEBUG_SPEC_STOP_SEQS
printf(
"bid:%d. tid:%d, cur_token_idx: %ld. stop_seq_now "
"%ld\n",
bid,
tid,
cur_token_idx,
stop_seq_now[i]);
#endif
if (cur_token_idx != stop_seq_now[i]) {
break;
}
if (i == 0) {
is_end = true;
}
}
}
if (is_end) {
#ifdef DEBUG_SPEC_STOP_SEQS
printf("bid:%d end with accept_idx %d", bid, accept_idx);
#endif
accept_nums[bid] = accept_idx;
accept_tokens_now[accept_idx - 1] = end_ids[0];
stop_flags[bid] = true;
}
continue;
}
// 遍历一个 stop_seqs
for (int i = stop_seq_len - 1; i >= 0; --i) {
int64_t cur_token_idx = -1;
// 通过当前值判断 token 是在 pre_ids 还是 accept_token 里
if (stop_seq_len - 1 - i < accept_idx) {
#ifdef DEBUG_SPEC_STOP_SEQS
printf(
"AcceptTokens bid:%d. tid:%d, accept_idx:%d, "
"accept_token_idx: "
"%d\n",
bid,
tid,
accept_idx,
accept_idx - (stop_seq_len - 1 - i) - 1);
#endif
cur_token_idx =
accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1];
} else {
#ifdef DEBUG_SPEC_STOP_SEQS
printf(
"PreIds bid:%d. tid:%d, step_idx_now:%ld. "
"accept_idx:%d. "
"pre_id_idx: %ld\n",
bid,
tid,
step_idx_now,
accept_idx,
step_idx_now - accept_num + accept_idx -
(stop_seq_len - 1 - i));
#endif
int pre_ids_idx =
step_idx_now - accept_num + accept_idx - (stop_seq_len - 1 - i);
// EC3
// 特殊拼接会导致input_ids最后一位无特殊token即pre_ids[0]可能为23,
// 导致异常结束
if (pre_ids_idx <= 0) {
break;
}
cur_token_idx = pre_ids_now[pre_ids_idx];
}
#ifdef DEBUG_SPEC_STOP_SEQS
printf(
"bid:%d. tid:%d, cur_token_idx: %ld. stop_seq_now "
"%ld\n",
bid,
tid,
cur_token_idx,
stop_seq_now[i]);
#endif
if (cur_token_idx != stop_seq_now[i]) {
break;
}
if (i == 0) {
is_end = true;
}
}
}
if (is_end) {
#ifdef DEBUG_SPEC_STOP_SEQS
printf("bid:%d end with accept_idx %d", bid, accept_idx);
#endif
accept_nums[bid] = accept_idx;
accept_tokens_now[accept_idx - 1] = end_ids[0];
stop_flags[bid] = true;
}
}
}
}
void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens,
@@ -139,34 +139,34 @@ void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens,
const paddle::Tensor &stop_seqs,
const paddle::Tensor &stop_seqs_len,
const paddle::Tensor &end_ids) {
PD_CHECK(accept_tokens.dtype() == paddle::DataType::INT64);
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
PD_CHECK(accept_tokens.dtype() == paddle::DataType::INT64);
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
auto cu_stream = accept_tokens.stream();
std::vector<int64_t> shape = accept_tokens.shape();
std::vector<int64_t> stop_seqs_shape = stop_seqs.shape();
int bs_now = shape[0];
int stop_seqs_bs = stop_seqs_shape[0];
int stop_seqs_max_len = stop_seqs_shape[1];
int pre_ids_len = pre_ids.shape()[1];
int accept_tokens_len = accept_tokens.shape()[1];
auto cu_stream = accept_tokens.stream();
std::vector<int64_t> shape = accept_tokens.shape();
std::vector<int64_t> stop_seqs_shape = stop_seqs.shape();
int bs_now = shape[0];
int stop_seqs_bs = stop_seqs_shape[1];
int stop_seqs_max_len = stop_seqs_shape[2];
int pre_ids_len = pre_ids.shape()[1];
int accept_tokens_len = accept_tokens.shape()[1];
int block_size = (stop_seqs_bs + 31) / 32 * 32;
spec_set_value_by_stop_seqs<<<bs_now, block_size, 0, cu_stream>>>(
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
stop_seqs.data<int64_t>(),
stop_seqs_len.data<int>(),
seq_lens.data<int>(),
end_ids.data<int64_t>(),
bs_now,
accept_tokens_len,
stop_seqs_bs,
stop_seqs_max_len,
pre_ids_len);
int block_size = (stop_seqs_bs + 31) / 32 * 32;
spec_set_value_by_stop_seqs<<<bs_now, block_size, 0, cu_stream>>>(
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
stop_seqs.data<int64_t>(),
stop_seqs_len.data<int>(),
seq_lens.data<int>(),
end_ids.data<int64_t>(),
bs_now,
accept_tokens_len,
stop_seqs_bs,
stop_seqs_max_len,
pre_ids_len);
}
PD_BUILD_STATIC_OP(speculate_set_stop_value_multi_seqs)

View File

@@ -247,18 +247,22 @@ def load_ep_checkpoint(cls: PretrainedModel, model_path: str, fd_config: FDConfi
)
return base_range
prefix_layer_name = (
"mtp_block" if getattr(fd_config.speculative_config, "model_type", "main") == "mtp" else "layers"
)
for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers):
for j in get_expert_ranges(fd_config):
up_gate_proj_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight"
down_proj_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight"
up_gate_proj_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.up_gate_proj.weight"
down_proj_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.down_proj.weight"
up_gate_proj_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.quant_weight"
down_proj_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.quant_weight"
up_gate_proj_quant_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.up_gate_proj.quant_weight"
down_proj_quant_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.down_proj.quant_weight"
up_gate_proj_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight_scale"
down_proj_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight_scale"
up_gate_proj_scale_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.up_gate_proj.weight_scale"
down_proj_scale_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.down_proj.weight_scale"
down_proj_in_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.activation_scale"
down_proj_in_scale_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.down_proj.activation_scale"
num_local_ffn_keys.append(up_gate_proj_key)
num_local_ffn_keys.append(down_proj_key)
num_local_ffn_keys.append(up_gate_proj_quant_key)
@@ -273,7 +277,7 @@ def load_ep_checkpoint(cls: PretrainedModel, model_path: str, fd_config: FDConfi
num_experts = num_experts[0]
for j in range(num_experts):
up_gate_proj_in_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.activation_scale"
up_gate_proj_in_scale_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.up_gate_proj.activation_scale"
num_local_ffn_keys.append(up_gate_proj_in_scale_key)
for k in num_local_ffn_keys:
@@ -284,7 +288,7 @@ def load_ep_checkpoint(cls: PretrainedModel, model_path: str, fd_config: FDConfi
no_tp_action_keys = copy.deepcopy(num_local_ffn_keys)
if fd_config.parallel_config.use_sequence_parallel_moe:
for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers):
k = f"ernie.layers.{i}.self_attn.o_proj.weight"
k = f"ernie.{prefix_layer_name}.{i}.self_attn.o_proj.weight"
if k in weight_list:
no_tp_action_keys.append(k)
tp_actions = cls._get_tensor_parallel_mappings(fd_config.model_config.pretrained_config)
@@ -506,7 +510,7 @@ def load_composite_checkpoint(
# 2. Tensor Parallel (TP)
# 3. Pre-sharded (pre-split)
"""
if fd_config.parallel_config.use_ep and fd_config.speculative_config.model_type != "mtp":
if fd_config.parallel_config.use_ep:
state_dict = load_ep_checkpoint(cls, model_path, fd_config, return_numpy=True)
else:
rank_dirs = [

View File

@@ -79,6 +79,7 @@ else:
speculate_step_paddle,
speculate_step_system_cache,
speculate_update,
speculate_set_stop_value_multi_seqs,
step_paddle,
step_system_cache,
update_inputs,
@@ -467,7 +468,17 @@ def post_process_specualate(
think_end_id=think_end_id,
line_break_id=line_break_id,
)
speculate_set_stop_value_multi_seqs(
model_output.accept_tokens,
model_output.accept_num,
model_output.pre_ids,
model_output.step_idx,
model_output.stop_flags,
model_output.seq_lens_this_time,
model_output.stop_token_ids,
model_output.stop_seqs_len,
model_output.eos_token_id,
)
speculate_update(
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,