diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh index a01cc1e53..cc537e46c 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh @@ -52,6 +52,7 @@ __global__ void multi_query_append_attention_kernel( const float quant_min_bound, const float in_scale, const uint32_t chunk_size, + const int num_blocks_x_cpu, T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, // num_heads, head_dim] float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] @@ -74,6 +75,11 @@ __global__ void multi_query_append_attention_kernel( block_table_now = block_table + batch_id * max_block_num_per_seq; + //When cudagraph capture prefill, may launch more gridDim.x + if(btid >= static_cast(num_blocks_x_cpu)){ + return; + } + const uint32_t q_len = seq_lens[batch_id]; if (q_len <= 0) { return; @@ -422,6 +428,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel( const float quant_min_bound, const float in_scale, const uint32_t chunk_size, + const int num_blocks_x_cpu, T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, // num_heads, head_dim] float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] @@ -445,6 +452,11 @@ __global__ void multi_query_append_attention_warp1_4_kernel( const uint32_t num_rows_per_block = num_frags_x * 16; const int *block_table_now = block_table + batch_id * max_block_num_per_seq; + //When cudagraph capture prefill, may launch more gridDim.x + if(btid >= static_cast(num_blocks_x_cpu)){ + return; + } + const uint32_t q_len = seq_lens[batch_id]; if (q_len <= 0) { return; @@ -902,6 +914,7 @@ void MultiQueryAppendAttention( quant_min_bound, in_scale, chunk_size, + num_blocks_x_cpu, nullptr, nullptr, nullptr, @@ -960,6 +973,7 @@ void MultiQueryAppendAttention( quant_min_bound, in_scale, chunk_size, + num_blocks_x_cpu, reinterpret_cast(tmp_workspace->ptr()), static_cast(tmp_m->ptr()), static_cast(tmp_d->ptr()), @@ -1134,6 +1148,7 @@ void MultiQueryAppendAttention( quant_min_bound, in_scale, chunk_size, + num_blocks_x_cpu, nullptr, nullptr, nullptr, @@ -1206,6 +1221,7 @@ void MultiQueryAppendAttention( quant_min_bound, in_scale, chunk_size, + num_blocks_x_cpu, reinterpret_cast(tmp_workspace->ptr()), static_cast(tmp_m->ptr()), static_cast(tmp_d->ptr()), diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh index 16ef7e9e4..49317bfdf 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh @@ -57,6 +57,7 @@ __global__ void multi_query_append_attention_c4_kernel( const float quant_min_bound, const float in_scale, const uint32_t chunk_size, + const int num_blocks_x_cpu, T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, // num_heads, head_dim] float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] @@ -85,6 +86,11 @@ __global__ void multi_query_append_attention_c4_kernel( block_table_now = block_table + batch_id * max_block_num_per_seq; + //When cudagraph capture prefill, may launch more gridDim.x + if(btid >= static_cast(num_blocks_x_cpu)){ + return; + } + const uint32_t q_len = seq_lens[batch_id]; if (q_len <= 0) { return; @@ -520,6 +526,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( const float quant_min_bound, const float in_scale, const uint32_t chunk_size, + const int num_blocks_x_cpu, T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, // num_heads, head_dim] float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] @@ -549,6 +556,11 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( const uint32_t num_rows_per_block = num_frags_x * 16; const int *block_table_now = block_table + batch_id * max_block_num_per_seq; + //When cudagraph capture prefill, may launch more gridDim.x + if(btid >= static_cast(num_blocks_x_cpu)){ + return; + } + const uint32_t q_len = seq_lens[batch_id]; if (q_len <= 0) { return; @@ -1107,6 +1119,7 @@ void MultiQueryAppendC4Attention( quant_min_bound, in_scale, chunk_size, + num_blocks_x_cpu, nullptr, nullptr, nullptr, @@ -1171,6 +1184,7 @@ void MultiQueryAppendC4Attention( quant_min_bound, in_scale, chunk_size, + num_blocks_x_cpu, reinterpret_cast(tmp_workspace->ptr()), static_cast(tmp_m->ptr()), static_cast(tmp_d->ptr()), @@ -1365,6 +1379,7 @@ void MultiQueryAppendC4Attention( quant_min_bound, in_scale, chunk_size, + num_blocks_x_cpu, nullptr, nullptr, nullptr, @@ -1445,6 +1460,7 @@ void MultiQueryAppendC4Attention( quant_min_bound, in_scale, chunk_size, + num_blocks_x_cpu, reinterpret_cast(tmp_workspace->ptr()), static_cast(tmp_m->ptr()), static_cast(tmp_d->ptr()), diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh index 77ba87814..b2fe4c6f6 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh @@ -58,6 +58,7 @@ __global__ void multi_query_append_attention_c8_kernel( const float quant_min_bound, const float in_scale, const uint32_t chunk_size, + const int num_blocks_x_cpu, T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, // num_heads, head_dim] float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] @@ -87,6 +88,11 @@ __global__ void multi_query_append_attention_c8_kernel( block_table_now = block_table + batch_id * max_block_num_per_seq; + //When cudagraph capture prefill, may launch more gridDim.x + if(btid >= static_cast(num_blocks_x_cpu)){ + return; + } + const uint32_t q_len = seq_lens[batch_id]; if (q_len <= 0) { return; @@ -527,6 +533,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( const float quant_min_bound, const float in_scale, const uint32_t chunk_size, + const int num_blocks_x_cpu, T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, // num_heads, head_dim] float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] @@ -556,6 +563,11 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( const uint32_t num_rows_per_block = num_frags_x * 16; const int *block_table_now = block_table + batch_id * max_block_num_per_seq; + //When cudagraph capture prefill, may launch more gridDim.x + if(btid >= static_cast(num_blocks_x_cpu)){ + return; + } + const uint32_t q_len = seq_lens[batch_id]; if (q_len <= 0) { return; @@ -1159,6 +1171,7 @@ void MultiQueryAppendC8Attention( quant_min_bound, in_scale, chunk_size, + num_blocks_x_cpu, nullptr, nullptr, nullptr, @@ -1217,6 +1230,7 @@ void MultiQueryAppendC8Attention( quant_min_bound, in_scale, chunk_size, + num_blocks_x_cpu, reinterpret_cast(tmp_workspace->ptr()), static_cast(tmp_m->ptr()), static_cast(tmp_d->ptr()), @@ -1443,6 +1457,7 @@ void MultiQueryAppendC8Attention( quant_min_bound, in_scale, chunk_size, + num_blocks_x_cpu, nullptr, nullptr, nullptr, @@ -1517,6 +1532,7 @@ void MultiQueryAppendC8Attention( quant_min_bound, in_scale, chunk_size, + num_blocks_x_cpu, reinterpret_cast(tmp_workspace->ptr()), static_cast(tmp_m->ptr()), static_cast(tmp_d->ptr()), diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index b9c951d39..dbe072250 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -191,14 +191,21 @@ get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time, } } -std::vector GetBlockShapeAndSplitKVBlock( +void GetBlockShapeAndSplitKVBlock( const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_this_time, paddle::Tensor &decoder_batch_ids, // Inplace paddle::Tensor &decoder_tile_ids_per_batch, // Inplace paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory - paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory + paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU + paddle::Tensor &encoder_batch_ids, // Inplace + paddle::Tensor &encoder_tile_ids_per_batch, // Inplace + paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, CPU + paddle::Tensor &kv_batch_ids, // Inplace + paddle::Tensor &kv_tile_ids_per_batch, // Inplace + paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU + paddle::Tensor &max_len_kv_cpu, // Inplace, CPU const int encoder_block_shape_q, const int decoder_block_shape_q, const int group_size, @@ -223,13 +230,7 @@ std::vector GetBlockShapeAndSplitKVBlock( int max_system_len = max_len_cpu_ptr[6]; int max_just_dec_len_without_system = max_len_cpu_ptr[7]; - paddle::Tensor encoder_batch_ids; - paddle::Tensor encoder_tile_ids_per_batch; - paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/ - paddle::Tensor kv_batch_ids; - paddle::Tensor kv_tile_ids_per_batch; - paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/ - paddle::Tensor max_len_kv_cpu; /*cpu*/ + auto max_len_kv = GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place()); @@ -237,17 +238,14 @@ std::vector GetBlockShapeAndSplitKVBlock( max_len_kv.data(), seq_lens_this_time.data(), seq_lens_decoder.data(), bsz); - max_len_kv_cpu = max_len_kv.copy_to(paddle::CPUPlace(), false); + + max_len_kv_cpu.copy_(max_len_kv, max_len_kv_cpu.place(), false); if (max_enc_len_this_time > 0) { - const uint32_t max_tile_size_per_bs_kv = - div_up(max_enc_dec_len_this_time, block_size); - kv_batch_ids = - GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32, - seq_lens_encoder.place()); - kv_tile_ids_per_batch = - GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32, - seq_lens_encoder.place()); + const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size); + const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv; + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_batch_ids.data(), 0, kv_batch_shape * sizeof(int32_t), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_tile_ids_per_batch.data(), 0, kv_batch_shape * sizeof(int32_t), stream)); auto kv_num_blocks_x = GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place()); @@ -258,16 +256,12 @@ std::vector GetBlockShapeAndSplitKVBlock( kv_tile_ids_per_batch.data(), kv_num_blocks_x.data(), bsz, block_size, block_size); - kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to(paddle::CPUPlace(), false); - - const uint32_t encoder_max_tile_size_per_bs_q = - div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q); - encoder_batch_ids = - GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q}, - paddle::DataType::INT32, seq_lens_encoder.place()); - encoder_tile_ids_per_batch = - GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q}, - paddle::DataType::INT32, seq_lens_encoder.place()); + kv_num_blocks_x_cpu.copy_(kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false); + // Clear buffer + const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q); + const uint32_t encoder_batch_shape = bsz * encoder_max_tile_size_per_bs_q; + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_batch_ids.data(), 0, encoder_batch_shape * sizeof(int32_t), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_tile_ids_per_batch.data(), 0, encoder_batch_shape * sizeof(int32_t), stream)); auto encoder_num_blocks_x = GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place()); split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data(), nullptr, @@ -275,21 +269,7 @@ std::vector GetBlockShapeAndSplitKVBlock( encoder_tile_ids_per_batch.data(), encoder_num_blocks_x.data(), bsz, encoder_block_shape_q, group_size); - encoder_num_blocks_x_cpu = - encoder_num_blocks_x.copy_to(paddle::CPUPlace(), false); - } else { - encoder_batch_ids = - GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); - encoder_tile_ids_per_batch = - GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); - encoder_num_blocks_x_cpu = - GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace()); - kv_batch_ids = - GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); - kv_tile_ids_per_batch = - GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); - kv_num_blocks_x_cpu = - GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace()); + encoder_num_blocks_x_cpu.copy_(encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false); } if (max_just_dec_len_this_time > 0) { @@ -314,15 +294,6 @@ std::vector GetBlockShapeAndSplitKVBlock( decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false); } - return { - encoder_batch_ids, - encoder_tile_ids_per_batch, - encoder_num_blocks_x_cpu, /*cpu*/ - kv_batch_ids, - kv_tile_ids_per_batch, - kv_num_blocks_x_cpu, /*cpu*/ - max_len_kv_cpu, /*cpu*/ - }; } PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block) @@ -333,16 +304,17 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block) "decoder_batch_ids", "decoder_tile_ids_per_batch", "decoder_num_blocks_x_cpu", - "max_len_tensor_cpu" + "max_len_tensor_cpu", + "encoder_batch_ids", + "encoder_tile_ids_per_batch", + "encoder_num_blocks_x_cpu", + "kv_batch_ids", + "kv_tile_ids_per_batch", + "kv_num_blocks_x_cpu", + "max_len_kv_cpu" }) .Outputs({ - paddle::Optional("encoder_batch_ids"), - paddle::Optional("encoder_tile_ids_per_batch"), - paddle::Optional("encoder_num_blocks_x_cpu"), - paddle::Optional("kv_batch_ids"), - paddle::Optional("kv_tile_ids_per_batch"), - paddle::Optional("kv_num_blocks_x_cpu"), - "max_len_kv_cpu" + }) .Attrs({ "encoder_block_shape_q: int", diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index b0bb23604..88280b079 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -299,7 +299,7 @@ paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank, const int device_id, paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor &kv_signal_metadata, const int layer_id); -std::vector GetBlockShapeAndSplitKVBlock( +void GetBlockShapeAndSplitKVBlock( const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_this_time, @@ -307,6 +307,13 @@ std::vector GetBlockShapeAndSplitKVBlock( paddle::Tensor &decoder_tile_ids_per_batch, // Inplace paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory + paddle::Tensor &encoder_batch_ids, // Inplace + paddle::Tensor &encoder_tile_ids_per_batch, // Inplace + paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, Pinned Memory + paddle::Tensor &kv_batch_ids, // Inplace + paddle::Tensor &kv_tile_ids_per_batch, // Inplace + paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, Pinned Memory + paddle::Tensor &max_len_kv_cpu, // Inplace, Pinned Memory const int encoder_block_shape_q, const int decoder_block_shape_q, const int group_size, diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 73f318db7..105242058 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -580,6 +580,10 @@ class GraphOptimizationConfig: """ Whether to use a full cuda graph for the entire forward pass rather than splitting certain operations such as attention into subgraphs. Thus this flag cannot be used together with splitting_ops.""" + self.cudagraph_only_prefill: bool = False + """When cudagraph_only_prefill is False, only capture decode-only. + When cudagraph_only_prefill is True, only capture prefill-only. + Now don't support capture both decode-only and prefill-only""" self.full_cuda_graph: bool = True self.max_capture_size: int = None @@ -592,13 +596,13 @@ class GraphOptimizationConfig: self.check_legality_parameters() - def init_with_cudagrpah_size(self, max_num_seqs: int = 0) -> None: + def init_with_cudagrpah_size(self, max_capture_size: int = 0) -> None: """ Initialize cuda graph capture sizes and pre-compute the mapping from batch size to padded graph size """ # Regular capture sizes - self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_num_seqs] + self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_capture_size] dedup_sizes = list(set(self.cudagraph_capture_sizes)) if len(dedup_sizes) < len(self.cudagraph_capture_sizes): logger.info( @@ -632,7 +636,7 @@ class GraphOptimizationConfig: # Shape [128, 144, ... 240, 256] draft_capture_sizes += [16 * i for i in range(9, 17)] # Shape [256, 288, ... 992, 1024] - draft_capture_sizes += [32 * i for i in range(17, 33)] + draft_capture_sizes += [32 * i for i in range(9, 33)] draft_capture_sizes.append(max_num_seqs) self.cudagraph_capture_sizes = sorted(draft_capture_sizes) @@ -1140,7 +1144,11 @@ class FDConfig: # Initialize cuda graph capture list if self.graph_opt_config.cudagraph_capture_sizes is None: self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs) - self.graph_opt_config.init_with_cudagrpah_size(max_num_seqs=self.parallel_config.max_num_seqs) + + if self.graph_opt_config.cudagraph_only_prefill: + self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=512) + else: + self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=self.parallel_config.max_num_seqs) # TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn if self.graph_opt_config.graph_opt_level == 2: diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index 968495733..c775beaaf 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -81,14 +81,42 @@ class ForwardMeta: attn_mask: Optional[paddle.Tensor] = None # Attention mask offset attn_mask_offsets: Optional[paddle.Tensor] = None + + # A common pattern for launching CUDA kernels is to set the kernel's grids.x dimension + # using a `num_blocks` variable, and then map each thread block to a specific batch and + # data tile using `batch_ids` and `tile_ids_per_batch`. + # + # The variable names below follow this pattern, using a common prefix (e.g., `encoder_`, `decoder_`, `kv_`) + # for variables that are logically grouped together. The mapping works as follows: + # + # Usage: `my_kernel<<>>(..., batch_ids, tile_ids, ...)` + # `grids.x` = `num_blocks_cpu` + # `batch_id` = `batch_ids[blockIdx.x]` + # `tile_id` = `tile_ids[blockIdx.x]` + + # Maps the thread block index (blockIdx.x) to the corresponding batch for the decoder stage in multi_query_append_attention_warp1_4_kernel. # Decoder batch id. Used by attention backend. decoder_batch_ids: Optional[paddle.Tensor] = None - # Tile ID for each batch of the decoder. Used by attention backend. + # Maps the thread block index (blockIdx.x) to the specific data tile being processed within that batch for the decoder stage in multi_query_append_attention_warp1_4_kernel. decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None - # The number of blocks that attention backend can use in decode stage + # The number of CUDA blocks to launch in the x-dimension for the multi_query_append_attention_warp1_4_kernel, defining its grids.x. decoder_num_blocks_cpu: Optional[paddle.Tensor] = None - # Recorded multiple lengths related to prefill or decode + # A tensor that holds multiple lengths related to prefill or decode stages. max_len_tensor_cpu: Optional[paddle.Tensor] = None + # Maps the thread block index (blockIdx.x) to the corresponding batch for the encoder stage in multi_query_append_attention_kernel. + encoder_batch_ids: Optional[paddle.Tensor] = None + # Maps the thread block index (blockIdx.x) to the specific data tile being processed within that batch for the encoder stage in multi_query_append_attention_kernel. + encoder_tile_ids_per_batch: Optional[paddle.Tensor] = None + # The number of CUDA blocks to launch in the x-dimension for the multi_query_append_attention_kernel, defining its grids.x. + encoder_num_blocks_x_cpu: Optional[paddle.Tensor] = None + # Maps the thread block index (blockIdx.x) to the corresponding batch for the append_write_cache_kv kernel. + kv_batch_ids: Optional[paddle.Tensor] = None + # Maps the thread block index (blockIdx.x) to the specific data tile being processed within that batch for the append_write_cache_kv kernel. + kv_tile_ids_per_batch: Optional[paddle.Tensor] = None + # The number of CUDA blocks to launch in the x-dimension for the append_write_cache_kv kernel, defining its grids.x. + kv_num_blocks_x_cpu: Optional[paddle.Tensor] = None + # The maximum sequence length of the KV cache, which may represent the current maximum decoder length. + max_len_kv_cpu: Optional[paddle.Tensor] = None # Sequence length of encoder for ever batch seq_lens_encoder: Optional[paddle.Tensor] = None @@ -133,6 +161,7 @@ class ForwardMeta: "shape": obj.shape, "dtype": str(obj.dtype), "place": str(obj.place), + # "content": obj if obj.numel()<10 else "Too big to show" } return tensor_info elif isinstance(obj, (list, tuple)): diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index aa47aa391..64023e7e2 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -49,14 +49,6 @@ class AppendAttentionMetadata(AttentionMetadata): AppendAttentionMetadata """ - encoder_batch_ids: paddle.Tensor = None - encoder_tile_ids_per_batch: paddle.Tensor = None - encoder_num_blocks: paddle.Tensor = None - kv_batch_ids: paddle.Tensor = None - kv_tile_ids_per_batch: paddle.Tensor = None - kv_num_blocks: paddle.Tensor = None - max_len_kv: paddle.Tensor = None - _dtype: paddle.dtype = paddle.bfloat16 encoder_max_partition_size: int = 32768 max_partition_size: int = 32768 @@ -142,15 +134,7 @@ class AppendAttentionBackend(AttentionBackend): metadata.rotary_embs = forward_meta.rotary_embs metadata.attn_mask = forward_meta.attn_mask metadata.pre_caches_length = forward_meta.pre_caches_length - ( - metadata.encoder_batch_ids, - metadata.encoder_tile_ids_per_batch, - metadata.encoder_num_blocks, - metadata.kv_batch_ids, - metadata.kv_tile_ids_per_batch, - metadata.kv_num_blocks, - metadata.max_len_kv, - ) = get_block_shape_and_split_kv_block( + get_block_shape_and_split_kv_block( forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, @@ -158,6 +142,13 @@ class AppendAttentionBackend(AttentionBackend): forward_meta.decoder_tile_ids_per_batch, forward_meta.decoder_num_blocks_cpu, forward_meta.max_len_tensor_cpu, + forward_meta.encoder_batch_ids, + forward_meta.encoder_tile_ids_per_batch, + forward_meta.encoder_num_blocks_x_cpu, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, + forward_meta.max_len_kv_cpu, self.encoder_block_shape_q, self.decoder_block_shape_q, self.group_size, @@ -288,17 +279,17 @@ class AppendAttentionBackend(AttentionBackend): forward_meta.batch_id_per_token, forward_meta.cu_seqlens_q, metadata.block_tables, - metadata.encoder_batch_ids, - metadata.encoder_tile_ids_per_batch, - metadata.encoder_num_blocks, - metadata.kv_batch_ids, - metadata.kv_tile_ids_per_batch, - metadata.kv_num_blocks, + forward_meta.encoder_batch_ids, + forward_meta.encoder_tile_ids_per_batch, + forward_meta.encoder_num_blocks_x_cpu, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, forward_meta.decoder_batch_ids, forward_meta.decoder_tile_ids_per_batch, forward_meta.decoder_num_blocks_cpu, forward_meta.max_len_tensor_cpu, - metadata.max_len_kv, + forward_meta.max_len_kv_cpu, res, metadata.rotary_embs, metadata.attn_mask, @@ -344,17 +335,17 @@ class AppendAttentionBackend(AttentionBackend): forward_meta.batch_id_per_token, forward_meta.cu_seqlens_q, metadata.block_tables, - metadata.encoder_batch_ids, - metadata.encoder_tile_ids_per_batch, - metadata.encoder_num_blocks, - metadata.kv_batch_ids, - metadata.kv_tile_ids_per_batch, - metadata.kv_num_blocks, + forward_meta.encoder_batch_ids, + forward_meta.encoder_tile_ids_per_batch, + forward_meta.encoder_num_blocks_x_cpu, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, forward_meta.decoder_batch_ids, forward_meta.decoder_tile_ids_per_batch, forward_meta.decoder_num_blocks_cpu, forward_meta.max_len_tensor_cpu, - metadata.max_len_kv, + forward_meta.max_len_kv_cpu, metadata.rotary_embs, metadata.attn_mask, layer.qkv_bias, diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index 8f220ddb9..c4c504368 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -65,13 +65,6 @@ class FlashAttentionMetadata(AttentionMetadata): rotary_embs: Optional[paddle.Tensor] = None block_tables: Optional[paddle.Tensor] = None - encoder_batch_ids: paddle.Tensor = None - encoder_tile_ids_per_batch: paddle.Tensor = None - encoder_num_blocks: paddle.Tensor = None - kv_batch_ids: paddle.Tensor = None - kv_tile_ids_per_batch: paddle.Tensor = None - kv_num_blocks: paddle.Tensor = None - max_len_kv: paddle.Tensor = None cu_seqlens_q: paddle.Tensor = None cu_seqlens_k: paddle.Tensor = None @@ -198,15 +191,7 @@ class FlashAttentionBackend(AttentionBackend): metadata.cu_seqlens_q = forward_meta.cu_seqlens_q metadata.rotary_embs = forward_meta.rotary_embs metadata.block_tables = forward_meta.block_tables - ( - metadata.encoder_batch_ids, - metadata.encoder_tile_ids_per_batch, - metadata.encoder_num_blocks, - metadata.kv_batch_ids, - metadata.kv_tile_ids_per_batch, - metadata.kv_num_blocks, - metadata.max_len_kv, - ) = get_block_shape_and_split_kv_block( + get_block_shape_and_split_kv_block( forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, @@ -214,6 +199,13 @@ class FlashAttentionBackend(AttentionBackend): forward_meta.decoder_tile_ids_per_batch, forward_meta.decoder_num_blocks_cpu, forward_meta.max_len_tensor_cpu, + forward_meta.encoder_batch_ids, + forward_meta.encoder_tile_ids_per_batch, + forward_meta.encoder_num_blocks_x_cpu, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, + forward_meta.max_len_kv_cpu, self.encoder_block_shape_q, self.decoder_block_shape_q, self.group_size, @@ -295,9 +287,9 @@ class FlashAttentionBackend(AttentionBackend): forward_meta.seq_lens_decoder, forward_meta.batch_id_per_token, metadata.block_tables, - metadata.kv_batch_ids, - metadata.kv_tile_ids_per_batch, - metadata.kv_num_blocks, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, metadata.pre_cache_batch_ids, metadata.pre_cache_tile_ids_per_batch, metadata.pre_cache_num_blocks_cpu, @@ -336,17 +328,17 @@ class FlashAttentionBackend(AttentionBackend): forward_meta.batch_id_per_token, forward_meta.cu_seqlens_q, metadata.block_tables, - metadata.encoder_batch_ids, - metadata.encoder_tile_ids_per_batch, - metadata.encoder_num_blocks, - metadata.kv_batch_ids, - metadata.kv_tile_ids_per_batch, - metadata.kv_num_blocks, + forward_meta.encoder_batch_ids, + forward_meta.encoder_tile_ids_per_batch, + forward_meta.encoder_num_blocks_x_cpu, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, forward_meta.decoder_batch_ids, # from buffer forward_meta.decoder_tile_ids_per_batch, # from buffer forward_meta.decoder_num_blocks_cpu, metadata.max_len_tensor_cpu_decoder, - metadata.max_len_kv, + forward_meta.max_len_kv_cpu, metadata.rotary_embs, forward_meta.attn_mask, layer.qkv_bias, diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 2cf961f21..724f6eae5 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -69,14 +69,6 @@ class MLAAttentionMetadata(AttentionMetadata): MLAAttentionMetadata for Multi-Layer Attention """ - encoder_batch_ids: paddle.Tensor = None - encoder_tile_ids_per_batch: paddle.Tensor = None - encoder_num_blocks: paddle.Tensor = None - kv_batch_ids: paddle.Tensor = None - kv_tile_ids_per_batch: paddle.Tensor = None - kv_num_blocks: paddle.Tensor = None - max_len_kv: paddle.Tensor = None - _dtype: paddle.dtype = paddle.bfloat16 encoder_max_partition_size: int = 32768 max_partition_size: int = 32768 @@ -191,15 +183,7 @@ class MLAAttentionBackend(AttentionBackend): metadata.attn_mask = forward_meta.attn_mask metadata.pre_caches_length = forward_meta.pre_caches_length - ( - metadata.encoder_batch_ids, - metadata.encoder_tile_ids_per_batch, - metadata.encoder_num_blocks, - metadata.kv_batch_ids, - metadata.kv_tile_ids_per_batch, - metadata.kv_num_blocks, - metadata.max_len_kv, - ) = get_block_shape_and_split_kv_block( + get_block_shape_and_split_kv_block( forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, @@ -207,6 +191,13 @@ class MLAAttentionBackend(AttentionBackend): forward_meta.decoder_tile_ids_per_batch, forward_meta.decoder_num_blocks_cpu, forward_meta.max_len_tensor_cpu, + forward_meta.encoder_batch_ids, + forward_meta.encoder_tile_ids_per_batch, + forward_meta.encoder_num_blocks_x_cpu, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, + forward_meta.max_len_kv_cpu, self.encoder_block_shape_q, self.decoder_block_shape_q, self.group_size, @@ -362,19 +353,19 @@ class MLAAttentionBackend(AttentionBackend): forward_meta.cu_seqlens_q, forward_meta.batch_id_per_token, metadata.block_tables, - metadata.encoder_batch_ids, - metadata.encoder_tile_ids_per_batch, - metadata.encoder_num_blocks, - metadata.kv_batch_ids, - metadata.kv_tile_ids_per_batch, - metadata.kv_num_blocks, + forward_meta.encoder_batch_ids, + forward_meta.encoder_tile_ids_per_batch, + forward_meta.encoder_num_blocks_x_cpu, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, forward_meta.decoder_batch_ids, forward_meta.decoder_tile_ids_per_batch, forward_meta.decoder_num_blocks_cpu, forward_meta.decoder_num_blocks_cpu, metadata.max_enc_len_this_time, metadata.max_dec_len_this_time, - metadata.max_len_kv, + forward_meta.max_len_kv_cpu, None, # attn_mask None, # qkv_bias None, # qkv_out_scales @@ -483,19 +474,19 @@ class MLAAttentionBackend(AttentionBackend): forward_meta.cu_seqlens_q, forward_meta.batch_id_per_token, metadata.block_tables, - metadata.encoder_batch_ids, - metadata.encoder_tile_ids_per_batch, - metadata.encoder_num_blocks, - metadata.kv_batch_ids, - metadata.kv_tile_ids_per_batch, - metadata.kv_num_blocks, + forward_meta.encoder_batch_ids, + forward_meta.encoder_tile_ids_per_batch, + forward_meta.encoder_num_blocks_x_cpu, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, forward_meta.decoder_batch_ids, forward_meta.decoder_tile_ids_per_batch, forward_meta.decoder_num_blocks_cpu, forward_meta.decoder_num_blocks_cpu, metadata.max_enc_len_this_time, metadata.max_dec_len_this_time, - metadata.max_len_kv, + forward_meta.max_len_kv_cpu, None, # attn_mask None, # qkv_bias None, # qkv_out_scales diff --git a/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py b/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py index dd57b5259..68a7402b8 100644 --- a/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py +++ b/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py @@ -32,6 +32,13 @@ def get_block_shape_and_split_kv_block( decoder_tile_ids_per_batch: paddle.Tensor, decoder_num_blocks_x_cpu: paddle.Tensor, max_len_tensor_cpu: paddle.Tensor, + encoder_batch_ids: paddle.Tensor, + encoder_tile_ids_per_batch: paddle.Tensor, + encoder_num_blocks_x_cpu: paddle.Tensor, + kv_batch_ids: paddle.Tensor, + kv_tile_ids_per_batch: paddle.Tensor, + kv_num_blocks_x_cpu: paddle.Tensor, + max_len_kv_cpu: paddle.Tensor, encoder_block_shape_q: int, decoder_block_shape_q: int, group_size: int, @@ -42,15 +49,7 @@ def get_block_shape_and_split_kv_block( get_block_shape_and_split_kv_block """ if current_platform.is_cuda(): - ( - encoder_batch_ids, - encoder_tile_ids_per_batch, - encoder_num_blocks, - kv_batch_ids, - kv_tile_ids_per_batch, - kv_num_blocks, - max_len_kv_cpu, - ) = get_block_shape_and_split_kv_block_cuda( + get_block_shape_and_split_kv_block_cuda( seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, @@ -58,20 +57,19 @@ def get_block_shape_and_split_kv_block( decoder_tile_ids_per_batch, decoder_num_blocks_x_cpu, max_len_tensor_cpu, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_x_cpu, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + max_len_kv_cpu, encoder_block_shape_q, decoder_block_shape_q, group_size, block_size, decoder_step_token_num, ) - return ( - encoder_batch_ids, - encoder_tile_ids_per_batch, - encoder_num_blocks, - kv_batch_ids, - kv_tile_ids_per_batch, - kv_num_blocks, - max_len_kv_cpu, - ) + else: raise NotImplementedError diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 1397c79bf..45c727419 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -212,6 +212,22 @@ class MTPProposer(Proposer): self.target_model_inputs["max_len_tensor_cpu"] ).cpu() + self.model_inputs["encoder_batch_ids"] = paddle.zeros_like(self.target_model_inputs["encoder_batch_ids"]) + self.model_inputs["encoder_tile_ids_per_batch"] = paddle.zeros_like( + self.target_model_inputs["encoder_tile_ids_per_batch"] + ) + self.model_inputs["encoder_num_blocks_x_cpu"] = paddle.zeros_like( + self.target_model_inputs["encoder_num_blocks_x_cpu"] + ).cpu() + self.model_inputs["kv_batch_ids"] = paddle.zeros_like(self.target_model_inputs["kv_batch_ids"]) + self.model_inputs["kv_tile_ids_per_batch"] = paddle.zeros_like( + self.target_model_inputs["kv_tile_ids_per_batch"] + ) + self.model_inputs["kv_num_blocks_x_cpu"] = paddle.zeros_like( + self.target_model_inputs["kv_num_blocks_x_cpu"] + ).cpu() + self.model_inputs["max_len_kv_cpu"] = paddle.zeros_like(self.target_model_inputs["max_len_kv_cpu"]).cpu() + # Get the attention backend attn_cls = get_attention_backend() attn_backend = attn_cls( @@ -321,6 +337,13 @@ class MTPProposer(Proposer): self.model_inputs["decoder_tile_ids_per_batch"] = None self.model_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory self.model_inputs["max_len_tensor_cpu"] = None # CPU + self.model_inputs["encoder_batch_ids"] = None + self.model_inputs["encoder_tile_ids_per_batch"] = None + self.model_inputs["encoder_num_blocks_x_cpu"] = None # CPU + self.model_inputs["kv_batch_ids"] = None + self.model_inputs["kv_tile_ids_per_batch"] = None + self.model_inputs["kv_num_blocks_x_cpu"] = None # CPU + self.model_inputs["max_len_kv_cpu"] = None # CPU # Input tokens self.model_inputs["draft_tokens"] = paddle.full( @@ -512,6 +535,13 @@ class MTPProposer(Proposer): cu_seqlens_k=self.model_inputs["cu_seqlens_k"], block_tables=self.model_inputs["block_tables"], caches=self.model_inputs["caches"], + encoder_batch_ids=self.model_inputs["encoder_batch_ids"], + encoder_tile_ids_per_batch=self.model_inputs["encoder_tile_ids_per_batch"], + encoder_num_blocks_x_cpu=self.model_inputs["encoder_num_blocks_x_cpu"], + kv_batch_ids=self.model_inputs["kv_batch_ids"], + kv_tile_ids_per_batch=self.model_inputs["kv_tile_ids_per_batch"], + kv_num_blocks_x_cpu=self.model_inputs["kv_num_blocks_x_cpu"], + max_len_kv_cpu=self.model_inputs["max_len_kv_cpu"], ) # Initialzie attention meta data diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py index 07341c23b..ee56d032e 100644 --- a/fastdeploy/worker/gcu_model_runner.py +++ b/fastdeploy/worker/gcu_model_runner.py @@ -430,6 +430,13 @@ class GCUModelRunner(ModelRunnerBase): self.share_inputs["decoder_tile_ids_per_batch"] = None self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory self.share_inputs["max_len_tensor_cpu"] = None # CPU + self.share_inputs["encoder_batch_ids"] = None + self.share_inputs["encoder_tile_ids_per_batch"] = None + self.share_inputs["encoder_num_blocks_x_cpu"] = None # CPU + self.share_inputs["kv_batch_ids"] = None + self.share_inputs["kv_tile_ids_per_batch"] = None + self.share_inputs["kv_num_blocks_x_cpu"] = None # CPU + self.share_inputs["max_len_kv_cpu"] = None # CPU # Initialize rotary position embedding tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1)) @@ -601,6 +608,13 @@ class GCUModelRunner(ModelRunnerBase): cu_seqlens_k=self.share_inputs["cu_seqlens_k"], block_tables=self.share_inputs["block_tables"], caches=self.share_inputs["caches"], + encoder_batch_ids=self.share_inputs["encoder_batch_ids"], + encoder_tile_ids_per_batch=self.share_inputs["encoder_tile_ids_per_batch"], + encoder_num_blocks_x_cpu=self.share_inputs["encoder_num_blocks_x_cpu"], + kv_batch_ids=self.share_inputs["kv_batch_ids"], + kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"], + kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"], + max_len_kv_cpu=self.share_inputs["max_len_kv_cpu"], ) # Update Batch type for cuda graph @@ -673,14 +687,31 @@ class GCUModelRunner(ModelRunnerBase): encoder_block_shape_q = 64 decoder_block_shape_q = 16 decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1 + group_size = np.ceil(num_heads / self.model_config.kv_num_heads) + decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil( - (decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q + (decoder_step_token_num * group_size) / decoder_block_shape_q + ) + encode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil( + (self.model_config.max_model_len * group_size) / encoder_block_shape_q + ) + kv_max_tile_size = self.parallel_config.max_num_seqs * np.ceil( + self.model_config.max_model_len / self.fd_config.cache_config.block_size ) self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").cpu() self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu() + self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") + self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") + self.share_inputs["encoder_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + + self.share_inputs["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") + self.share_inputs["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") + self.share_inputs["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + self.share_inputs["max_len_kv_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + # Get the attention backend attn_cls = get_attention_backend() attn_backend = attn_cls( diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 9579f93b2..53591ca59 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -142,6 +142,7 @@ class GPUModelRunner(ModelRunnerBase): self.use_cudagraph = self.graph_opt_config.use_cudagraph self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes)) self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes + self.cudagraph_only_prefill = self.graph_opt_config.cudagraph_only_prefill # Initialize share inputs self._init_share_inputs(self.parallel_config.max_num_seqs) @@ -177,10 +178,49 @@ class GPUModelRunner(ModelRunnerBase): """ check whether prefill stage exist """ - if int(paddle.max(self.share_inputs["seq_lens_encoder"])) != 0: - return 1 - else: - return 0 + return int(paddle.max(self.share_inputs["seq_lens_encoder"])) > 0 + + def exist_decode(self): + """ + check whether decode stage exist + """ + return int(paddle.max(self.share_inputs["seq_lens_decoder"])) > 0 + + def only_prefill(self): + """ + check whether prefill only + """ + if_only_prefill = True + decode_exists = None + if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed": + only_prefill_batch_list = [] + decode_exists = self.exist_decode() + paddle.distributed.all_gather_object(only_prefill_batch_list, not decode_exists) + if_only_prefill = all(only_prefill_batch_list) + + if_only_prefill = if_only_prefill and not (decode_exists if decode_exists is not None else self.exist_decode()) + + return if_only_prefill + + def only_decode(self): + """ + check whether decode only + """ + # Update Batch type for cuda graph for if_only_decode + if_only_decode = True + prefill_exists = None + # mix ep in single node + if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed": + only_decode_batch_list = [] + prefill_exists = self.exist_prefill() + paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists) + if_only_decode = all(only_decode_batch_list) + + if_only_decode = if_only_decode and not ( + prefill_exists if prefill_exists is not None else self.exist_prefill() + ) + + return if_only_decode def _init_speculative_proposer(self): """ @@ -600,27 +640,81 @@ class GPUModelRunner(ModelRunnerBase): if self.speculative_method in ["mtp"]: self.proposer.insert_prefill_inputs(req_dicts, num_running_requests) - def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int): - """Set dummy prefill inputs to share_inputs""" + def get_input_length_list( + self, num_tokens: int, batch_size: int, expected_decode_len: int, capture_prefill: bool = False + ): + """ + Generates some list for _dummy_prefill_inputs, when capture pure prefill or mtp, + the list should be carefully constructed. + + This function addresses a specific problem: in the pure prefill stage, variable + input lengths (e.g., `prompt[160, 0]` vs. `prompt[80, 80]`) can lead to different + CUDA Grid dimensions for kernels like `split_q_block`. This prevents CUDA Graph + reuse. + + The `split_q_block` kernel calculates the total number of blocks, which directly + determines the `griddim.x` launch parameter for the `multi_query_append_attention_kernel`. + The blocks for a single sequence are determined by the formula: + `num_blocks = ceil((sequence_length * group_size) / block_shape_q)` + + Due to the `ceil` (ceiling) function, distributing a total number of tokens across + a batch of shorter sequences will result in a larger total block count. For example, + with a `group_size` of 5 and `block_shape_q` of 64: + - A single sequence of 160 tokens requires `ceil((160 * 5) / 64) = 13` blocks. + - Two sequences of 80 tokens each require `ceil((80 * 5) / 64) * 2 = 7 * 2 = 14` blocks. + + To ensure graph replayability, this function creates a "dummy" list of sequence + lengths that's designed to produce the theoretical maximum `encoder_num_blocks_x_cpu` + for the given `num_tokens` and `batch_size`. This strategy ensures the captured + CUDA Graph has the largest possible grid dimensions. At runtime, if the actual number + of blocks is less than or equal to this maximum, the kernel can safely execute by + using an early-exit mechanism. + + Args: + num_tokens (int): The total number of tokens across all sequences. + batch_size (int): The number of sequences (requests) in the batch. + + Returns: + List[int]: A list of integers representing the sequence length for each request. + This list is crafted to maximize the total number of blocks. + """ # NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token max_dec_len = expected_decode_len + 1 - full_length = min( - num_tokens // batch_size, + input_length = min( + num_tokens // (1 if capture_prefill else batch_size), self.parallel_config.max_model_len - max_dec_len, ) # NOTE(wanglongzhi): When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan. # TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP. if self.fd_config.parallel_config.enable_expert_parallel: - full_length = min(full_length, 32) + input_length = min(input_length, 32) - input_length = int(full_length * self.cache_config.kv_cache_ratio) block_num = ( input_length + self.cache_config.block_size - 1 ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num + input_length_list = [input_length] * batch_size + + if capture_prefill: + if num_tokens < batch_size: + input_length_list = [1] * num_tokens + else: + input_length_list = [1] * (batch_size - 1) + input_length_list.append(num_tokens - batch_size + 1) + + len_of_input_length_list = len(input_length_list) + max_dec_len_list = [max_dec_len] * len_of_input_length_list + + return input_length_list, max_dec_len_list, block_num + + def _dummy_prefill_inputs(self, input_length_list: List[int], max_dec_len_list: List[int], block_num: int): + """Set dummy prefill inputs to share_inputs""" + batch_size = len(input_length_list) for i in range(batch_size): idx = i + input_length = input_length_list[i] + max_dec_len = max_dec_len_list[i] self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length) self.share_inputs["prompt_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length) self.share_inputs["eos_token_id"][:] = np.array( @@ -745,6 +839,13 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["decoder_tile_ids_per_batch"] = None self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory self.share_inputs["max_len_tensor_cpu"] = None # CPU + self.share_inputs["encoder_batch_ids"] = None + self.share_inputs["encoder_tile_ids_per_batch"] = None + self.share_inputs["encoder_num_blocks_x_cpu"] = None # CPU + self.share_inputs["kv_batch_ids"] = None + self.share_inputs["kv_tile_ids_per_batch"] = None + self.share_inputs["kv_num_blocks_x_cpu"] = None # CPU + self.share_inputs["max_len_kv_cpu"] = None # CPU # Initialize rotary position embedding tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1)) @@ -977,23 +1078,30 @@ class GPUModelRunner(ModelRunnerBase): cu_seqlens_k=self.share_inputs["cu_seqlens_k"], block_tables=self.share_inputs["block_tables"], caches=self.share_inputs["caches"], + encoder_batch_ids=self.share_inputs["encoder_batch_ids"], + encoder_tile_ids_per_batch=self.share_inputs["encoder_tile_ids_per_batch"], + encoder_num_blocks_x_cpu=self.share_inputs["encoder_num_blocks_x_cpu"], + kv_batch_ids=self.share_inputs["kv_batch_ids"], + kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"], + kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"], + max_len_kv_cpu=self.share_inputs["max_len_kv_cpu"], ) - # Update Batch type for cuda graph - only_decode_batch = True - prefill_exists = None - # mix ep in single node - if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed": - only_decode_batch_list = [] - prefill_exists = self.exist_prefill() - paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists) - only_decode_batch = all(only_decode_batch_list) - self.fd_config.parallel_config.moe_phase.phase = "decode" if only_decode_batch else "prefill" + # Update Batch type for cuda graph for only_decode_batch + if_only_decode = self.only_decode() + only_decode_use_cudagraph = self.use_cudagraph and if_only_decode + # Update config about moe for better performance + # TODO(wanglongzhi):Modifying the config at runtime is not appropriate; it needs to be moved to forward_meta. It will be used in MoEMethodBase.apply() + if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed": + self.fd_config.parallel_config.moe_phase.phase = "decode" if if_only_decode else "prefill" + + # Update Batch type for cuda graph for only_prefill_batch + only_prefill_use_cudagraph = self.use_cudagraph and self.cudagraph_only_prefill and self.only_prefill() + + # When support capture both prefill-only and decode-only, this will use [only_prefill_use_cudagraph or only_decode_use_cudagraph] self.forward_meta.step_use_cudagraph = ( - self.use_cudagraph - and only_decode_batch - and not (prefill_exists if prefill_exists is not None else self.exist_prefill()) + only_prefill_use_cudagraph if self.cudagraph_only_prefill else only_decode_use_cudagraph ) # Initialzie attention meta data @@ -1085,14 +1193,31 @@ class GPUModelRunner(ModelRunnerBase): encoder_block_shape_q = 64 decoder_block_shape_q = 16 decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1 + group_size = np.ceil(num_heads / self.model_config.kv_num_heads) + decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil( - (decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q + (decoder_step_token_num * group_size) / decoder_block_shape_q + ) + encode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil( + (self.model_config.max_model_len * group_size) / encoder_block_shape_q + ) + kv_max_tile_size = self.parallel_config.max_num_seqs * np.ceil( + self.model_config.max_model_len / self.fd_config.cache_config.block_size ) self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory() self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu() + self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") + self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") + self.share_inputs["encoder_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + + self.share_inputs["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") + self.share_inputs["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") + self.share_inputs["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + self.share_inputs["max_len_kv_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + # Get the attention backend attn_cls = get_attention_backend() attn_backend = attn_cls( @@ -1112,6 +1237,7 @@ class GPUModelRunner(ModelRunnerBase): batch_size: paddle.Tensor, expected_decode_len: int = 1, in_capturing: bool = False, + capture_prefill: bool = False, ) -> paddle.Tensor: """ Use dummy inputs to run before formal execution. @@ -1119,11 +1245,19 @@ class GPUModelRunner(ModelRunnerBase): num_tokens: expected_decode_len: Expected number of tokens generated in_capturing: Is cuda graph in capturing state + capture_prefill: Capture pure prefill for cuda graph """ - self._dummy_prefill_inputs( + + input_length_list, max_dec_len_list, block_num = self.get_input_length_list( num_tokens=num_tokens, batch_size=batch_size, expected_decode_len=expected_decode_len, + capture_prefill=capture_prefill, + ) + self._dummy_prefill_inputs( + input_length_list=input_length_list, + max_dec_len_list=max_dec_len_list, + block_num=block_num, ) if self.speculative_method in ["mtp"]: self.proposer.dummy_prefill_inputs( @@ -1353,14 +1487,30 @@ class GPUModelRunner(ModelRunnerBase): time_before_capture = time.perf_counter() expected_decode_len = 1 capture_sizes = self.cudagraph_capture_sizes.copy() - for batch_size in sorted(capture_sizes, reverse=True): - self._dummy_run( - num_tokens=self.parallel_config.max_num_batched_tokens, - batch_size=batch_size, - in_capturing=True, - expected_decode_len=expected_decode_len, - ) - logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}") + + if self.fd_config.graph_opt_config.cudagraph_only_prefill: + for num_tokens in sorted(capture_sizes, reverse=True): + self._dummy_run( + num_tokens=num_tokens, + batch_size=self.parallel_config.max_num_seqs, + in_capturing=True, + expected_decode_len=expected_decode_len, + capture_prefill=True, + ) + logger.info( + f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}" + ) + else: + for batch_size in sorted(capture_sizes, reverse=True): + self._dummy_run( + num_tokens=self.parallel_config.max_num_batched_tokens, + batch_size=batch_size, + in_capturing=True, + expected_decode_len=expected_decode_len, + ) + logger.info( + f"Warm up the model with the num_tokens:{batch_size}, expected_decode_len:{expected_decode_len}" + ) time_after_capture = time.perf_counter() logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds") diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index b15c3fb16..68a6b6fac 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -397,7 +397,6 @@ class PaddleDisWorkerProc: self.get_profile_block_num_signal.value[0] = num_blocks_local else: num_blocks_local = self.fd_config.parallel_config.total_block_num - logger.info(f"------- num_blocks_global: {num_blocks_local} --------") # wait engine launch cache_manager if self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed": diff --git a/tests/graph_optimization/test_cuda_graph_dynamic_subgraph.py b/tests/graph_optimization/test_cuda_graph_dynamic_subgraph.py index 77c154b77..b6e74753b 100644 --- a/tests/graph_optimization/test_cuda_graph_dynamic_subgraph.py +++ b/tests/graph_optimization/test_cuda_graph_dynamic_subgraph.py @@ -157,7 +157,7 @@ class TestCUDAGrpahSubgraph(unittest.TestCase): cache_config = CacheConfig({}) # Initialize cuda graph capture list graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs) - graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs) + graph_opt_config.init_with_cudagrpah_size(max_capture_size=parallel_config.max_num_seqs) fd_config = FDConfig( graph_opt_config=graph_opt_config, parallel_config=parallel_config, diff --git a/tests/graph_optimization/test_cuda_graph_spec_decode.py b/tests/graph_optimization/test_cuda_graph_spec_decode.py index 3ea67d4b3..f4a95cead 100644 --- a/tests/graph_optimization/test_cuda_graph_spec_decode.py +++ b/tests/graph_optimization/test_cuda_graph_spec_decode.py @@ -104,7 +104,7 @@ class TestCUDAGrpahSpecDecode(unittest.TestCase): cache_config = CacheConfig({}) # Initialize cuda graph capture list graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs) - graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs) + graph_opt_config.init_with_cudagrpah_size(max_capture_size=parallel_config.max_num_seqs) fd_config = FDConfig( graph_opt_config=graph_opt_config, parallel_config=parallel_config, diff --git a/tests/graph_optimization/test_static_graph_cuda_graph_split.py b/tests/graph_optimization/test_static_graph_cuda_graph_split.py index 5bb721eba..faaad4127 100644 --- a/tests/graph_optimization/test_static_graph_cuda_graph_split.py +++ b/tests/graph_optimization/test_static_graph_cuda_graph_split.py @@ -90,7 +90,7 @@ class TestStaticGraphCUDAGraphSplit(unittest.TestCase): graph_opt_config = GraphOptimizationConfig({"use_cudagraph": True, "graph_opt_level": 1}) parallel_config = ParallelConfig({"max_num_seqs": 1}) graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs) - graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs) + graph_opt_config.init_with_cudagrpah_size(max_capture_size=parallel_config.max_num_seqs) cache_config = CacheConfig({}) fd_config = FDConfig( diff --git a/tests/layers/test_append_attention.py b/tests/layers/test_append_attention.py index 8616437ab..fcae19509 100644 --- a/tests/layers/test_append_attention.py +++ b/tests/layers/test_append_attention.py @@ -386,6 +386,14 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): self.decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory() self.max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu() + self.encoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32") + self.encoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32") + self.encoder_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() + self.kv_batch_ids = paddle.full([self.batch_size], 0, dtype="int32") + self.kv_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32") + self.kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() + self.max_len_kv_cpu = paddle.full([1], 0, dtype="int32").cpu() + self.cache_shape = ( self.max_block_num, self.kv_num_head, @@ -469,15 +477,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): get_block_shape_and_split_kv_block, ) - ( - encoder_batch_ids, - encoder_tile_ids_per_batch, - encoder_num_blocks, - kv_batch_ids, - kv_tile_ids_per_batch, - kv_num_blocks, - max_len_kv, - ) = get_block_shape_and_split_kv_block( + get_block_shape_and_split_kv_block( self.seq_lens_encoder, self.seq_lens_decoder, self.seq_lens_this_time, @@ -485,6 +485,13 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): self.decoder_tile_ids_per_batch, self.decoder_num_blocks_cpu, self.max_len_tensor_cpu, + self.encoder_batch_ids, + self.encoder_tile_ids_per_batch, + self.encoder_num_blocks_x_cpu, + self.kv_batch_ids, + self.kv_tile_ids_per_batch, + self.kv_num_blocks_x_cpu, + self.max_len_kv_cpu, 64, 12, (self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head, @@ -508,17 +515,17 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): self.padding_offset, self.cum_offset, self.block_tables, - encoder_batch_ids, - encoder_tile_ids_per_batch, - encoder_num_blocks, - kv_batch_ids, - kv_tile_ids_per_batch, - kv_num_blocks, + self.encoder_batch_ids, + self.encoder_tile_ids_per_batch, + self.encoder_num_blocks_x_cpu, + self.kv_batch_ids, + self.kv_tile_ids_per_batch, + self.kv_num_blocks_x_cpu, self.decoder_batch_ids, self.decoder_tile_ids_per_batch, self.decoder_num_blocks_cpu, self.max_len_tensor_cpu, - max_len_kv, + self.max_len_kv_cpu, self.rope_emb, # rope_emb None, # attn_mask None, # qkv_bias diff --git a/tests/layers/test_append_attention_with_output.py b/tests/layers/test_append_attention_with_output.py index 7b4121cf4..47cc1f384 100644 --- a/tests/layers/test_append_attention_with_output.py +++ b/tests/layers/test_append_attention_with_output.py @@ -382,6 +382,13 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): self.decoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32") self.decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory() self.max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu() + self.encoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32") + self.encoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32") + self.encoder_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() + self.kv_batch_ids = paddle.full([self.batch_size], 0, dtype="int32") + self.kv_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32") + self.kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() + self.max_len_kv_cpu = paddle.full([1], 0, dtype="int32").cpu() self.cache_shape = ( self.max_block_num, @@ -450,15 +457,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): get_block_shape_and_split_kv_block, ) - ( - encoder_batch_ids, - encoder_tile_ids_per_batch, - encoder_num_blocks, - kv_batch_ids, - kv_tile_ids_per_batch, - kv_num_blocks, - max_len_kv, - ) = get_block_shape_and_split_kv_block( + get_block_shape_and_split_kv_block( self.seq_lens_encoder, self.seq_lens_decoder, self.seq_lens_this_time, @@ -466,6 +465,13 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): self.decoder_tile_ids_per_batch, self.decoder_num_blocks_cpu, self.max_len_tensor_cpu, + self.encoder_batch_ids, + self.encoder_tile_ids_per_batch, + self.encoder_num_blocks_x_cpu, + self.kv_batch_ids, + self.kv_tile_ids_per_batch, + self.kv_num_blocks_x_cpu, + self.max_len_kv_cpu, 64, 12, (self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head, @@ -491,17 +497,17 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): self.padding_offset, self.cum_offset, self.block_tables, - encoder_batch_ids, - encoder_tile_ids_per_batch, - encoder_num_blocks, - kv_batch_ids, - kv_tile_ids_per_batch, - kv_num_blocks, + self.encoder_batch_ids, + self.encoder_tile_ids_per_batch, + self.encoder_num_blocks_x_cpu, + self.kv_batch_ids, + self.kv_tile_ids_per_batch, + self.kv_num_blocks_x_cpu, self.decoder_batch_ids, self.decoder_tile_ids_per_batch, self.decoder_num_blocks_cpu, self.max_len_tensor_cpu, - max_len_kv, + self.max_len_kv_cpu, out, self.rope_emb, # rope_emb None, # attn_mask diff --git a/tests/operators/test_tree_mask.py b/tests/operators/test_tree_mask.py index 59c4b1d98..1d8c81b12 100644 --- a/tests/operators/test_tree_mask.py +++ b/tests/operators/test_tree_mask.py @@ -190,30 +190,32 @@ class TestTreeMask(unittest.TestCase): encoder_block_shape_q = 64 decoder_block_shape_q = 16 - + group_size = self.num_q_head // self.num_kv_head decode_max_tile_size = ( - self.bsz - * (decoder_step_token_num * (self.num_q_head // self.num_kv_head) + decoder_block_shape_q - 1) - / decoder_block_shape_q + self.bsz * (decoder_step_token_num * group_size + decoder_block_shape_q - 1) / decoder_block_shape_q ) + encode_max_tile_size = ( + self.bsz * (self.max_seq_len * group_size + encoder_block_shape_q - 1) / encoder_block_shape_q + ) + kv_max_tile_size = self.bsz * (self.max_seq_len + self.block_size - 1) / self.block_size + decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory() max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu() + encoder_batch_ids = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") + encoder_tile_ids_per_batch = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") + encoder_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() + kv_batch_ids = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") + kv_tile_ids_per_batch = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") + kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() + max_len_kv_cpu = paddle.full([1], 0, dtype="int32").cpu() q_norm_weight = np.ones([self.head_dim]) k_norm_weight = np.ones([self.head_dim]) self.q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32") self.k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype="float32") paddle.device.synchronize() - ( - encoder_batch_ids, - encoder_tile_ids_per_batch, - encoder_num_blocks, - kv_batch_ids, - kv_tile_ids_per_batch, - kv_num_blocks, - max_len_kv, - ) = get_block_shape_and_split_kv_block( + get_block_shape_and_split_kv_block( seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, @@ -221,6 +223,13 @@ class TestTreeMask(unittest.TestCase): decoder_tile_ids_per_batch, decoder_num_blocks, max_len_tensor_cpu, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_x_cpu, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + max_len_kv_cpu, encoder_block_shape_q, decoder_block_shape_q, self.num_q_head // self.num_kv_head, @@ -243,15 +252,15 @@ class TestTreeMask(unittest.TestCase): self.block_tables, encoder_batch_ids, encoder_tile_ids_per_batch, - encoder_num_blocks, + encoder_num_blocks_x_cpu, kv_batch_ids, kv_tile_ids_per_batch, - kv_num_blocks, + kv_num_blocks_x_cpu, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks, max_len_tensor_cpu, - max_len_kv, + max_len_kv_cpu, rotary_embs, attn_mask, None, # qkv_bias