diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index c70712a7e..3368eb620 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -15,9 +15,9 @@ #include "helper.h" #include "paddle/extension.h" #ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU +#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h" #include "paddle/phi/core/memory/memcpy.h" #endif -#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h" #include "utils.cuh" template @@ -290,10 +290,11 @@ void GetBlockShapeAndSplitKVBlock( bsz); // Note (sunxin): Skip capturing the DtoH copy (it's time-consuming); CPU data // is only for branching in attention. - if (!phi::backends::gpu::IsCUDAGraphCapturing()) { +#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if (!phi::backends::gpu::IsCUDAGraphCapturing()) +#endif max_len_tensor_cpu.copy_( max_len_tensor_gpu, max_len_tensor_cpu.place(), false); - } auto max_len_cpu_ptr = max_len_tensor_cpu.data(); int max_len_this_time = max_len_cpu_ptr[0]; @@ -404,10 +405,11 @@ void GetBlockShapeAndSplitKVBlock( group_size); // Note (sunxin): Skip capturing the DtoH copy (it's time-consuming); CPU // data is only for branching in attention. - if (!phi::backends::gpu::IsCUDAGraphCapturing()) { +#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if (!phi::backends::gpu::IsCUDAGraphCapturing()) +#endif decoder_num_blocks_cpu.copy_( decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); - } PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); }