mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] dyc8 support prefixcache (#5125)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
* dyc8 support prefixcache * fix cache_trans test case * update code
This commit is contained in:
@@ -16,127 +16,159 @@
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <paddle::DataType D>
|
||||
void SwapCacheImplAllLayers(const std::vector<paddle::Tensor>& cache_gpu_tensors, // gpu
|
||||
const std::vector<int64_t>& cache_cpu_ptrs, // cpu
|
||||
const int64_t& max_block_num_cpu,
|
||||
const std::vector<int64_t>& swap_block_ids_gpu,
|
||||
const std::vector<int64_t>& swap_block_ids_cpu,
|
||||
int mode) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
auto stream = cache_gpu_tensors[0].stream();
|
||||
for(int layer_idx=0; layer_idx < cache_gpu_tensors.size(); layer_idx++){
|
||||
const paddle::Tensor& cache_gpu = cache_gpu_tensors[layer_idx];
|
||||
const int64_t& cache_cpu_pointer = cache_cpu_ptrs[layer_idx];
|
||||
data_t* cache_gpu_ptr = const_cast<data_t*>(cache_gpu.data<data_t>());
|
||||
auto* cache_cpu_ptr = reinterpret_cast<data_t*>(cache_cpu_pointer);
|
||||
auto cache_shape = cache_gpu.shape();
|
||||
const int64_t max_block_num_gpu = cache_shape[0];
|
||||
const int64_t num_heads = cache_shape[1];
|
||||
const int64_t block_size = cache_shape[2];
|
||||
const int64_t head_dim = cache_shape[3];
|
||||
const int64_t cache_stride = num_heads * block_size * head_dim;
|
||||
|
||||
auto stream = cache_gpu.stream();
|
||||
if (swap_block_ids_gpu.size() == 0) {
|
||||
return;
|
||||
}
|
||||
int i = 0;
|
||||
int64_t consecutive_block_count = 1;
|
||||
int64_t last_gpu_block_id = swap_block_ids_gpu[i];
|
||||
int64_t last_cpu_block_id = swap_block_ids_cpu[i];
|
||||
int64_t first_gpu_block_id = last_gpu_block_id; // first block id in a consecutive block ids
|
||||
int64_t first_cpu_block_id = last_cpu_block_id;
|
||||
i += 1;
|
||||
while(true){
|
||||
if (i >= swap_block_ids_gpu.size()) {
|
||||
break;
|
||||
}
|
||||
int64_t gpu_block_id = swap_block_ids_gpu[i];
|
||||
int64_t cpu_block_id = swap_block_ids_cpu[i];
|
||||
assert(gpu_block_id >= 0 && gpu_block_id < max_block_num_gpu);
|
||||
assert(cpu_block_id >= 0 && cpu_block_id < max_block_num_cpu);
|
||||
if (gpu_block_id == last_gpu_block_id + 1 && cpu_block_id == last_cpu_block_id + 1){ // consecutive
|
||||
consecutive_block_count += 1;
|
||||
last_gpu_block_id = gpu_block_id;
|
||||
last_cpu_block_id = cpu_block_id;
|
||||
} else{
|
||||
// end of a consecutive block ids
|
||||
auto *cache_gpu_ptr_now = cache_gpu_ptr + first_gpu_block_id * cache_stride;
|
||||
auto *cache_cpu_ptr_now = cache_cpu_ptr + first_cpu_block_id * cache_stride;
|
||||
if (mode == 0) { // copy from device to host
|
||||
cudaMemcpyAsync(cache_cpu_ptr_now, cache_gpu_ptr_now, cache_stride * sizeof(DataType_) * consecutive_block_count, cudaMemcpyDeviceToHost, stream);
|
||||
} else { // copy from host to device
|
||||
cudaMemcpyAsync(cache_gpu_ptr_now, cache_cpu_ptr_now, cache_stride * sizeof(DataType_) * consecutive_block_count, cudaMemcpyHostToDevice, stream);
|
||||
}
|
||||
first_gpu_block_id = gpu_block_id;
|
||||
first_cpu_block_id = cpu_block_id;
|
||||
last_gpu_block_id = gpu_block_id;
|
||||
last_cpu_block_id = cpu_block_id;
|
||||
consecutive_block_count = 1;
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
// last batch
|
||||
auto *cache_gpu_ptr_now = cache_gpu_ptr + first_gpu_block_id * cache_stride;
|
||||
auto *cache_cpu_ptr_now = cache_cpu_ptr + first_cpu_block_id * cache_stride;
|
||||
if (mode == 0) { // copy from device to host
|
||||
cudaMemcpyAsync(cache_cpu_ptr_now, cache_gpu_ptr_now, cache_stride * sizeof(DataType_) * consecutive_block_count, cudaMemcpyDeviceToHost, stream);
|
||||
} else { // copy from host to device
|
||||
cudaMemcpyAsync(cache_gpu_ptr_now, cache_cpu_ptr_now, cache_stride * sizeof(DataType_) * consecutive_block_count, cudaMemcpyHostToDevice, stream);
|
||||
}
|
||||
void SwapCacheImplAllLayers(
|
||||
const std::vector<paddle::Tensor>& cache_gpu_tensors, // gpu
|
||||
const std::vector<int64_t>& cache_cpu_ptrs, // cpu
|
||||
const int64_t& max_block_num_cpu,
|
||||
const std::vector<int64_t>& swap_block_ids_gpu,
|
||||
const std::vector<int64_t>& swap_block_ids_cpu,
|
||||
int mode) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
auto stream = cache_gpu_tensors[0].stream();
|
||||
for (int layer_idx = 0; layer_idx < cache_gpu_tensors.size(); layer_idx++) {
|
||||
const paddle::Tensor& cache_gpu = cache_gpu_tensors[layer_idx];
|
||||
const int64_t& cache_cpu_pointer = cache_cpu_ptrs[layer_idx];
|
||||
data_t* cache_gpu_ptr = const_cast<data_t*>(cache_gpu.data<data_t>());
|
||||
auto* cache_cpu_ptr = reinterpret_cast<data_t*>(cache_cpu_pointer);
|
||||
auto cache_shape = cache_gpu.shape();
|
||||
const int64_t max_block_num_gpu = cache_shape[0];
|
||||
const int64_t num_heads = cache_shape[1];
|
||||
const int64_t block_size = cache_shape[2];
|
||||
int64_t head_dim = 1;
|
||||
if (cache_shape.size() == 4) {
|
||||
head_dim = cache_shape[3];
|
||||
}
|
||||
cudaStreamSynchronize(stream);
|
||||
const int64_t cache_stride = num_heads * block_size * head_dim;
|
||||
|
||||
auto stream = cache_gpu.stream();
|
||||
if (swap_block_ids_gpu.size() == 0) {
|
||||
return;
|
||||
}
|
||||
int i = 0;
|
||||
int64_t consecutive_block_count = 1;
|
||||
int64_t last_gpu_block_id = swap_block_ids_gpu[i];
|
||||
int64_t last_cpu_block_id = swap_block_ids_cpu[i];
|
||||
int64_t first_gpu_block_id =
|
||||
last_gpu_block_id; // first block id in a consecutive block ids
|
||||
int64_t first_cpu_block_id = last_cpu_block_id;
|
||||
i += 1;
|
||||
while (true) {
|
||||
if (i >= swap_block_ids_gpu.size()) {
|
||||
break;
|
||||
}
|
||||
int64_t gpu_block_id = swap_block_ids_gpu[i];
|
||||
int64_t cpu_block_id = swap_block_ids_cpu[i];
|
||||
assert(gpu_block_id >= 0 && gpu_block_id < max_block_num_gpu);
|
||||
assert(cpu_block_id >= 0 && cpu_block_id < max_block_num_cpu);
|
||||
if (gpu_block_id == last_gpu_block_id + 1 &&
|
||||
cpu_block_id == last_cpu_block_id + 1) { // consecutive
|
||||
consecutive_block_count += 1;
|
||||
last_gpu_block_id = gpu_block_id;
|
||||
last_cpu_block_id = cpu_block_id;
|
||||
} else {
|
||||
// end of a consecutive block ids
|
||||
auto* cache_gpu_ptr_now =
|
||||
cache_gpu_ptr + first_gpu_block_id * cache_stride;
|
||||
auto* cache_cpu_ptr_now =
|
||||
cache_cpu_ptr + first_cpu_block_id * cache_stride;
|
||||
if (mode == 0) { // copy from device to host
|
||||
cudaMemcpyAsync(
|
||||
cache_cpu_ptr_now,
|
||||
cache_gpu_ptr_now,
|
||||
cache_stride * sizeof(DataType_) * consecutive_block_count,
|
||||
cudaMemcpyDeviceToHost,
|
||||
stream);
|
||||
} else { // copy from host to device
|
||||
cudaMemcpyAsync(
|
||||
cache_gpu_ptr_now,
|
||||
cache_cpu_ptr_now,
|
||||
cache_stride * sizeof(DataType_) * consecutive_block_count,
|
||||
cudaMemcpyHostToDevice,
|
||||
stream);
|
||||
}
|
||||
first_gpu_block_id = gpu_block_id;
|
||||
first_cpu_block_id = cpu_block_id;
|
||||
last_gpu_block_id = gpu_block_id;
|
||||
last_cpu_block_id = cpu_block_id;
|
||||
consecutive_block_count = 1;
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
// last batch
|
||||
auto* cache_gpu_ptr_now = cache_gpu_ptr + first_gpu_block_id * cache_stride;
|
||||
auto* cache_cpu_ptr_now = cache_cpu_ptr + first_cpu_block_id * cache_stride;
|
||||
if (mode == 0) { // copy from device to host
|
||||
cudaMemcpyAsync(
|
||||
cache_cpu_ptr_now,
|
||||
cache_gpu_ptr_now,
|
||||
cache_stride * sizeof(DataType_) * consecutive_block_count,
|
||||
cudaMemcpyDeviceToHost,
|
||||
stream);
|
||||
} else { // copy from host to device
|
||||
cudaMemcpyAsync(
|
||||
cache_gpu_ptr_now,
|
||||
cache_cpu_ptr_now,
|
||||
cache_stride * sizeof(DataType_) * consecutive_block_count,
|
||||
cudaMemcpyHostToDevice,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
cudaStreamSynchronize(stream);
|
||||
}
|
||||
|
||||
void SwapCacheAllLayers(const std::vector<paddle::Tensor>& cache_gpu_tensors, // gpu
|
||||
const std::vector<int64_t>& cache_cpu_ptrs, // cpu memory pointer
|
||||
int64_t max_block_num_cpu, // cpu max block num
|
||||
const std::vector<int64_t>& swap_block_ids_gpu,
|
||||
const std::vector<int64_t>& swap_block_ids_cpu,
|
||||
int rank,
|
||||
int mode) {
|
||||
cudaSetDevice(rank); // used for distributed launch
|
||||
assert(cache_gpu_tensors.size() > 0 && cache_gpu_tensors.size() == cache_cpu_ptrs.size());
|
||||
switch (cache_gpu_tensors[0].dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
return SwapCacheImplAllLayers<paddle::DataType::BFLOAT16>(
|
||||
cache_gpu_tensors,
|
||||
cache_cpu_ptrs,
|
||||
max_block_num_cpu,
|
||||
swap_block_ids_gpu,
|
||||
swap_block_ids_cpu,
|
||||
mode);
|
||||
case paddle::DataType::FLOAT16:
|
||||
return SwapCacheImplAllLayers<paddle::DataType::FLOAT16>(
|
||||
cache_gpu_tensors,
|
||||
cache_cpu_ptrs,
|
||||
max_block_num_cpu,
|
||||
swap_block_ids_gpu,
|
||||
swap_block_ids_cpu,
|
||||
mode);
|
||||
case paddle::DataType::UINT8:
|
||||
return SwapCacheImplAllLayers<paddle::DataType::UINT8>(
|
||||
cache_gpu_tensors,
|
||||
cache_cpu_ptrs,
|
||||
max_block_num_cpu,
|
||||
swap_block_ids_gpu,
|
||||
swap_block_ids_cpu,
|
||||
mode);
|
||||
default:
|
||||
PD_THROW("Unsupported data type.");
|
||||
}
|
||||
void SwapCacheAllLayers(
|
||||
const std::vector<paddle::Tensor>& cache_gpu_tensors, // gpu
|
||||
const std::vector<int64_t>& cache_cpu_ptrs, // cpu memory pointer
|
||||
int64_t max_block_num_cpu, // cpu max block num
|
||||
const std::vector<int64_t>& swap_block_ids_gpu,
|
||||
const std::vector<int64_t>& swap_block_ids_cpu,
|
||||
int rank,
|
||||
int mode) {
|
||||
cudaSetDevice(rank); // used for distributed launch
|
||||
assert(cache_gpu_tensors.size() > 0 &&
|
||||
cache_gpu_tensors.size() == cache_cpu_ptrs.size());
|
||||
switch (cache_gpu_tensors[0].dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
return SwapCacheImplAllLayers<paddle::DataType::BFLOAT16>(
|
||||
cache_gpu_tensors,
|
||||
cache_cpu_ptrs,
|
||||
max_block_num_cpu,
|
||||
swap_block_ids_gpu,
|
||||
swap_block_ids_cpu,
|
||||
mode);
|
||||
case paddle::DataType::FLOAT16:
|
||||
return SwapCacheImplAllLayers<paddle::DataType::FLOAT16>(
|
||||
cache_gpu_tensors,
|
||||
cache_cpu_ptrs,
|
||||
max_block_num_cpu,
|
||||
swap_block_ids_gpu,
|
||||
swap_block_ids_cpu,
|
||||
mode);
|
||||
case paddle::DataType::UINT8:
|
||||
return SwapCacheImplAllLayers<paddle::DataType::UINT8>(cache_gpu_tensors,
|
||||
cache_cpu_ptrs,
|
||||
max_block_num_cpu,
|
||||
swap_block_ids_gpu,
|
||||
swap_block_ids_cpu,
|
||||
mode);
|
||||
default:
|
||||
PD_THROW("Unsupported data type.");
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(swap_cache_all_layers)
|
||||
.Inputs({paddle::Vec("cache_gpu_tensors")})
|
||||
.Attrs({"cache_cpu_ptrs: std::vector<int64_t>",
|
||||
"max_block_num_cpu: int64_t",
|
||||
"swap_block_ids_gpu: std::vector<int64_t>",
|
||||
"swap_block_ids_cpu: std::vector<int64_t>",
|
||||
"rank: int",
|
||||
"mode: int",})
|
||||
.Attrs({
|
||||
"cache_cpu_ptrs: std::vector<int64_t>",
|
||||
"max_block_num_cpu: int64_t",
|
||||
"swap_block_ids_gpu: std::vector<int64_t>",
|
||||
"swap_block_ids_cpu: std::vector<int64_t>",
|
||||
"rank: int",
|
||||
"mode: int",
|
||||
})
|
||||
.Outputs({paddle::Vec("cache_dst_outs")})
|
||||
.SetInplaceMap({{paddle::Vec("cache_gpu_tensors"), paddle::Vec("cache_dst_outs")}})
|
||||
.SetInplaceMap({{paddle::Vec("cache_gpu_tensors"),
|
||||
paddle::Vec("cache_dst_outs")}})
|
||||
.SetKernelFn(PD_KERNEL(SwapCacheAllLayers));
|
||||
|
||||
@@ -63,7 +63,7 @@ def parse_args():
|
||||
"--cache_dtype",
|
||||
type=str,
|
||||
default="bfloat16",
|
||||
choices=["uint8", "bfloat16"],
|
||||
choices=["uint8", "bfloat16", "block_wise_fp8"],
|
||||
help="cache dtype",
|
||||
)
|
||||
parser.add_argument("--key_cache_shape", type=str, default="", help="key cache shape")
|
||||
@@ -114,6 +114,8 @@ class CacheTransferManager:
|
||||
self.cpu_cache_kvs = {}
|
||||
self.gpu_cache_k_tensors = []
|
||||
self.gpu_cache_v_tensors = []
|
||||
self.gpu_cache_scales_k_tensors = []
|
||||
self.gpu_cache_scales_v_tensors = []
|
||||
self.speculative_config = SpeculativeConfig(args.speculative_config)
|
||||
self.key_cache_shape = [int(i) for i in args.key_cache_shape.split(",")]
|
||||
self.value_cache_shape = []
|
||||
@@ -131,6 +133,7 @@ class CacheTransferManager:
|
||||
self.rank = rank
|
||||
self.device = device
|
||||
self.engine_pid = args.engine_pid
|
||||
self.cache_dtype = args.cache_dtype
|
||||
|
||||
address = (args.pod_ip, args.cache_queue_port)
|
||||
self.cache_task_queue = EngineCacheQueue(
|
||||
@@ -203,12 +206,19 @@ class CacheTransferManager:
|
||||
time.sleep(0.1)
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] OK! Stop waiting.")
|
||||
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
cache_type = "uint8"
|
||||
else:
|
||||
cache_type = args.cache_dtype
|
||||
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
|
||||
set_device(self.device)
|
||||
for i in range(args.num_layers + self.num_extra_layers):
|
||||
num_gpu_blocks = self.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
|
||||
key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}"
|
||||
val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}"
|
||||
key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}.device{self.device}"
|
||||
value_cache_scales_name = f"value_cache_scales_{i}_rank{self.rank}.device{self.device}"
|
||||
key_cache_shape = [
|
||||
num_gpu_blocks,
|
||||
self.key_cache_shape[1],
|
||||
@@ -227,26 +237,64 @@ class CacheTransferManager:
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] ..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}"
|
||||
)
|
||||
key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=args.cache_dtype)
|
||||
key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_type)
|
||||
set_data_ipc(key_cache, key_name)
|
||||
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
key_cache_scales = paddle.full(
|
||||
shape=[num_gpu_blocks, self.key_cache_shape[1], self.key_cache_shape[2]],
|
||||
fill_value=0,
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
set_data_ipc(key_cache_scales, key_cache_scales_name)
|
||||
if self.value_cache_shape:
|
||||
val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=args.cache_dtype)
|
||||
val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_type)
|
||||
set_data_ipc(val_cache, val_name)
|
||||
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
value_cache_scales = paddle.full(
|
||||
shape=[num_gpu_blocks, self.value_cache_shape[1], self.value_cache_shape[2]],
|
||||
fill_value=0,
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
set_data_ipc(value_cache_scales, value_cache_scales_name)
|
||||
else:
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] ..attaching kv cache for layer {i}: {key_cache_shape} {value_cache_shape}"
|
||||
)
|
||||
key_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
|
||||
val_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
|
||||
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
val_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
key_cache = share_external_data_(key_cache, key_name, key_cache_shape, True)
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
key_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
|
||||
key_cache_scales = share_external_data_(
|
||||
key_cache_scales,
|
||||
key_cache_scales_name,
|
||||
[num_gpu_blocks, self.key_cache_shape[1], self.key_cache_shape[2]],
|
||||
True,
|
||||
)
|
||||
if self.value_cache_shape:
|
||||
val_cache = share_external_data_(val_cache, val_name, value_cache_shape, True)
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
value_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
|
||||
value_cache_scales = share_external_data_(
|
||||
value_cache_scales,
|
||||
value_cache_scales_name,
|
||||
[num_gpu_blocks, self.value_cache_shape[1], self.value_cache_shape[2]],
|
||||
True,
|
||||
)
|
||||
|
||||
self.gpu_cache_kvs[key_name] = key_cache
|
||||
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
self.gpu_cache_kvs[key_cache_scales_name] = key_cache_scales
|
||||
self.gpu_cache_scales_k_tensors.append(self.gpu_cache_kvs[key_cache_scales_name])
|
||||
if args.value_cache_shape:
|
||||
self.gpu_cache_kvs[val_name] = val_cache
|
||||
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[val_name])
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
self.gpu_cache_kvs[value_cache_scales_name] = value_cache_scales
|
||||
self.gpu_cache_scales_v_tensors.append(self.gpu_cache_kvs[value_cache_scales_name])
|
||||
|
||||
if args.create_cache_tensor:
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ kv cache is ready!")
|
||||
@@ -265,12 +313,17 @@ class CacheTransferManager:
|
||||
value_cache_size = 0
|
||||
if args.cache_dtype == "bfloat16":
|
||||
cache_bytes = 2
|
||||
elif args.cache_dtype == "uint8":
|
||||
elif args.cache_dtype == "uint8" or args.cache_dtype == "block_wise_fp8":
|
||||
cache_bytes = 1
|
||||
else:
|
||||
raise ValueError(f"Unsupported cache dtype: {args.cache_dtype}")
|
||||
key_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * key_cache_size
|
||||
value_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * value_cache_size
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
|
||||
cache_scales_size = self.key_cache_shape[1] * self.key_cache_shape[2]
|
||||
scales_key_need_to_allocate_bytes = args.num_cpu_blocks * cache_scales.element_size() * cache_scales_size
|
||||
scales_value_need_to_allocate_bytes = args.num_cpu_blocks * cache_scales.element_size() * cache_scales_size
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] ..swap space size : {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB"
|
||||
)
|
||||
@@ -282,17 +335,27 @@ class CacheTransferManager:
|
||||
paddle.set_device("cpu")
|
||||
self.k_dst_ptrs = []
|
||||
self.v_dst_ptrs = []
|
||||
self.k_scales_ptrs = []
|
||||
self.v_scales_ptrs = []
|
||||
for i in range(args.num_layers + self.num_extra_layers):
|
||||
key_name = f"key_caches_{i}_rank{self.rank}"
|
||||
val_name = f"value_caches_{i}_rank{self.rank}"
|
||||
key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}"
|
||||
value_cache_scales_name = f"value_cache_scales_{i}_rank{self.rank}"
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] ..creating cpu cache for layer {i}: {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB"
|
||||
)
|
||||
self.cpu_cache_kvs[key_name] = cuda_host_alloc(key_need_to_allocate_bytes)
|
||||
self.k_dst_ptrs.append(self.cpu_cache_kvs[key_name])
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
self.cpu_cache_kvs[key_cache_scales_name] = cuda_host_alloc(scales_key_need_to_allocate_bytes)
|
||||
self.k_scales_ptrs.append(self.cpu_cache_kvs[key_cache_scales_name])
|
||||
if value_need_to_allocate_bytes > 0:
|
||||
self.cpu_cache_kvs[val_name] = cuda_host_alloc(value_need_to_allocate_bytes)
|
||||
self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name])
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
self.cpu_cache_kvs[value_cache_scales_name] = cuda_host_alloc(scales_value_need_to_allocate_bytes)
|
||||
self.v_scales_ptrs.append(self.cpu_cache_kvs[value_cache_scales_name])
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!")
|
||||
self.swap_space_ready_signal.value[self.rank] = 1
|
||||
|
||||
@@ -492,6 +555,25 @@ class CacheTransferManager:
|
||||
self.device,
|
||||
0,
|
||||
)
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
swap_cache_all_layers(
|
||||
self.gpu_cache_scales_k_tensors,
|
||||
self.k_scales_ptrs,
|
||||
self.num_cpu_blocks,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
self.device,
|
||||
0,
|
||||
)
|
||||
swap_cache_all_layers(
|
||||
self.gpu_cache_scales_v_tensors,
|
||||
self.v_scales_ptrs,
|
||||
self.num_cpu_blocks,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
self.device,
|
||||
0,
|
||||
)
|
||||
|
||||
elif event_type.value == CacheStatus.SWAP2GPU.value:
|
||||
swap_cache_all_layers(
|
||||
@@ -512,6 +594,25 @@ class CacheTransferManager:
|
||||
self.device,
|
||||
1,
|
||||
)
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
swap_cache_all_layers(
|
||||
self.gpu_cache_scales_k_tensors,
|
||||
self.k_scales_ptrs,
|
||||
self.num_cpu_blocks,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
self.device,
|
||||
1,
|
||||
)
|
||||
swap_cache_all_layers(
|
||||
self.gpu_cache_scales_v_tensors,
|
||||
self.v_scales_ptrs,
|
||||
self.num_cpu_blocks,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
self.device,
|
||||
1,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"transfer data: Get unexpected event type {event_type}, only SWAP2CPU and SWAP2GPU supported"
|
||||
|
||||
@@ -1239,6 +1239,8 @@ class CacheConfig:
|
||||
self.enable_hierarchical_cache = True
|
||||
|
||||
if self.model_cfg is not None:
|
||||
if self.model_cfg.quantization is not None and isinstance(self.model_cfg.quantization, dict):
|
||||
self.cache_dtype = self.model_cfg.quantization.get("kv_cache_quant_type", self.cache_dtype)
|
||||
if self.model_cfg.quantization_config is not None:
|
||||
self.cache_dtype = self.model_cfg.quantization_config.get("kv_cache_quant_type", self.cache_dtype)
|
||||
if (
|
||||
|
||||
@@ -1450,8 +1450,10 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
for i in range(self.model_config.num_hidden_layers):
|
||||
# init key cache
|
||||
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
key_cache_scales_name = f"key_cache_scales_{i}_rank{local_rank}.device{self.device}"
|
||||
if value_cache_shape:
|
||||
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
value_cache_scales_name = f"value_cache_scales_{i}_rank{local_rank}.device{self.device}"
|
||||
if create_cache_tensor:
|
||||
logger.info(f"..creating kv cache for layer {i}: key:{key_cache_shape}, value:{value_cache_shape}")
|
||||
key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_type)
|
||||
@@ -1477,12 +1479,25 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
logger.info(f"..attaching kv cache for layer {i}: key:{key_cache_shape}, value:{value_cache_shape}")
|
||||
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
key_cache = share_external_data(key_cache, key_cache_name, key_cache_shape)
|
||||
if kv_cache_quant_type == "block_wise_fp8":
|
||||
key_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
|
||||
key_cache_scales = share_external_data(
|
||||
key_cache_scales, key_cache_scales_name, kv_cache_scale_shape
|
||||
)
|
||||
if value_cache_shape:
|
||||
val_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
val_cache = share_external_data(val_cache, val_cache_name, value_cache_shape)
|
||||
cache_kvs_list.extend([key_cache, val_cache])
|
||||
if kv_cache_quant_type == "block_wise_fp8":
|
||||
val_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
|
||||
val_cache_scales = share_external_data(
|
||||
val_cache_scales, value_cache_scales_name, kv_cache_scale_shape
|
||||
)
|
||||
cache_kvs_list.extend([key_cache_scales, val_cache_scales])
|
||||
else:
|
||||
cache_kvs_list.extend([key_cache])
|
||||
if kv_cache_quant_type == "block_wise_fp8":
|
||||
cache_kvs_list.extend([key_cache_scales])
|
||||
|
||||
self.share_inputs["caches"] = cache_kvs_list
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ class Args:
|
||||
key_cache_shape = "1,1,1,1"
|
||||
value_cache_shape = ""
|
||||
create_cache_tensor = False
|
||||
cache_dtype = "bfloat16"
|
||||
|
||||
|
||||
# ==========================
|
||||
|
||||
Reference in New Issue
Block a user