fix pd-split first step bug (#5246)

This commit is contained in:
freeliuzc
2025-11-26 18:02:32 +08:00
committed by GitHub
parent 710753377f
commit bdcc952eeb
3 changed files with 12 additions and 8 deletions

View File

@@ -24,7 +24,7 @@ __global__ void update_attn_mask_offsets_kernel(
int* attn_mask_offsets_decoder,
const bool* is_block_step,
int* decode_states,
const int* mask_rollback,
int* mask_rollback,
const int real_bsz,
const int max_model_len,
const int decode_states_len) {
@@ -58,7 +58,7 @@ __global__ void update_attn_mask_offsets_kernel(
// Status: decoder -- normal or chunk_prefill
// TODO: support speculative decoding.
attn_mask_offsets_decoder[bid] -= mask_rollback[bid];
mask_rollback[bid] = 0;
for (int i = 0; i < seq_len_this_time; i++) {
attn_mask_offsets[(query_start_id + i) * 2 + 1] =
attn_mask_offsets_decoder[bid] + 1 + i;
@@ -117,7 +117,7 @@ std::vector<paddle::Tensor> UpdateAttnMaskOffsets(
const_cast<int*>(attn_mask_offsets_decoder.data<int>()),
is_block_step.data<bool>(),
const_cast<int*>(decode_states.data<int>()),
mask_rollback.data<int>(),
const_cast<int*>(mask_rollback.data<int>()),
real_bsz,
max_model_len,
decode_states_len);
@@ -136,6 +136,7 @@ PD_BUILD_STATIC_OP(update_attn_mask_offsets)
"is_block_step",
"decode_states",
"mask_rollback"})
.Outputs({"attn_mask_offsets", "decode_states_out"})
.SetInplaceMap({{"decode_states", "decode_states_out"}})
.Outputs({"attn_mask_offsets", "decode_states_out", "mask_rollback_out"})
.SetInplaceMap({{"decode_states", "decode_states_out"},
{"mask_rollback", "mask_rollback_out"}})
.SetKernelFn(PD_KERNEL(UpdateAttnMaskOffsets));

View File

@@ -319,9 +319,6 @@ class EngineService:
)
self.cfg.cache_config.cache_queue_port = self.cache_task_queue.get_server_port()
self.llm_logger.info(
f"local {min(self.cfg.worker_num_per_node * self.cfg.node_rank + self.cfg.parallel_config.local_data_parallel_id,self.cfg.parallel_config.data_parallel_size - 1)}"
)
self.engine_worker_queue = EngineWorkerQueue(
address=address,
is_server=False,

View File

@@ -515,6 +515,12 @@ class MTPProposer(Proposer):
self.model_inputs["attn_mask_offsets_decoder"][idx : idx + 1] = (
inputs["attention_mask_offset"][prefill_end_index - 1] + 1
)
if (
self.fd_config.scheduler_config.splitwise_role == "decode"
): # In PD, we continue to decode after P generates first token
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
# P-D split need rollback one step
self.model_inputs["mask_rollback"][idx : idx + 1] = 1
# has_prefill_task = True
elif request.task_type.value == RequestType.DECODE.value: # decode task