Revert "[BugFix] Revert skip capture (#5023)" (#5080)

This commit is contained in:
Sunny-bot1
2025-11-17 16:14:55 +08:00
committed by GitHub
parent 7f94d77e08
commit 8a4ddb29df

View File

@@ -15,6 +15,7 @@
#include "helper.h"
#include "paddle/extension.h"
#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h"
#include "paddle/phi/core/memory/memcpy.h"
#endif
#include "utils.cuh"
@@ -287,9 +288,13 @@ void GetBlockShapeAndSplitKVBlock(
seq_lens_encoder.data<int>(),
max_len_tensor_gpu.data<int>(),
bsz);
max_len_tensor_cpu.copy_(
max_len_tensor_gpu, max_len_tensor_cpu.place(), false);
// Note (sunxin): Skip capturing the DtoH copy (it's time-consuming); CPU data
// is only for branching in attention.
#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
if (!phi::backends::gpu::IsCUDAGraphCapturing())
#endif
max_len_tensor_cpu.copy_(
max_len_tensor_gpu, max_len_tensor_cpu.place(), false);
auto max_len_cpu_ptr = max_len_tensor_cpu.data<int>();
int max_len_this_time = max_len_cpu_ptr[0];
@@ -398,9 +403,13 @@ void GetBlockShapeAndSplitKVBlock(
bsz,
decoder_block_shape_q,
group_size);
decoder_num_blocks_cpu.copy_(
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
// Note (sunxin): Skip capturing the DtoH copy (it's time-consuming); CPU
// data is only for branching in attention.
#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
if (!phi::backends::gpu::IsCUDAGraphCapturing())
#endif
decoder_num_blocks_cpu.copy_(
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
}