diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 0baed24ac..36b193d0c 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -782,7 +782,8 @@ void SpeculateUpdate(const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& stop_flags, const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& is_block_step, - const paddle::Tensor& stop_nums); + const paddle::Tensor& stop_nums, + const paddle::Tensor& mask_rollback); void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor& pre_ids_all, const paddle::Tensor& accept_tokens, @@ -1047,6 +1048,18 @@ void SpeculateGetTargetLogits(const paddle::Tensor& target_logits, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& accept_num); +std::vector UpdateAttnMaskOffsets( + const paddle::Tensor& ids_remove_padding, + const paddle::Tensor& seq_lens_this_time, // only on cpu + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& attn_mask_offsets_full, + const paddle::Tensor& attn_mask_offsets_decoder, + const paddle::Tensor& is_block_step, + const paddle::Tensor& decode_states, + const paddle::Tensor& mask_rollback); + PYBIND11_MODULE(fastdeploy_ops, m) { m.def("get_expert_token_num", &GetExpertTokenNum, @@ -1632,4 +1645,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("speculate_get_target_logits", &SpeculateGetTargetLogits, "speculate_get_target_logits function"); + + m.def("update_attn_mask_offsets", + &UpdateAttnMaskOffsets, + "update attention mask"); } diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_update.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_update.cu index 828dc1728..5c92b7e07 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_update.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_update.cu @@ -16,115 +16,116 @@ template __global__ void speculate_update(int *seq_lens_encoder, - int *seq_lens_decoder, - bool *not_need_stop, - int64_t *draft_tokens, - int *actual_draft_token_nums, - const int64_t *accept_tokens, - const int *accept_num, - const bool *stop_flags, - const int *seq_lens_this_time, - const bool *is_block_step, - const int64_t *stop_nums, - const int real_bsz, - const int max_bsz, - const int max_draft_tokens) { - const int bid = threadIdx.x; - const int accept_num_now = accept_num[bid]; - int stop_flag_now_int = 0; - if (!(is_block_step[bid] || bid >= real_bsz)) { - if (stop_flags[bid]) { - stop_flag_now_int = 1; - } - if (seq_lens_encoder[bid] == 0) { - seq_lens_decoder[bid] += accept_num_now; - } - - if (seq_lens_this_time[bid] > 1 && - seq_lens_encoder[bid] == - 0) { // 对于append模式,需要根据接收与否确定是否要降低下次draft - // token的数量 - auto current_actual_draft_token_num = actual_draft_token_nums[bid]; - if (accept_num_now - 1 == current_actual_draft_token_num) { - if (current_actual_draft_token_num + 2 <= - max_draft_tokens - 1) { - actual_draft_token_nums[bid] = - current_actual_draft_token_num + 2; - } else if (current_actual_draft_token_num + 1 <= - max_draft_tokens - 1) { - actual_draft_token_nums[bid] = - current_actual_draft_token_num + 1; - } else { - actual_draft_token_nums[bid] = max_draft_tokens - 1; - } - } else { - actual_draft_token_nums[bid] = - actual_draft_token_nums[bid] - 1 >= 1 - ? actual_draft_token_nums[bid] - 1 - : 1; - } - } - - if (seq_lens_encoder[bid] != 0) { - seq_lens_decoder[bid] += seq_lens_encoder[bid]; - seq_lens_encoder[bid] = 0; - } - draft_tokens[bid * max_draft_tokens] = - accept_tokens[bid * max_draft_tokens + accept_num_now - 1]; - } else if (bid >= real_bsz && bid < max_bsz) { - stop_flag_now_int = 1; + int *seq_lens_decoder, + bool *not_need_stop, + int64_t *draft_tokens, + int *actual_draft_token_nums, + const int64_t *accept_tokens, + const int *accept_num, + const bool *stop_flags, + const int *seq_lens_this_time, + const bool *is_block_step, + const int64_t *stop_nums, + int *mask_rollback, + const int real_bsz, + const int max_bsz, + const int max_draft_tokens) { + const int bid = threadIdx.x; + const int accept_num_now = accept_num[bid]; + int stop_flag_now_int = 0; + if (!(is_block_step[bid] || bid >= real_bsz)) { + if (stop_flags[bid]) { + stop_flag_now_int = 1; + mask_rollback[bid] = 0; + } else if (seq_lens_encoder[bid] == 0) { // decoder + seq_lens_decoder[bid] += accept_num_now; + mask_rollback[bid] = seq_lens_this_time[bid] - accept_num_now; + } else { // encoder + mask_rollback[bid] = 0; } - __syncthreads(); - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - // printf("stop_flag_now_int %d \n", stop_flag_now_int); - int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int); - - if (threadIdx.x == 0) { - // printf("stop_sum %d \n", stop_sum); - not_need_stop[0] = stop_sum < stop_nums[0]; + if (seq_lens_this_time[bid] > 1 && + seq_lens_encoder[bid] == + 0) { // 对于append模式,需要根据接收与否确定是否要降低下次draft + // token的数量 + auto current_actual_draft_token_num = actual_draft_token_nums[bid]; + if (accept_num_now - 1 == current_actual_draft_token_num) { + if (current_actual_draft_token_num + 2 <= max_draft_tokens - 1) { + actual_draft_token_nums[bid] = current_actual_draft_token_num + 2; + } else if (current_actual_draft_token_num + 1 <= max_draft_tokens - 1) { + actual_draft_token_nums[bid] = current_actual_draft_token_num + 1; + } else { + actual_draft_token_nums[bid] = max_draft_tokens - 1; + } + } else { + actual_draft_token_nums[bid] = actual_draft_token_nums[bid] - 1 >= 1 + ? actual_draft_token_nums[bid] - 1 + : 1; + } } + + if (seq_lens_encoder[bid] != 0) { + seq_lens_decoder[bid] += seq_lens_encoder[bid]; + seq_lens_encoder[bid] = 0; + } + draft_tokens[bid * max_draft_tokens] = + accept_tokens[bid * max_draft_tokens + accept_num_now - 1]; + } else if (bid >= real_bsz && bid < max_bsz) { + stop_flag_now_int = 1; + } + __syncthreads(); + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + // printf("stop_flag_now_int %d \n", stop_flag_now_int); + int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int); + + if (threadIdx.x == 0) { + // printf("stop_sum %d \n", stop_sum); + not_need_stop[0] = stop_sum < stop_nums[0]; + } } void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor ¬_need_stop, - const paddle::Tensor &draft_tokens, - const paddle::Tensor &actual_draft_token_nums, - const paddle::Tensor &accept_tokens, - const paddle::Tensor &accept_num, - const paddle::Tensor &stop_flags, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &is_block_step, - const paddle::Tensor &stop_nums) { - const int real_bsz = seq_lens_this_time.shape()[0]; - const int max_bsz = stop_flags.shape()[0]; - auto max_draft_tokens = draft_tokens.shape()[1]; + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor ¬_need_stop, + const paddle::Tensor &draft_tokens, + const paddle::Tensor &actual_draft_token_nums, + const paddle::Tensor &accept_tokens, + const paddle::Tensor &accept_num, + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &is_block_step, + const paddle::Tensor &stop_nums, + const paddle::Tensor &mask_rollback) { + const int real_bsz = seq_lens_this_time.shape()[0]; + const int max_bsz = stop_flags.shape()[0]; + auto max_draft_tokens = draft_tokens.shape()[1]; - constexpr int BlockSize = 512; + constexpr int BlockSize = 512; - auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false); - speculate_update<<<1, BlockSize, 0, accept_tokens.stream()>>>( - const_cast(seq_lens_encoder.data()), - const_cast(seq_lens_decoder.data()), - const_cast(not_need_stop_gpu.data()), - const_cast(draft_tokens.data()), - const_cast(actual_draft_token_nums.data()), - accept_tokens.data(), - accept_num.data(), - stop_flags.data(), - seq_lens_this_time.data(), - is_block_step.data(), - stop_nums.data(), - real_bsz, - max_bsz, - max_draft_tokens); + auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false); + speculate_update<<<1, BlockSize, 0, accept_tokens.stream()>>>( + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(not_need_stop_gpu.data()), + const_cast(draft_tokens.data()), + const_cast(actual_draft_token_nums.data()), + accept_tokens.data(), + accept_num.data(), + stop_flags.data(), + seq_lens_this_time.data(), + is_block_step.data(), + stop_nums.data(), + const_cast(mask_rollback.data()), + real_bsz, + max_bsz, + max_draft_tokens); - auto not_need_stop_cpu = - not_need_stop_gpu.copy_to(not_need_stop.place(), true); - bool *not_need_stop_data = const_cast(not_need_stop.data()); - not_need_stop_data[0] = not_need_stop_cpu.data()[0]; + auto not_need_stop_cpu = + not_need_stop_gpu.copy_to(not_need_stop.place(), true); + bool *not_need_stop_data = const_cast(not_need_stop.data()); + not_need_stop_data[0] = not_need_stop_cpu.data()[0]; } PD_BUILD_STATIC_OP(speculate_update) @@ -138,15 +139,18 @@ PD_BUILD_STATIC_OP(speculate_update) "stop_flags", "seq_lens_this_time", "is_block_step", - "stop_nums"}) + "stop_nums", + "mask_rollback"}) .Outputs({"seq_lens_encoder_out", "seq_lens_decoder_out", "not_need_stop_out", "draft_tokens_out", - "actual_draft_token_nums_out"}) + "actual_draft_token_nums_out", + "mask_rollback_out"}) .SetInplaceMap({{"seq_lens_encoder", "seq_lens_encoder_out"}, {"seq_lens_decoder", "seq_lens_decoder_out"}, {"not_need_stop", "not_need_stop_out"}, {"draft_tokens", "draft_tokens_out"}, - {"actual_draft_token_nums", "actual_draft_token_nums_out"}}) + {"actual_draft_token_nums", "actual_draft_token_nums_out"}, + {"mask_rollback", "mask_rollback_out"}}) .SetKernelFn(PD_KERNEL(SpeculateUpdate)); diff --git a/custom_ops/gpu_ops/update_attn_mask_offsets.cu b/custom_ops/gpu_ops/update_attn_mask_offsets.cu new file mode 100644 index 000000000..3318fd0cf --- /dev/null +++ b/custom_ops/gpu_ops/update_attn_mask_offsets.cu @@ -0,0 +1,141 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" + +__global__ void update_attn_mask_offsets_kernel( + int* attn_mask_offsets, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int* seq_lens_decoder, + const int* cu_seqlens_q, + const int* attn_mask_offsets_full, + int* attn_mask_offsets_decoder, + const bool* is_block_step, + int* decode_states, + const int* mask_rollback, + const int real_bsz, + const int max_model_len, + const int decode_states_len) { + int tid = threadIdx.x; + + for (int bid = tid; bid < real_bsz; bid += blockDim.x) { + int seq_len_this_time = seq_lens_this_time[bid]; + int seq_len_encoder = seq_lens_encoder[bid]; + int seq_len_decoder = seq_lens_decoder[bid]; + int query_start_id = cu_seqlens_q[bid]; + + const int* attn_mask_offsets_full_now = + attn_mask_offsets_full + bid * max_model_len; + int* decode_states_now = decode_states + bid * decode_states_len; + if (!is_block_step[bid]) { + if (seq_len_encoder == 0 && seq_len_decoder == 0) { + // Status: stop + } else if (seq_len_encoder > 0) { + for (int i = 0; i < seq_len_this_time; i++) { + if (*decode_states_now == 2 && seq_len_decoder > 0) { + // Status: vision generate phase + attn_mask_offsets[(query_start_id + i) * 2 + 1] = + seq_len_decoder + seq_len_this_time; + } else { + // Status: prefill -- normal or chunk_prefill + attn_mask_offsets[(query_start_id + i) * 2 + 1] = + attn_mask_offsets_full_now[i] + 1; + } + } + } else if (seq_len_decoder > 0) { + // Status: decoder -- normal or chunk_prefill + // TODO: support speculative decoding. + attn_mask_offsets_decoder[bid] -= mask_rollback[bid]; + + for (int i = 0; i < seq_len_this_time; i++) { + attn_mask_offsets[(query_start_id + i) * 2 + 1] = + attn_mask_offsets_decoder[bid] + 1 + i; + } + attn_mask_offsets_decoder[bid] += seq_len_this_time; + + // Speculative decoding in text_generation + if (seq_len_this_time > 1) { + for (int i = 0; i < decode_states_len; i++) { + if (i < seq_len_this_time) { + decode_states_now[i] = 0; + } else { + decode_states_now[i] = -1; + } + } + } + } + } + } +} + +std::vector UpdateAttnMaskOffsets( + const paddle::Tensor& ids_remove_padding, + const paddle::Tensor& seq_lens_this_time, // only on cpu + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& attn_mask_offsets_full, + const paddle::Tensor& attn_mask_offsets_decoder, + const paddle::Tensor& is_block_step, + const paddle::Tensor& decode_states, + const paddle::Tensor& mask_rollback) { + int max_model_len = attn_mask_offsets_full.shape()[1]; + int real_bsz = seq_lens_this_time.shape()[0]; + int batch_seq_lens = ids_remove_padding.shape()[0]; + int decode_states_len = decode_states.shape()[1]; + + auto attn_mask_offsets = paddle::full({batch_seq_lens * 2}, + 0, + paddle::DataType::INT32, + ids_remove_padding.place()); + + // launch config + int blockSize = 512; + + update_attn_mask_offsets_kernel<<<1, + blockSize, + 0, + ids_remove_padding.stream()>>>( + attn_mask_offsets.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + cu_seqlens_q.data(), + attn_mask_offsets_full.data(), + const_cast(attn_mask_offsets_decoder.data()), + is_block_step.data(), + const_cast(decode_states.data()), + mask_rollback.data(), + real_bsz, + max_model_len, + decode_states_len); + + return {attn_mask_offsets}; +} + +PD_BUILD_STATIC_OP(update_attn_mask_offsets) + .Inputs({"ids_remove_padding", + "seq_lens_this_time", + "seq_lens_encoder", + "seq_lens_decoder", + "cu_seqlens_q", + "attn_mask_offsets_full", + "attn_mask_offsets_decoder", + "is_block_step", + "decode_states", + "mask_rollback"}) + .Outputs({"attn_mask_offsets", "decode_states_out"}) + .SetInplaceMap({{"decode_states", "decode_states_out"}}) + .SetKernelFn(PD_KERNEL(UpdateAttnMaskOffsets)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 81f86b543..78c3b739c 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -305,6 +305,7 @@ elif paddle.is_compiled_with_cuda(): "gpu_ops/merge_prefill_decode_output.cu", "gpu_ops/limit_thinking_content_length_v1.cu", "gpu_ops/limit_thinking_content_length_v2.cu", + "gpu_ops/update_attn_mask_offsets.cu", ] # pd_disaggregation diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 813fb4790..138e9fcf1 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -96,7 +96,6 @@ class AppendAttentionBackend(AttentionBackend): self.rope_3d = False self.causal: bool = getattr(fd_config.model_config, "causal", True) self.speculative_method: str = fd_config.speculative_config.method - self.use_speculate: bool = self.speculative_method is not None self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"]) @@ -366,7 +365,7 @@ class AppendAttentionBackend(AttentionBackend): getattr(layer, "cache_v_zp", None), layer.linear_shift, layer.linear_smooth, - None if self.use_speculate else forward_meta.attn_mask_offsets, + None, metadata.kv_signal_data_list[layer.layer_id], getattr(layer, "q_norm_weight", None), getattr(layer, "k_norm_weight", None), @@ -385,7 +384,7 @@ class AppendAttentionBackend(AttentionBackend): metadata.max_partition_size, metadata.encoder_max_partition_size, self.speculate_max_draft_token_num + 1, - self.causal or self.use_speculate, + True, self.speculative_method is not None, sliding_window, ) diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 76f648f39..69748eba5 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -443,6 +443,7 @@ def post_process_specualate( model_output.seq_lens_this_time, model_output.is_block_step, model_output.stop_nums, + model_output.mask_rollback, ) if not skip_save_output: diff --git a/fastdeploy/spec_decode/base.py b/fastdeploy/spec_decode/base.py index 7438a0dbe..b7611f686 100644 --- a/fastdeploy/spec_decode/base.py +++ b/fastdeploy/spec_decode/base.py @@ -69,6 +69,8 @@ class Proposer(ABC): self.max_ngram_size = self.speculative_config.max_ngram_size self.min_ngram_size = self.speculative_config.min_ngram_size + self.enable_mm = self.model_config.enable_mm + spec_logger.info(f"Speculate config: {self.speculative_config}") def run(self, *args, **kwargs) -> Any: diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 50d57ed76..22134cfe8 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -46,6 +46,7 @@ from fastdeploy.model_executor.ops.gpu import ( share_external_data, speculate_get_logits, speculate_save_output_topk, + update_attn_mask_offsets, ) from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding @@ -441,6 +442,21 @@ class MTPProposer(Proposer): self.model_inputs["cu_next_token_offset"] = paddle.full( shape=[self.max_num_seqs + 1], fill_value=0, dtype="int32" ) + self.model_inputs["mask_rollback"] = paddle.full([self.max_num_seqs, 1], 0, dtype="int32") + # attn_mask + if self.enable_mm: + self.model_inputs["attn_mask_offsets"] = paddle.full( + shape=[self.max_num_seqs * self.max_model_len], fill_value=-1, dtype="int32" + ) + self.model_inputs["attn_mask_offsets_full"] = paddle.full( + [self.max_num_seqs, self.max_model_len], -1, dtype="int32" + ) + self.model_inputs["attn_mask_offsets_decoder"] = paddle.full([self.max_num_seqs, 1], -1, dtype="int32") + self.model_inputs["decode_states"] = paddle.full( + [self.max_num_seqs, self.max_draft_token_num + 1], + -1, + dtype="int32", + ) def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int): @@ -482,6 +498,16 @@ class MTPProposer(Proposer): self.model_inputs["step_idx"][idx : idx + 1] = ( len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0 ) + if self.enable_mm: + inputs = request.multimodal_inputs + self.model_inputs["attn_mask_offsets_full"][idx][0 : prefill_end_index - prefill_start_index] = ( + paddle.to_tensor( + inputs["attention_mask_offset"][prefill_start_index:prefill_end_index], dtype="int32" + ) + ) + self.model_inputs["attn_mask_offsets_decoder"][idx : idx + 1] = ( + inputs["attention_mask_offset"][prefill_end_index - 1] + 1 + ) # has_prefill_task = True elif request.task_type.value == RequestType.DECODE.value: # decode task @@ -621,6 +647,7 @@ class MTPProposer(Proposer): kv_batch_ids=self.model_inputs["kv_batch_ids"], kv_tile_ids_per_batch=self.model_inputs["kv_tile_ids_per_batch"], kv_num_blocks_x_cpu=self.model_inputs["kv_num_blocks_x_cpu"], + attn_mask_offsets=self.model_inputs["attn_mask_offsets"] if self.enable_mm else None, ) # Initialzie attention meta data @@ -754,6 +781,21 @@ class MTPProposer(Proposer): self.model_inputs["seq_lens_decoder"], ) + if self.enable_mm: + attn_mask_offsets = update_attn_mask_offsets( + ids_remove_padding, + getattr(self.model_inputs, "seq_lens_this_time", self.seq_lens_this_time_buffer), + self.model_inputs["seq_lens_encoder"], + self.model_inputs["seq_lens_decoder"], + cu_seqlens_q, + self.model_inputs["attn_mask_offsets_full"], + self.model_inputs["attn_mask_offsets_decoder"], + self.model_inputs["is_block_step"], + self.model_inputs["decode_states"], + self.model_inputs["mask_rollback"], + )[0] + self.model_inputs["attn_mask_offsets"].copy_(attn_mask_offsets, False) + # Initialize forward meta data self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) self.model_inputs["batch_id_per_token"][:] = -1 diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index ef22a75a8..2a6aa553e 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1208,6 +1208,8 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["logits_processors_args"] = [{} for _ in range(max_num_seqs)] logger.info(f"Enabled logits processors: {self.share_inputs['logits_processors']}") + self.share_inputs["mask_rollback"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") + def _prepare_inputs(self) -> None: """Prepare the model inputs""" if envs.ENABLE_V1_KVCACHE_SCHEDULER: @@ -1713,6 +1715,8 @@ class GPUModelRunner(ModelRunnerBase): accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], + prompt_lens=self.share_inputs["prompt_lens"], + mask_rollback=self.share_inputs["mask_rollback"], ) post_process( @@ -2223,6 +2227,7 @@ class GPUModelRunner(ModelRunnerBase): stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], prompt_lens=self.share_inputs["prompt_lens"], + mask_rollback=self.share_inputs["mask_rollback"], ) if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill": diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index b4192e882..7d3a006ca 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -237,6 +237,11 @@ class ModelOutputData: """ prompt_lens: paddle.Tensor = None + """ + step mask rollback in some cases + """ + mask_rollback: paddle.Tensor = None + @dataclass class ModelRunnerOutput: diff --git a/tests/operators/test_speculate_update.py b/tests/operators/test_speculate_update.py index d3dcd7e7f..7ea446d80 100644 --- a/tests/operators/test_speculate_update.py +++ b/tests/operators/test_speculate_update.py @@ -32,6 +32,7 @@ def speculate_update_np( seq_lens_this_time, is_block_step, stop_nums, + mask_rollback, ): stop_sum = 0 real_bsz = seq_lens_this_time.shape[0] @@ -47,9 +48,13 @@ def speculate_update_np( if stop_flags[bid]: stop_flag_now_int = 1 + mask_rollback[bid] = 0 - if seq_lens_encoder[bid] == 0: + elif seq_lens_encoder[bid] == 0: seq_lens_decoder[bid] += accept_num[bid] + mask_rollback[bid] = seq_lens_this_time[bid] - accept_num[bid] + else: + mask_rollback[bid] = 0 if (seq_lens_encoder[bid] == 0) and (seq_lens_this_time[bid] > 1): cur_len = actual_draft_token_nums[bid] @@ -103,6 +108,7 @@ def gen_inputs( stop_flags = rng.integers(0, 2, size=max_bsz, dtype=np.bool_) is_block_step = rng.integers(0, 2, size=max_bsz, dtype=np.bool_) stop_nums = np.array([5], dtype=np.int64) + mask_rollback = np.zeros([max_bsz], dtype=np.int32) seq_lens_this_time = rng.integers(1, max_draft_tokens, size=real_bsz, dtype=np.int32) @@ -118,6 +124,7 @@ def gen_inputs( "seq_lens_this_time": seq_lens_this_time, "is_block_step": is_block_step, "stop_nums": stop_nums, + "mask_rollback": mask_rollback, } diff --git a/tests/operators/test_tree_mask.py b/tests/operators/test_tree_mask.py index 650f9357c..b3df9750a 100644 --- a/tests/operators/test_tree_mask.py +++ b/tests/operators/test_tree_mask.py @@ -140,7 +140,9 @@ class TestTreeMask(unittest.TestCase): .reshape([-1, self.num_q_head, self.head_dim]) ) - def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None, use_qknorm=False): + def run_append_c16_attention( + self, q_len, kv_len, prefill=False, attn_mask=None, use_qknorm=False, mask_offset=None + ): if prefill: seq_lens_enc = [ q_len, @@ -274,7 +276,7 @@ class TestTreeMask(unittest.TestCase): None, # cache_v_zp None, # linear_shift None, # linear_smooth - None, # mask_offset + mask_offset, # mask_offset None, # kv_signal_data self.q_norm_weight_tensor if use_qknorm else None, # q_norm_weight self.k_norm_weight_tensor if use_qknorm else None, # k_norm_weight @@ -293,7 +295,7 @@ class TestTreeMask(unittest.TestCase): self.max_partition_size, self.encoder_max_partition_size, decoder_step_token_num, - True, + True if mask_offset is None else False, decoder_step_token_num > 1, 0, ) @@ -365,6 +367,30 @@ class TestTreeMask(unittest.TestCase): ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03 ) + def test_mask_offset(self): + prefill_len = 8192 + dec_len_q = 5 + total_len = prefill_len + dec_len_q + mask = paddle.tril(paddle.ones((self.bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len) + mask = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf"))) + self.run_append_c16_attention(prefill_len, 0, True, use_qknorm=self.use_qknorm) + + mask_offset = paddle.tile( + paddle.tensor( + [0, prefill_len + 1, 0, prefill_len + 2, 0, prefill_len + 3, 0, prefill_len + 4, 0, prefill_len + 5], + dtype="int32", + ), + [self.bsz], + ).astype("int32") + dec_out = self.run_append_c16_attention( + dec_len_q, prefill_len, False, use_qknorm=self.use_qknorm, mask_offset=mask_offset + ) + + ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask, use_qknorm=self.use_qknorm) + np.testing.assert_allclose( + ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03 + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/operators/test_update_attn_mask.py b/tests/operators/test_update_attn_mask.py new file mode 100644 index 000000000..e6abca69a --- /dev/null +++ b/tests/operators/test_update_attn_mask.py @@ -0,0 +1,277 @@ +import os +import unittest + +import numpy as np +import paddle + +# 请确保你的编译后 op 在这个路径下可导入 +from fastdeploy.model_executor.ops.gpu import update_attn_mask_offsets + + +def py_update_attn_mask_offsets_op( + ids_remove_padding_len, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + cu_seqlens_q, + attn_mask_offsets_full, + attn_mask_offsets_decoder, + is_block_step, + decode_states, + mask_rollback, +): + """ + Python-side reference op that mirrors the CUDA kernel you provided (latest version). + - ids_remove_padding_len: 总的去padding后 token 数(用于算 batch_seq_lens) + - seq_lens_*: 1D numpy int32 arrays (len == bsz) + - cu_seqlens_q: 1D numpy int32 prefix sums (len == bsz) + - attn_mask_offsets_full: numpy array shape (bsz, max_model_len) + - attn_mask_offsets_decoder: 1D numpy int32 (bsz,) + - is_block_step: 1D bool array (bsz,) + - decode_states: numpy int32 array shape (bsz, decode_states_len) + - mask_rollback: 1D numpy int32 (bsz,) or shape (bsz,1) + Returns: + attn_mask_offsets_ref (1D int32 length batch_seq_lens * 2), + decode_states_ref (bsz x decode_states_len int32) + """ + # normalize inputs + seq_lens_this_time = np.array(seq_lens_this_time, dtype=np.int32).reshape(-1) + seq_lens_encoder = np.array(seq_lens_encoder, dtype=np.int32).reshape(-1) + seq_lens_decoder = np.array(seq_lens_decoder, dtype=np.int32).reshape(-1) + cu_seqlens_q = np.array(cu_seqlens_q, dtype=np.int32).reshape(-1) + is_block_step = np.array(is_block_step, dtype=bool).reshape(-1) + attn_mask_offsets_full = np.array(attn_mask_offsets_full, dtype=np.int32) + attn_mask_offsets_decoder = np.array(attn_mask_offsets_decoder, dtype=np.int32).reshape(-1) + decode_states = np.array(decode_states, dtype=np.int32).copy() + mask_rollback = np.array(mask_rollback, dtype=np.int32).reshape(-1) + + bsz = int(seq_lens_this_time.shape[0]) + total_seq = int(np.sum(seq_lens_this_time)) + decode_states_len = int(decode_states.shape[1]) + + # CUDA creates paddle::full({batch_seq_lens * 2}, 0) + attn_mask_offsets = np.zeros((total_seq * 2,), dtype=np.int32) + + for bid in range(bsz): + if is_block_step[bid]: + # skip update for this batch entry + continue + + seq_len_this = int(seq_lens_this_time[bid]) + seq_len_enc = int(seq_lens_encoder[bid]) + seq_len_dec = int(seq_lens_decoder[bid]) + query_start = int(cu_seqlens_q[bid]) + # pointer-like views in C++: attn_mask_offsets_full_now, decode_states_now + full_now = attn_mask_offsets_full[bid] + decode_now = decode_states[bid] # this is a view into decode_states + + # stop: both zero => do nothing + if seq_len_enc == 0 and seq_len_dec == 0: + continue + + # prefill path (encoder > 0) + if seq_len_enc > 0: + for i in range(seq_len_this): + # vision generate phase check: (*decode_states_now == 2 && seq_len_decoder > 0) + # In C++ code they used '*decode_states_now == 2' — meaning first element compare. + if decode_now.size > 0 and decode_now[0] == 2 and seq_len_dec > 0: + attn_mask_offsets[(query_start + i) * 2 + 1] = seq_len_dec + seq_len_this + else: + # attn_mask_offsets_full_now[i] + 1 + attn_mask_offsets[(query_start + i) * 2 + 1] = int(full_now[i]) + 1 + # done prefill branch + continue + + # decoder path (seq_len_decoder > 0) + if seq_len_dec > 0: + # subtract mask rollback + rollback = int(mask_rollback[bid]) if bid < mask_rollback.shape[0] else 0 + attn_mask_offsets_decoder[bid] = int(attn_mask_offsets_decoder[bid]) - rollback + start = int(attn_mask_offsets_decoder[bid]) + + for i in range(seq_len_this): + attn_mask_offsets[(query_start + i) * 2 + 1] = start + 1 + i + + # advance decoder offset + attn_mask_offsets_decoder[bid] = int(attn_mask_offsets_decoder[bid]) + seq_len_this + + # speculative decoding: if seq_len_this > 1 then set decode_states_now[i] accordingly + if seq_len_this > 1: + for i in range(decode_states_len): + decode_now[i] = 0 if i < seq_len_this else -1 + # done decoder branch + continue + + return attn_mask_offsets, decode_states + + +class UpdateAttnMaskOffsetsTestCase(unittest.TestCase): + def setUp(self): + # If GPU available, use it. But we don't hard require CUDA here; op itself must be callable. + # Ensure Paddle uses GPU if available to match operator placement + try: + paddle.set_device("gpu") + except Exception: + paddle.set_device("cpu") + + def _call_and_compare( + self, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + is_block_step, + max_model_len=8, + decode_states_len=4, + vision_generate=False, + ): + # build numpy inputs + seq_lens_this_time = np.array(seq_lens_this_time, dtype=np.int32).reshape(-1) + seq_lens_encoder = np.array(seq_lens_encoder, dtype=np.int32).reshape(-1) + seq_lens_decoder = np.array(seq_lens_decoder, dtype=np.int32).reshape(-1) + bsz = seq_lens_this_time.shape[0] + total_seq = int(np.sum(seq_lens_this_time)) + cu_seqlens_q = np.zeros((bsz,), dtype=np.int32) + if bsz > 1: + cu_seqlens_q[1:] = np.cumsum(seq_lens_this_time[:-1]) + + # attn_mask_offsets_full: shape (bsz, max_model_len) + attn_mask_offsets_full = np.arange(bsz * max_model_len, dtype=np.int32).reshape(bsz, max_model_len) + + # attn_mask_offsets_decoder initial (use seq_lens_decoder as seed for deterministic test) + attn_mask_offsets_decoder = np.array(seq_lens_decoder, dtype=np.int32).copy() + + # decode_states initial + decode_states = np.full((bsz, decode_states_len), -1, dtype=np.int32) + if vision_generate: + decode_states[:, 0] = 2 # make first element 2 to trigger vision phase + + mask_rollback = np.zeros((bsz,), dtype=np.int32) + + # ids_remove_padding: length = total_seq (only length used by op) + ids_remove_padding = paddle.randint(low=0, high=10, shape=[total_seq], dtype="int32") + decode_states_tensor = paddle.to_tensor(decode_states, dtype="int32") + # prepare paddle tensors and call the compiled op + out = update_attn_mask_offsets( + ids_remove_padding, + paddle.to_tensor(seq_lens_this_time, dtype="int32"), + paddle.to_tensor(seq_lens_encoder, dtype="int32"), + paddle.to_tensor(seq_lens_decoder, dtype="int32"), + paddle.to_tensor(cu_seqlens_q, dtype="int32"), + paddle.to_tensor(attn_mask_offsets_full, dtype="int32"), + paddle.to_tensor(attn_mask_offsets_decoder, dtype="int32"), + paddle.to_tensor(np.array(is_block_step, dtype=bool).reshape(-1), dtype="bool"), + decode_states_tensor, + paddle.to_tensor(mask_rollback, dtype="int32"), + ) + + # op returns [attn_mask_offsets, decode_states_out] per your PD_BUILD_STATIC_OP outputs + if isinstance(out, (list, tuple)): + op_attn_mask_offsets = out[0].numpy().astype(np.int32).reshape(-1) + op_decode_states = out[1].numpy().astype(np.int32) + else: + # Some bindings might return single tensor and inplace decode_states update + # Try to handle that case: assume attn_mask_offsets returned and decode_states was mutated inplace. + op_attn_mask_offsets = out.numpy().astype(np.int32).reshape(-1) + # fetch decode_states by re-creating input decode_states tensor? best effort: + # (we passed decode_states as a paddle tensor; in operator we passed a copy, but PD set inplace mapping + # so many builds will actually give decode_states_out as second output; this block is fallback.) + op_decode_states = decode_states_tensor.numpy() + + # compute python reference outputs + ref_attn_mask_offsets, ref_decode_states = py_update_attn_mask_offsets_op( + ids_remove_padding_len=total_seq, + seq_lens_this_time=seq_lens_this_time, + seq_lens_encoder=seq_lens_encoder, + seq_lens_decoder=seq_lens_decoder, + cu_seqlens_q=cu_seqlens_q, + attn_mask_offsets_full=attn_mask_offsets_full, + attn_mask_offsets_decoder=attn_mask_offsets_decoder.copy(), + is_block_step=np.array(is_block_step, dtype=bool).reshape(-1), + decode_states=decode_states.copy(), + mask_rollback=mask_rollback, + ) + + # optionally print debug if env var set + if os.environ.get("ATTN_MASK_TEST_DEBUG", "0") == "1": + print("=== DEBUG ===") + print("seq_lens_this_time:", seq_lens_this_time) + print("seq_lens_encoder:", seq_lens_encoder) + print("seq_lens_decoder:", seq_lens_decoder) + print("cu_seqlens_q:", cu_seqlens_q) + print("ref_attn_mask_offsets:", ref_attn_mask_offsets) + print("op_attn_mask_offsets:", op_attn_mask_offsets) + print("ref_decode_states:", ref_decode_states) + print("op_decode_states:", op_decode_states) + print("=============") + + # shape checks + self.assertEqual( + op_attn_mask_offsets.shape, + ref_attn_mask_offsets.shape, + f"attn_mask_offsets shape mismatch: op {op_attn_mask_offsets.shape}, ref {ref_attn_mask_offsets.shape}", + ) + # element-wise equality + np.testing.assert_array_equal(op_attn_mask_offsets, ref_attn_mask_offsets) + np.testing.assert_array_equal(op_decode_states, ref_decode_states) + + # --- Test cases below (cover branches) --- + + def test_stop_case(self): + # stop: both encoder and decoder are zero -> nothing written (all zeros) + self._call_and_compare( + seq_lens_this_time=[1], + seq_lens_encoder=[0], + seq_lens_decoder=[0], + is_block_step=[False], + max_model_len=4, + decode_states_len=2, + ) + + def test_prefill_case(self): + # prefill: encoder > 0, should copy attn_mask_offsets_full[i] + 1 into positions ((q+i)*2+1) + self._call_and_compare( + seq_lens_this_time=[3], + seq_lens_encoder=[3], + seq_lens_decoder=[0], + is_block_step=[False], + max_model_len=8, + decode_states_len=4, + ) + + def test_vision_generate_prefill(self): + # vision generate: decode_states[0] == 2 and seq_len_decoder > 0 triggers alternate write + self._call_and_compare( + seq_lens_this_time=[2], + seq_lens_encoder=[2], + seq_lens_decoder=[5], # >0 to activate vision branch + is_block_step=[False], + max_model_len=8, + decode_states_len=4, + vision_generate=True, + ) + + def test_decoder_case(self): + # decoder path: should write attn_mask_offsets_decoder - rollback + 1 .. +seq_len_this_time-1 + self._call_and_compare( + seq_lens_this_time=[2], + seq_lens_encoder=[0], + seq_lens_decoder=[7], + is_block_step=[False], + max_model_len=8, + decode_states_len=6, + ) + + def test_mixed_batch_case(self): + # mixed batch with different statuses + self._call_and_compare( + seq_lens_this_time=[2, 4, 1], + seq_lens_encoder=[0, 4, 0], + seq_lens_decoder=[5, 0, 1], + is_block_step=[False, False, False], + max_model_len=12, + decode_states_len=2, + ) + + +if __name__ == "__main__": + unittest.main()