[Cherry-Pick][Feature] dy-c8 prefix caching (#4918)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* c8 prefix caching

* update code

* update code

* update cache trans

* update code

* update code
This commit is contained in:
kevin
2025-11-28 10:37:49 +08:00
committed by GitHub
parent f637ba708c
commit b52e1bd281
4 changed files with 100 additions and 5 deletions

View File

@@ -35,7 +35,10 @@ void SwapCacheImplAllLayers(const std::vector<paddle::Tensor>& cache_gpu_tensors
const int64_t max_block_num_gpu = cache_shape[0];
const int64_t num_heads = cache_shape[1];
const int64_t block_size = cache_shape[2];
const int64_t head_dim = cache_shape[3];
int64_t head_dim = 1;
if (cache_shape.size() == 4) {
head_dim = cache_shape[3];
}
const int64_t cache_stride = num_heads * block_size * head_dim;
auto stream = cache_gpu.stream();

View File

@@ -83,7 +83,6 @@ def parse_args():
"--cache_dtype",
type=str,
default="bfloat16",
choices=["uint8", "bfloat16"],
help="cache dtype",
)
parser.add_argument(
@@ -115,6 +114,8 @@ class CacheTransferManager:
self.cpu_cache_kvs = {}
self.gpu_cache_k_tensors = []
self.gpu_cache_v_tensors = []
self.gpu_cache_scales_k_tensors = []
self.gpu_cache_scales_v_tensors = []
self.speculative_config = SpeculativeConfig(args.speculative_config)
self.num_extra_layers = self.speculative_config.num_extra_cache_layer
self.num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
@@ -126,6 +127,7 @@ class CacheTransferManager:
self.n_ranks = args.mp_num
self.rank = rank
self.device = device
self.cache_dtype = args.cache_dtype
address = (args.pod_ip, args.cache_queue_port)
self.cache_task_queue = EngineCacheQueue(
@@ -137,8 +139,11 @@ class CacheTransferManager:
)
self.num_cpu_blocks = args.num_cpu_blocks
if args.cache_dtype == "block_wise_fp8":
cache_type = "uint8"
else:
cache_type = args.cache_dtype
cache_type = args.cache_dtype
for i in range(args.num_layers + self.num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
@@ -164,7 +169,6 @@ class CacheTransferManager:
dtype=cache_type,
)
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
set_data_ipc(
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
f"key_caches_{i}_rank{rank}.device{device}",
@@ -173,6 +177,32 @@ class CacheTransferManager:
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
f"value_caches_{i}_rank{rank}.device{device}",
)
if args.cache_dtype == "block_wise_fp8":
self.gpu_cache_kvs[f"key_cache_scales_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[num_gpu_blocks, args.kv_num_head, args.block_size],
fill_value=0,
dtype=paddle.get_default_dtype(),
)
self.gpu_cache_kvs[f"value_cache_scales_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[num_gpu_blocks, args.kv_num_head, args.block_size],
fill_value=0,
dtype=paddle.get_default_dtype(),
)
self.gpu_cache_scales_k_tensors.append(
self.gpu_cache_kvs[f"key_cache_scales_{i}_rank{rank}_device{device}"]
)
self.gpu_cache_scales_v_tensors.append(
self.gpu_cache_kvs[f"value_cache_scales_{i}_rank{rank}_device{device}"]
)
set_data_ipc(
self.gpu_cache_kvs[f"key_cache_scales_{i}_rank{rank}_device{device}"],
f"key_cache_scales_{i}_rank{rank}.device{device}",
)
set_data_ipc(
self.gpu_cache_kvs[f"value_cache_scales_{i}_rank{rank}_device{device}"],
f"value_cache_scales_{i}_rank{rank}.device{device}",
)
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
logger.info(f"device :{self.device}")
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
@@ -181,6 +211,8 @@ class CacheTransferManager:
paddle.set_device("cpu")
self.k_dst_ptrs = []
self.v_dst_ptrs = []
self.k_scales_ptrs = []
self.v_scales_ptrs = []
for i in range(args.num_layers + self.num_extra_layers):
self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"] = cuda_host_alloc(
args.num_cpu_blocks * args.bytes_per_layer_per_block
@@ -190,6 +222,14 @@ class CacheTransferManager:
args.num_cpu_blocks * args.bytes_per_layer_per_block
)
self.v_dst_ptrs.append(self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"])
self.cpu_cache_kvs[f"key_caches_scales_{i}_rank{rank}"] = cuda_host_alloc(
args.num_cpu_blocks * args.bytes_per_layer_per_block
)
self.k_scales_ptrs.append(self.cpu_cache_kvs[f"key_caches_scales_{i}_rank{rank}"])
self.cpu_cache_kvs[f"value_caches_scales_{i}_rank{rank}"] = cuda_host_alloc(
args.num_cpu_blocks * args.bytes_per_layer_per_block
)
self.v_scales_ptrs.append(self.cpu_cache_kvs[f"value_caches_scales_{i}_rank{rank}"])
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
self.cache_ready_signal = IPCSignal(
@@ -388,6 +428,25 @@ class CacheTransferManager:
self.device,
0,
)
if self.cache_dtype == "block_wise_fp8":
swap_cache_all_layers(
self.gpu_cache_scales_k_tensors,
self.k_scales_ptrs,
self.num_cpu_blocks,
gpu_block_ids,
cpu_block_ids,
self.device,
0,
)
swap_cache_all_layers(
self.gpu_cache_scales_v_tensors,
self.v_scales_ptrs,
self.num_cpu_blocks,
gpu_block_ids,
cpu_block_ids,
self.device,
0,
)
elif event_type.value == CacheStatus.SWAP2GPU.value:
swap_cache_all_layers(
@@ -408,6 +467,25 @@ class CacheTransferManager:
self.device,
1,
)
if self.cache_dtype == "block_wise_fp8":
swap_cache_all_layers(
self.gpu_cache_scales_k_tensors,
self.k_scales_ptrs,
self.num_cpu_blocks,
gpu_block_ids,
cpu_block_ids,
self.device,
1,
)
swap_cache_all_layers(
self.gpu_cache_scales_v_tensors,
self.v_scales_ptrs,
self.num_cpu_blocks,
gpu_block_ids,
cpu_block_ids,
self.device,
1,
)
else:
logger.warning(
f"transfer data: Get unexpected event type {event_type}, only SWAP2CPU and SWAP2GPU supported"

View File

@@ -997,7 +997,9 @@ class CacheConfig:
self.enable_hierarchical_cache = True
if self.model_cfg is not None:
if self.model_cfg.quantization_config is not None:
if self.model_cfg.quantization is not None and isinstance(self.model_cfg.quantization, dict):
self.cache_dtype = self.model_cfg.quantization.get("kv_cache_quant_type", self.cache_dtype)
elif self.model_cfg.quantization_config is not None:
self.cache_dtype = self.model_cfg.quantization_config.get("kv_cache_quant_type", self.cache_dtype)
if (
hasattr(self.model_cfg, "num_key_value_heads")

View File

@@ -1050,6 +1050,18 @@ class GPUModelRunner(ModelRunnerBase):
value_cache = share_external_data(value_cache, val_cache_name, kv_cache_shape)
cache_kvs_list.append(value_cache)
if kv_cache_quant_type == "block_wise_fp8":
scale_key_cache_name = f"key_cache_scales_{i}_rank{local_rank}.device{self.device_id}"
scale_val_cache_name = f"value_cache_scales_{i}_rank{local_rank}.device{self.device_id}"
key_scale_cache = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
key_scale_cache = share_external_data(key_scale_cache, scale_key_cache_name, kv_cache_scale_shape)
cache_kvs_list.append(key_scale_cache)
value_scale_cache = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
value_scale_cache = share_external_data(
value_scale_cache, scale_val_cache_name, kv_cache_scale_shape
)
cache_kvs_list.append(value_scale_cache)
self.share_inputs["caches"] = cache_kvs_list
else:
for i in range(self.model_config.num_hidden_layers):