From e1a9b282eb7d48784d28a6529e528b774634d74b Mon Sep 17 00:00:00 2001 From: lizan1999 <55830407+lizan1999@users.noreply.github.com> Date: Thu, 18 Dec 2025 14:34:54 +0800 Subject: [PATCH] fix bug for EP+MTP (#5605) Co-authored-by: lizan1999 --- custom_ops/xpu_ops/src/ops/adjust_batch.cc | 22 ++++---- .../ops/mtp/speculate_get_padding_offset.cc | 50 ++++++++++--------- fastdeploy/spec_decode/mtp.py | 8 --- 3 files changed, 38 insertions(+), 42 deletions(-) diff --git a/custom_ops/xpu_ops/src/ops/adjust_batch.cc b/custom_ops/xpu_ops/src/ops/adjust_batch.cc index fb3b31688..b33f51ddb 100644 --- a/custom_ops/xpu_ops/src/ops/adjust_batch.cc +++ b/custom_ops/xpu_ops/src/ops/adjust_batch.cc @@ -71,16 +71,18 @@ std::vector AdjustBatchKernel( const_cast(decoder_batch_idx.data())}; auto out = paddle::empty({token_num, dim}, x.type(), x.place()); - - int r = baidu::xpu::api::plugin::eb_adjust_batch( - ctx, - reinterpret_cast(x.data()), - reinterpret_cast(out.data()), - encoder_seqs_lods_vp, - decoder_seqs_lods_vp, - encoder_batch_map_vp, - decoder_batch_map_vp, - dim); + if (token_num > 0) { + int r = baidu::xpu::api::plugin::eb_adjust_batch( + ctx, + reinterpret_cast(x.data()), + reinterpret_cast(out.data()), + encoder_seqs_lods_vp, + decoder_seqs_lods_vp, + encoder_batch_map_vp, + decoder_batch_map_vp, + dim); + PD_CHECK(r == 0, "XPU eb_adjust_batch failed"); + } return {out}; } diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_get_padding_offset.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_padding_offset.cc index f22dc7aaa..7ebf64ccc 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_get_padding_offset.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_padding_offset.cc @@ -57,31 +57,33 @@ std::vector SpeculateGetPaddingOffset( "Cum offsets tensor must be contiguous"); PD_CHECK(seq_len.is_contiguous(), "Seq lens tensor must be contiguous"); - int r = baidu::xpu::api::plugin::speculate_get_padding_offset( - xpu_ctx->x_context(), - batch_id_per_token.data(), - cum_offsets_out.data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - cum_offsets.data(), - seq_len.data(), - seq_length, - bsz); - PD_CHECK(r == 0, "XPU speculate_get_padding_offset failed"); + if (token_num_data > 0) { + int r = baidu::xpu::api::plugin::speculate_get_padding_offset( + xpu_ctx->x_context(), + batch_id_per_token.data(), + cum_offsets_out.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + cum_offsets.data(), + seq_len.data(), + seq_length, + bsz); + PD_CHECK(r == 0, "XPU speculate_get_padding_offset failed"); - r = baidu::xpu::api::plugin::speculate_remove_padding( - xpu_ctx->x_context(), - x_remove_padding.data(), - input_ids.data(), - draft_tokens.data(), - seq_len.data(), - seq_lens_encoder.data(), - cum_offsets_out.data(), - seq_length, - max_draft_tokens, - bsz, - token_num_data); - PD_CHECK(r == 0, "XPU speculate_remove_padding failed"); + r = baidu::xpu::api::plugin::speculate_remove_padding( + xpu_ctx->x_context(), + x_remove_padding.data(), + input_ids.data(), + draft_tokens.data(), + seq_len.data(), + seq_lens_encoder.data(), + cum_offsets_out.data(), + seq_length, + max_draft_tokens, + bsz, + token_num_data); + PD_CHECK(r == 0, "XPU speculate_remove_padding failed"); + } return {x_remove_padding, cum_offsets_out, diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index b373141a9..052bb2a6a 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -732,14 +732,6 @@ class MTPProposer(Proposer): for attn_backend in self.attn_backends: attn_backend.init_attention_metadata(self.forward_meta) - # Mix ep in single node - if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": - only_decode_batch_list = [] - prefill_exists = self.exist_prefill() - paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists) - only_decode_batch = all(only_decode_batch_list) - self.fd_config.model_config.moe_phase.phase = "decode" if only_decode_batch else "prefill" - def exist_prefill(self): """ check whether prefill stage exist