diff --git a/custom_ops/gpu_ops/update_attn_mask_offsets.cu b/custom_ops/gpu_ops/update_attn_mask_offsets.cu index 3318fd0cf..7d9611c7b 100644 --- a/custom_ops/gpu_ops/update_attn_mask_offsets.cu +++ b/custom_ops/gpu_ops/update_attn_mask_offsets.cu @@ -24,7 +24,7 @@ __global__ void update_attn_mask_offsets_kernel( int* attn_mask_offsets_decoder, const bool* is_block_step, int* decode_states, - const int* mask_rollback, + int* mask_rollback, const int real_bsz, const int max_model_len, const int decode_states_len) { @@ -58,7 +58,7 @@ __global__ void update_attn_mask_offsets_kernel( // Status: decoder -- normal or chunk_prefill // TODO: support speculative decoding. attn_mask_offsets_decoder[bid] -= mask_rollback[bid]; - + mask_rollback[bid] = 0; for (int i = 0; i < seq_len_this_time; i++) { attn_mask_offsets[(query_start_id + i) * 2 + 1] = attn_mask_offsets_decoder[bid] + 1 + i; @@ -117,7 +117,7 @@ std::vector UpdateAttnMaskOffsets( const_cast(attn_mask_offsets_decoder.data()), is_block_step.data(), const_cast(decode_states.data()), - mask_rollback.data(), + const_cast(mask_rollback.data()), real_bsz, max_model_len, decode_states_len); @@ -136,6 +136,7 @@ PD_BUILD_STATIC_OP(update_attn_mask_offsets) "is_block_step", "decode_states", "mask_rollback"}) - .Outputs({"attn_mask_offsets", "decode_states_out"}) - .SetInplaceMap({{"decode_states", "decode_states_out"}}) + .Outputs({"attn_mask_offsets", "decode_states_out", "mask_rollback_out"}) + .SetInplaceMap({{"decode_states", "decode_states_out"}, + {"mask_rollback", "mask_rollback_out"}}) .SetKernelFn(PD_KERNEL(UpdateAttnMaskOffsets)); diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 136f19508..aefe43037 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -319,9 +319,6 @@ class EngineService: ) self.cfg.cache_config.cache_queue_port = self.cache_task_queue.get_server_port() - self.llm_logger.info( - f"local {min(self.cfg.worker_num_per_node * self.cfg.node_rank + self.cfg.parallel_config.local_data_parallel_id,self.cfg.parallel_config.data_parallel_size - 1)}" - ) self.engine_worker_queue = EngineWorkerQueue( address=address, is_server=False, diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 321ef5a1b..611c3ab5f 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -515,6 +515,12 @@ class MTPProposer(Proposer): self.model_inputs["attn_mask_offsets_decoder"][idx : idx + 1] = ( inputs["attention_mask_offset"][prefill_end_index - 1] + 1 ) + if ( + self.fd_config.scheduler_config.splitwise_role == "decode" + ): # In PD, we continue to decode after P generates first token + self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0 + # P-D split need rollback one step + self.model_inputs["mask_rollback"][idx : idx + 1] = 1 # has_prefill_task = True elif request.task_type.value == RequestType.DECODE.value: # decode task