From 8a4ddb29df3663095835831a2c7c85639119f4f8 Mon Sep 17 00:00:00 2001 From: Sunny-bot1 <68891411+Sunny-bot1@users.noreply.github.com> Date: Mon, 17 Nov 2025 16:14:55 +0800 Subject: [PATCH] Revert "[BugFix] Revert skip capture (#5023)" (#5080) --- .../get_block_shape_and_split_kv_block.cu | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index 4a42235f5..3368eb620 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -15,6 +15,7 @@ #include "helper.h" #include "paddle/extension.h" #ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU +#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h" #include "paddle/phi/core/memory/memcpy.h" #endif #include "utils.cuh" @@ -287,9 +288,13 @@ void GetBlockShapeAndSplitKVBlock( seq_lens_encoder.data(), max_len_tensor_gpu.data(), bsz); - - max_len_tensor_cpu.copy_( - max_len_tensor_gpu, max_len_tensor_cpu.place(), false); + // Note (sunxin): Skip capturing the DtoH copy (it's time-consuming); CPU data + // is only for branching in attention. +#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if (!phi::backends::gpu::IsCUDAGraphCapturing()) +#endif + max_len_tensor_cpu.copy_( + max_len_tensor_gpu, max_len_tensor_cpu.place(), false); auto max_len_cpu_ptr = max_len_tensor_cpu.data(); int max_len_this_time = max_len_cpu_ptr[0]; @@ -398,9 +403,13 @@ void GetBlockShapeAndSplitKVBlock( bsz, decoder_block_shape_q, group_size); - - decoder_num_blocks_cpu.copy_( - decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); + // Note (sunxin): Skip capturing the DtoH copy (it's time-consuming); CPU + // data is only for branching in attention. +#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if (!phi::backends::gpu::IsCUDAGraphCapturing()) +#endif + decoder_num_blocks_cpu.copy_( + decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); }