[Feature] dyc8 support prefixcache (#5125)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled

* dyc8 support prefixcache

* fix cache_trans test case

* update code
This commit is contained in:
kevin
2025-11-21 19:46:26 +08:00
committed by GitHub
parent ab3a2e45ff
commit c068a4f642
5 changed files with 272 additions and 121 deletions

View File

@@ -16,127 +16,159 @@
#include "paddle/extension.h"
template <paddle::DataType D>
void SwapCacheImplAllLayers(const std::vector<paddle::Tensor>& cache_gpu_tensors, // gpu
const std::vector<int64_t>& cache_cpu_ptrs, // cpu
const int64_t& max_block_num_cpu,
const std::vector<int64_t>& swap_block_ids_gpu,
const std::vector<int64_t>& swap_block_ids_cpu,
int mode) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto stream = cache_gpu_tensors[0].stream();
for(int layer_idx=0; layer_idx < cache_gpu_tensors.size(); layer_idx++){
const paddle::Tensor& cache_gpu = cache_gpu_tensors[layer_idx];
const int64_t& cache_cpu_pointer = cache_cpu_ptrs[layer_idx];
data_t* cache_gpu_ptr = const_cast<data_t*>(cache_gpu.data<data_t>());
auto* cache_cpu_ptr = reinterpret_cast<data_t*>(cache_cpu_pointer);
auto cache_shape = cache_gpu.shape();
const int64_t max_block_num_gpu = cache_shape[0];
const int64_t num_heads = cache_shape[1];
const int64_t block_size = cache_shape[2];
const int64_t head_dim = cache_shape[3];
const int64_t cache_stride = num_heads * block_size * head_dim;
auto stream = cache_gpu.stream();
if (swap_block_ids_gpu.size() == 0) {
return;
}
int i = 0;
int64_t consecutive_block_count = 1;
int64_t last_gpu_block_id = swap_block_ids_gpu[i];
int64_t last_cpu_block_id = swap_block_ids_cpu[i];
int64_t first_gpu_block_id = last_gpu_block_id; // first block id in a consecutive block ids
int64_t first_cpu_block_id = last_cpu_block_id;
i += 1;
while(true){
if (i >= swap_block_ids_gpu.size()) {
break;
}
int64_t gpu_block_id = swap_block_ids_gpu[i];
int64_t cpu_block_id = swap_block_ids_cpu[i];
assert(gpu_block_id >= 0 && gpu_block_id < max_block_num_gpu);
assert(cpu_block_id >= 0 && cpu_block_id < max_block_num_cpu);
if (gpu_block_id == last_gpu_block_id + 1 && cpu_block_id == last_cpu_block_id + 1){ // consecutive
consecutive_block_count += 1;
last_gpu_block_id = gpu_block_id;
last_cpu_block_id = cpu_block_id;
} else{
// end of a consecutive block ids
auto *cache_gpu_ptr_now = cache_gpu_ptr + first_gpu_block_id * cache_stride;
auto *cache_cpu_ptr_now = cache_cpu_ptr + first_cpu_block_id * cache_stride;
if (mode == 0) { // copy from device to host
cudaMemcpyAsync(cache_cpu_ptr_now, cache_gpu_ptr_now, cache_stride * sizeof(DataType_) * consecutive_block_count, cudaMemcpyDeviceToHost, stream);
} else { // copy from host to device
cudaMemcpyAsync(cache_gpu_ptr_now, cache_cpu_ptr_now, cache_stride * sizeof(DataType_) * consecutive_block_count, cudaMemcpyHostToDevice, stream);
}
first_gpu_block_id = gpu_block_id;
first_cpu_block_id = cpu_block_id;
last_gpu_block_id = gpu_block_id;
last_cpu_block_id = cpu_block_id;
consecutive_block_count = 1;
}
i += 1;
}
// last batch
auto *cache_gpu_ptr_now = cache_gpu_ptr + first_gpu_block_id * cache_stride;
auto *cache_cpu_ptr_now = cache_cpu_ptr + first_cpu_block_id * cache_stride;
if (mode == 0) { // copy from device to host
cudaMemcpyAsync(cache_cpu_ptr_now, cache_gpu_ptr_now, cache_stride * sizeof(DataType_) * consecutive_block_count, cudaMemcpyDeviceToHost, stream);
} else { // copy from host to device
cudaMemcpyAsync(cache_gpu_ptr_now, cache_cpu_ptr_now, cache_stride * sizeof(DataType_) * consecutive_block_count, cudaMemcpyHostToDevice, stream);
}
void SwapCacheImplAllLayers(
const std::vector<paddle::Tensor>& cache_gpu_tensors, // gpu
const std::vector<int64_t>& cache_cpu_ptrs, // cpu
const int64_t& max_block_num_cpu,
const std::vector<int64_t>& swap_block_ids_gpu,
const std::vector<int64_t>& swap_block_ids_cpu,
int mode) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto stream = cache_gpu_tensors[0].stream();
for (int layer_idx = 0; layer_idx < cache_gpu_tensors.size(); layer_idx++) {
const paddle::Tensor& cache_gpu = cache_gpu_tensors[layer_idx];
const int64_t& cache_cpu_pointer = cache_cpu_ptrs[layer_idx];
data_t* cache_gpu_ptr = const_cast<data_t*>(cache_gpu.data<data_t>());
auto* cache_cpu_ptr = reinterpret_cast<data_t*>(cache_cpu_pointer);
auto cache_shape = cache_gpu.shape();
const int64_t max_block_num_gpu = cache_shape[0];
const int64_t num_heads = cache_shape[1];
const int64_t block_size = cache_shape[2];
int64_t head_dim = 1;
if (cache_shape.size() == 4) {
head_dim = cache_shape[3];
}
cudaStreamSynchronize(stream);
const int64_t cache_stride = num_heads * block_size * head_dim;
auto stream = cache_gpu.stream();
if (swap_block_ids_gpu.size() == 0) {
return;
}
int i = 0;
int64_t consecutive_block_count = 1;
int64_t last_gpu_block_id = swap_block_ids_gpu[i];
int64_t last_cpu_block_id = swap_block_ids_cpu[i];
int64_t first_gpu_block_id =
last_gpu_block_id; // first block id in a consecutive block ids
int64_t first_cpu_block_id = last_cpu_block_id;
i += 1;
while (true) {
if (i >= swap_block_ids_gpu.size()) {
break;
}
int64_t gpu_block_id = swap_block_ids_gpu[i];
int64_t cpu_block_id = swap_block_ids_cpu[i];
assert(gpu_block_id >= 0 && gpu_block_id < max_block_num_gpu);
assert(cpu_block_id >= 0 && cpu_block_id < max_block_num_cpu);
if (gpu_block_id == last_gpu_block_id + 1 &&
cpu_block_id == last_cpu_block_id + 1) { // consecutive
consecutive_block_count += 1;
last_gpu_block_id = gpu_block_id;
last_cpu_block_id = cpu_block_id;
} else {
// end of a consecutive block ids
auto* cache_gpu_ptr_now =
cache_gpu_ptr + first_gpu_block_id * cache_stride;
auto* cache_cpu_ptr_now =
cache_cpu_ptr + first_cpu_block_id * cache_stride;
if (mode == 0) { // copy from device to host
cudaMemcpyAsync(
cache_cpu_ptr_now,
cache_gpu_ptr_now,
cache_stride * sizeof(DataType_) * consecutive_block_count,
cudaMemcpyDeviceToHost,
stream);
} else { // copy from host to device
cudaMemcpyAsync(
cache_gpu_ptr_now,
cache_cpu_ptr_now,
cache_stride * sizeof(DataType_) * consecutive_block_count,
cudaMemcpyHostToDevice,
stream);
}
first_gpu_block_id = gpu_block_id;
first_cpu_block_id = cpu_block_id;
last_gpu_block_id = gpu_block_id;
last_cpu_block_id = cpu_block_id;
consecutive_block_count = 1;
}
i += 1;
}
// last batch
auto* cache_gpu_ptr_now = cache_gpu_ptr + first_gpu_block_id * cache_stride;
auto* cache_cpu_ptr_now = cache_cpu_ptr + first_cpu_block_id * cache_stride;
if (mode == 0) { // copy from device to host
cudaMemcpyAsync(
cache_cpu_ptr_now,
cache_gpu_ptr_now,
cache_stride * sizeof(DataType_) * consecutive_block_count,
cudaMemcpyDeviceToHost,
stream);
} else { // copy from host to device
cudaMemcpyAsync(
cache_gpu_ptr_now,
cache_cpu_ptr_now,
cache_stride * sizeof(DataType_) * consecutive_block_count,
cudaMemcpyHostToDevice,
stream);
}
}
cudaStreamSynchronize(stream);
}
void SwapCacheAllLayers(const std::vector<paddle::Tensor>& cache_gpu_tensors, // gpu
const std::vector<int64_t>& cache_cpu_ptrs, // cpu memory pointer
int64_t max_block_num_cpu, // cpu max block num
const std::vector<int64_t>& swap_block_ids_gpu,
const std::vector<int64_t>& swap_block_ids_cpu,
int rank,
int mode) {
cudaSetDevice(rank); // used for distributed launch
assert(cache_gpu_tensors.size() > 0 && cache_gpu_tensors.size() == cache_cpu_ptrs.size());
switch (cache_gpu_tensors[0].dtype()) {
case paddle::DataType::BFLOAT16:
return SwapCacheImplAllLayers<paddle::DataType::BFLOAT16>(
cache_gpu_tensors,
cache_cpu_ptrs,
max_block_num_cpu,
swap_block_ids_gpu,
swap_block_ids_cpu,
mode);
case paddle::DataType::FLOAT16:
return SwapCacheImplAllLayers<paddle::DataType::FLOAT16>(
cache_gpu_tensors,
cache_cpu_ptrs,
max_block_num_cpu,
swap_block_ids_gpu,
swap_block_ids_cpu,
mode);
case paddle::DataType::UINT8:
return SwapCacheImplAllLayers<paddle::DataType::UINT8>(
cache_gpu_tensors,
cache_cpu_ptrs,
max_block_num_cpu,
swap_block_ids_gpu,
swap_block_ids_cpu,
mode);
default:
PD_THROW("Unsupported data type.");
}
void SwapCacheAllLayers(
const std::vector<paddle::Tensor>& cache_gpu_tensors, // gpu
const std::vector<int64_t>& cache_cpu_ptrs, // cpu memory pointer
int64_t max_block_num_cpu, // cpu max block num
const std::vector<int64_t>& swap_block_ids_gpu,
const std::vector<int64_t>& swap_block_ids_cpu,
int rank,
int mode) {
cudaSetDevice(rank); // used for distributed launch
assert(cache_gpu_tensors.size() > 0 &&
cache_gpu_tensors.size() == cache_cpu_ptrs.size());
switch (cache_gpu_tensors[0].dtype()) {
case paddle::DataType::BFLOAT16:
return SwapCacheImplAllLayers<paddle::DataType::BFLOAT16>(
cache_gpu_tensors,
cache_cpu_ptrs,
max_block_num_cpu,
swap_block_ids_gpu,
swap_block_ids_cpu,
mode);
case paddle::DataType::FLOAT16:
return SwapCacheImplAllLayers<paddle::DataType::FLOAT16>(
cache_gpu_tensors,
cache_cpu_ptrs,
max_block_num_cpu,
swap_block_ids_gpu,
swap_block_ids_cpu,
mode);
case paddle::DataType::UINT8:
return SwapCacheImplAllLayers<paddle::DataType::UINT8>(cache_gpu_tensors,
cache_cpu_ptrs,
max_block_num_cpu,
swap_block_ids_gpu,
swap_block_ids_cpu,
mode);
default:
PD_THROW("Unsupported data type.");
}
}
PD_BUILD_STATIC_OP(swap_cache_all_layers)
.Inputs({paddle::Vec("cache_gpu_tensors")})
.Attrs({"cache_cpu_ptrs: std::vector<int64_t>",
"max_block_num_cpu: int64_t",
"swap_block_ids_gpu: std::vector<int64_t>",
"swap_block_ids_cpu: std::vector<int64_t>",
"rank: int",
"mode: int",})
.Attrs({
"cache_cpu_ptrs: std::vector<int64_t>",
"max_block_num_cpu: int64_t",
"swap_block_ids_gpu: std::vector<int64_t>",
"swap_block_ids_cpu: std::vector<int64_t>",
"rank: int",
"mode: int",
})
.Outputs({paddle::Vec("cache_dst_outs")})
.SetInplaceMap({{paddle::Vec("cache_gpu_tensors"), paddle::Vec("cache_dst_outs")}})
.SetInplaceMap({{paddle::Vec("cache_gpu_tensors"),
paddle::Vec("cache_dst_outs")}})
.SetKernelFn(PD_KERNEL(SwapCacheAllLayers));