fix bug for EP+MTP (#5605)

Co-authored-by: lizan1999 <lizan03@baidu.com>
This commit is contained in:
lizan1999
2025-12-18 14:34:54 +08:00
committed by GitHub
parent d8587e987e
commit e1a9b282eb
3 changed files with 38 additions and 42 deletions

View File

@@ -71,16 +71,18 @@ std::vector<paddle::Tensor> AdjustBatchKernel(
const_cast<int32_t *>(decoder_batch_idx.data<int32_t>())};
auto out = paddle::empty({token_num, dim}, x.type(), x.place());
int r = baidu::xpu::api::plugin::eb_adjust_batch<XPUType, XPUType>(
ctx,
reinterpret_cast<const XPUType *>(x.data<data_t>()),
reinterpret_cast<XPUType *>(out.data<data_t>()),
encoder_seqs_lods_vp,
decoder_seqs_lods_vp,
encoder_batch_map_vp,
decoder_batch_map_vp,
dim);
if (token_num > 0) {
int r = baidu::xpu::api::plugin::eb_adjust_batch<XPUType, XPUType>(
ctx,
reinterpret_cast<const XPUType *>(x.data<data_t>()),
reinterpret_cast<XPUType *>(out.data<data_t>()),
encoder_seqs_lods_vp,
decoder_seqs_lods_vp,
encoder_batch_map_vp,
decoder_batch_map_vp,
dim);
PD_CHECK(r == 0, "XPU eb_adjust_batch failed");
}
return {out};
}

View File

@@ -57,31 +57,33 @@ std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
"Cum offsets tensor must be contiguous");
PD_CHECK(seq_len.is_contiguous(), "Seq lens tensor must be contiguous");
int r = baidu::xpu::api::plugin::speculate_get_padding_offset(
xpu_ctx->x_context(),
batch_id_per_token.data<int>(),
cum_offsets_out.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
seq_length,
bsz);
PD_CHECK(r == 0, "XPU speculate_get_padding_offset failed");
if (token_num_data > 0) {
int r = baidu::xpu::api::plugin::speculate_get_padding_offset(
xpu_ctx->x_context(),
batch_id_per_token.data<int>(),
cum_offsets_out.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
seq_length,
bsz);
PD_CHECK(r == 0, "XPU speculate_get_padding_offset failed");
r = baidu::xpu::api::plugin::speculate_remove_padding<int64_t>(
xpu_ctx->x_context(),
x_remove_padding.data<int64_t>(),
input_ids.data<int64_t>(),
draft_tokens.data<int64_t>(),
seq_len.data<int>(),
seq_lens_encoder.data<int>(),
cum_offsets_out.data<int>(),
seq_length,
max_draft_tokens,
bsz,
token_num_data);
PD_CHECK(r == 0, "XPU speculate_remove_padding failed");
r = baidu::xpu::api::plugin::speculate_remove_padding<int64_t>(
xpu_ctx->x_context(),
x_remove_padding.data<int64_t>(),
input_ids.data<int64_t>(),
draft_tokens.data<int64_t>(),
seq_len.data<int>(),
seq_lens_encoder.data<int>(),
cum_offsets_out.data<int>(),
seq_length,
max_draft_tokens,
bsz,
token_num_data);
PD_CHECK(r == 0, "XPU speculate_remove_padding failed");
}
return {x_remove_padding,
cum_offsets_out,

View File

@@ -732,14 +732,6 @@ class MTPProposer(Proposer):
for attn_backend in self.attn_backends:
attn_backend.init_attention_metadata(self.forward_meta)
# Mix ep in single node
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
only_decode_batch_list = []
prefill_exists = self.exist_prefill()
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
only_decode_batch = all(only_decode_batch_list)
self.fd_config.model_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"
def exist_prefill(self):
"""
check whether prefill stage exist