mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-27 04:46:16 +08:00
[fix] fix clearing caches synchronization and add more logs (#4212)
* [fix] fix clearing caches synchronization and add more logs * [chore] print cache_ready_signal in log
This commit is contained in:
@@ -201,12 +201,12 @@ class CacheTransferManager:
|
|||||||
def _init_gpu_cache(self, args):
|
def _init_gpu_cache(self, args):
|
||||||
|
|
||||||
if not args.create_cache_tensor:
|
if not args.create_cache_tensor:
|
||||||
logger.info("Waiting for runners to create kv cache.")
|
logger.info(f"[rank {self.rank}/{self.n_ranks}] Waiting for runners to create kv cache.")
|
||||||
while self.cache_ready_signal.value[self.rank] != 1:
|
while self.cache_ready_signal.value[self.rank] != 1:
|
||||||
time.sleep(1)
|
time.sleep(0.1)
|
||||||
logger.info("OK! Stop waiting.")
|
logger.info(f"[rank {self.rank}/{self.n_ranks}] OK! Stop waiting.")
|
||||||
|
|
||||||
logger.info("Initializing kv cache for all layers.")
|
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
|
||||||
paddle.set_device(f"gpu:{self.device}")
|
paddle.set_device(f"gpu:{self.device}")
|
||||||
for i in range(args.num_layers + self.num_extra_layers):
|
for i in range(args.num_layers + self.num_extra_layers):
|
||||||
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
|
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
|
||||||
@@ -215,13 +215,13 @@ class CacheTransferManager:
|
|||||||
val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}"
|
val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}"
|
||||||
|
|
||||||
if args.create_cache_tensor:
|
if args.create_cache_tensor:
|
||||||
logger.info(f"..creating kv cache for layer {i}: {cache_shape}")
|
logger.info(f"[rank {self.rank}/{self.n_ranks}] ..creating kv cache for layer {i}: {cache_shape}")
|
||||||
key_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype)
|
key_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype)
|
||||||
val_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype)
|
val_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype)
|
||||||
set_data_ipc(key_cache, key_name)
|
set_data_ipc(key_cache, key_name)
|
||||||
set_data_ipc(val_cache, val_name)
|
set_data_ipc(val_cache, val_name)
|
||||||
else:
|
else:
|
||||||
logger.info(f"..attaching kv cache for layer {i}: {cache_shape}")
|
logger.info(f"[rank {self.rank}/{self.n_ranks}] ..attaching kv cache for layer {i}: {cache_shape}")
|
||||||
key_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
|
key_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
|
||||||
val_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
|
val_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
|
||||||
key_cache = share_external_data(key_cache, key_name, cache_shape)
|
key_cache = share_external_data(key_cache, key_name, cache_shape)
|
||||||
@@ -233,20 +233,22 @@ class CacheTransferManager:
|
|||||||
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[val_name])
|
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[val_name])
|
||||||
|
|
||||||
if args.create_cache_tensor:
|
if args.create_cache_tensor:
|
||||||
logger.info("✅ kv cache is ready!")
|
logger.info("[rank {self.rank}/{self.n_ranks}] ✅ kv cache is ready!")
|
||||||
self.cache_ready_signal.value[self.rank] = 1
|
self.cache_ready_signal.value[self.rank] = 1
|
||||||
|
|
||||||
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
|
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
|
||||||
logger.info(f"device :{self.device}")
|
logger.info(f"[rank {self.rank}/{self.n_ranks}] device :{self.device}")
|
||||||
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
|
logger.info(f"[rank {self.rank}/{self.n_ranks}] cache_kv_size_byte : {cache_kv_size_byte}")
|
||||||
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
|
logger.info(
|
||||||
|
f"[rank {self.rank}/{self.n_ranks}] done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}"
|
||||||
|
)
|
||||||
|
|
||||||
def _init_cpu_cache(self, args):
|
def _init_cpu_cache(self, args):
|
||||||
if args.num_cpu_blocks == 0:
|
if args.num_cpu_blocks == 0:
|
||||||
logger.info("💡 no swap space (cpu cache) is specified.")
|
logger.info(f"[rank {self.rank}/{self.n_ranks}] 💡 no swap space (cpu cache) is specified.")
|
||||||
self.swap_space_ready_signal.value[self.rank] = 1
|
self.swap_space_ready_signal.value[self.rank] = 1
|
||||||
return
|
return
|
||||||
logger.info("Initializing swap space (cpu cache) for all layers.")
|
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing swap space (cpu cache) for all layers.")
|
||||||
paddle.set_device("cpu")
|
paddle.set_device("cpu")
|
||||||
self.k_dst_ptrs = []
|
self.k_dst_ptrs = []
|
||||||
self.v_dst_ptrs = []
|
self.v_dst_ptrs = []
|
||||||
@@ -254,12 +256,14 @@ class CacheTransferManager:
|
|||||||
key_name = f"key_caches_{i}_rank{self.rank}"
|
key_name = f"key_caches_{i}_rank{self.rank}"
|
||||||
val_name = f"value_caches_{i}_rank{self.rank}"
|
val_name = f"value_caches_{i}_rank{self.rank}"
|
||||||
need_to_allocate_bytes = args.num_cpu_blocks * args.bytes_per_layer_per_block
|
need_to_allocate_bytes = args.num_cpu_blocks * args.bytes_per_layer_per_block
|
||||||
logger.info(f"..creating cpu cache for layer {i}: {2 * need_to_allocate_bytes / 1024 ** 3:.2f}GB")
|
logger.info(
|
||||||
|
f"[rank {self.rank}/{self.n_ranks}] ..creating cpu cache for layer {i}: {2 * need_to_allocate_bytes / 1024 ** 3:.2f}GB"
|
||||||
|
)
|
||||||
self.cpu_cache_kvs[key_name] = cuda_host_alloc(need_to_allocate_bytes)
|
self.cpu_cache_kvs[key_name] = cuda_host_alloc(need_to_allocate_bytes)
|
||||||
self.k_dst_ptrs.append(self.cpu_cache_kvs[key_name])
|
self.k_dst_ptrs.append(self.cpu_cache_kvs[key_name])
|
||||||
self.cpu_cache_kvs[val_name] = cuda_host_alloc(need_to_allocate_bytes)
|
self.cpu_cache_kvs[val_name] = cuda_host_alloc(need_to_allocate_bytes)
|
||||||
self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name])
|
self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name])
|
||||||
logger.info("✅ swap space (cpu cache) is ready!")
|
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!")
|
||||||
self.swap_space_ready_signal.value[self.rank] = 1
|
self.swap_space_ready_signal.value[self.rank] = 1
|
||||||
|
|
||||||
def _do_swap_to_cpu_task(
|
def _do_swap_to_cpu_task(
|
||||||
@@ -473,6 +477,10 @@ class CacheTransferManager:
|
|||||||
while True:
|
while True:
|
||||||
if kv_cache_status_signal.value[0] == KVCacheStatus.CLEARING:
|
if kv_cache_status_signal.value[0] == KVCacheStatus.CLEARING:
|
||||||
try:
|
try:
|
||||||
|
logger.info(
|
||||||
|
f"[rank {self.rank}/{self.n_ranks}] Start clearing caches {self.cache_ready_signal.value}"
|
||||||
|
)
|
||||||
|
# clear cpu caches
|
||||||
if envs.FD_ENABLE_SWAP_SPACE_CLEARING:
|
if envs.FD_ENABLE_SWAP_SPACE_CLEARING:
|
||||||
paddle.set_device("cpu")
|
paddle.set_device("cpu")
|
||||||
for ptrs in self.k_dst_ptrs + self.v_dst_ptrs:
|
for ptrs in self.k_dst_ptrs + self.v_dst_ptrs:
|
||||||
@@ -486,37 +494,58 @@ class CacheTransferManager:
|
|||||||
while np.sum(self.swap_space_ready_signal.value) != 0:
|
while np.sum(self.swap_space_ready_signal.value) != 0:
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# clear gpu caches
|
||||||
paddle.set_device(f"gpu:{self.device}")
|
paddle.set_device(f"gpu:{self.device}")
|
||||||
for name, tensor in self.gpu_cache_kvs.items():
|
for name, tensor in self.gpu_cache_kvs.items():
|
||||||
unset_data_ipc(tensor, name, True, False)
|
unset_data_ipc(tensor, name, True, False)
|
||||||
self.gpu_cache_kvs.clear()
|
self.gpu_cache_kvs.clear()
|
||||||
self.gpu_cache_k_tensors.clear()
|
self.gpu_cache_k_tensors.clear()
|
||||||
self.gpu_cache_v_tensors.clear()
|
self.gpu_cache_v_tensors.clear()
|
||||||
|
|
||||||
# reset cache_ready_signal
|
# reset cache_ready_signal
|
||||||
self.cache_ready_signal.value[self.rank] = 0
|
self.cache_ready_signal.value[self.rank] = 0
|
||||||
if np.sum(self.cache_ready_signal.value) == 0:
|
logger.info(
|
||||||
|
f"[rank {self.rank}/{self.n_ranks}] Finish clearing caches {self.cache_ready_signal.value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# wait for all ranks caches to be cleared
|
||||||
|
if np.sum(self.cache_ready_signal.value) != 0:
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# reset kv_cache_status_signal
|
||||||
kv_cache_status_signal.value[0] = KVCacheStatus.CLEARED
|
kv_cache_status_signal.value[0] = KVCacheStatus.CLEARED
|
||||||
|
logger.info("All ranks finish clearing caches")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to clear caches: {e}")
|
logger.error(f"[rank {self.rank}/{self.n_ranks}] Failed to clear caches: {e}")
|
||||||
|
|
||||||
elif kv_cache_status_signal.value[0] == KVCacheStatus.UPDATING:
|
elif kv_cache_status_signal.value[0] == KVCacheStatus.UPDATING:
|
||||||
try:
|
try:
|
||||||
|
logger.info(
|
||||||
|
f"[rank {self.rank}/{self.n_ranks}] Start restoring caches {self.cache_ready_signal.value}"
|
||||||
|
)
|
||||||
|
# restore cpu cache
|
||||||
if envs.FD_ENABLE_SWAP_SPACE_CLEARING:
|
if envs.FD_ENABLE_SWAP_SPACE_CLEARING:
|
||||||
self._init_cpu_cache(args)
|
self._init_cpu_cache(args)
|
||||||
while np.sum(self.swap_space_ready_signal.value) != args.mp_num:
|
while np.sum(self.swap_space_ready_signal.value) != args.mp_num:
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# restore gpu cache and set cache_ready_signal
|
||||||
self._init_gpu_cache(args)
|
self._init_gpu_cache(args)
|
||||||
|
logger.info(
|
||||||
|
f"[rank {self.rank}/{self.n_ranks}] Finish restoring caches {self.cache_ready_signal.value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# wait for all ranks caches to be ready
|
||||||
while np.sum(self.cache_ready_signal.value) != args.mp_num:
|
while np.sum(self.cache_ready_signal.value) != args.mp_num:
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# set kv_cache_status_signal
|
||||||
|
logger.info("All ranks finish restoring caches")
|
||||||
kv_cache_status_signal.value[0] = KVCacheStatus.NORMAL
|
kv_cache_status_signal.value[0] = KVCacheStatus.NORMAL
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to restore caches: {e}")
|
logger.error(f"[rank {self.rank}/{self.n_ranks}] Failed to restore caches: {e}")
|
||||||
|
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
@@ -115,7 +115,7 @@ class DynamicWeightManager:
|
|||||||
self._verify_parameters("clearance")
|
self._verify_parameters("clearance")
|
||||||
if self.parallel_config.tensor_parallel_size > 1:
|
if self.parallel_config.tensor_parallel_size > 1:
|
||||||
paddle.distributed.barrier(self.parallel_config.tp_group)
|
paddle.distributed.barrier(self.parallel_config.tp_group)
|
||||||
paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
|
paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
|
||||||
if self.parallel_config.enable_expert_parallel:
|
if self.parallel_config.enable_expert_parallel:
|
||||||
paddle.distributed.barrier(self.parallel_config.ep_group)
|
paddle.distributed.barrier(self.parallel_config.ep_group)
|
||||||
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
|
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
|
||||||
|
@@ -1028,12 +1028,12 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
create_cache_tensor = profile or self.parallel_config.splitwise_role == "mixed"
|
create_cache_tensor = profile or self.parallel_config.splitwise_role == "mixed"
|
||||||
|
|
||||||
if not create_cache_tensor:
|
if not create_cache_tensor:
|
||||||
logger.info("Waiting for cache managers to create kv cache..")
|
logger.info(f"Waiting for cache managers to create kv cache.. {cache_ready_signal.value}")
|
||||||
while cache_ready_signal.value[self.local_rank] != 1:
|
while cache_ready_signal.value[self.local_rank] != 1:
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
logger.info("OK! Stop waiting.")
|
logger.info(f"OK! Stop waiting. {cache_ready_signal.value}")
|
||||||
|
|
||||||
logger.info("Initializing kv cache for all layers.")
|
logger.info(f"Initializing kv cache for all layers. {cache_ready_signal.value}")
|
||||||
cache_kvs_list = []
|
cache_kvs_list = []
|
||||||
for i in range(self.model_config.num_hidden_layers):
|
for i in range(self.model_config.num_hidden_layers):
|
||||||
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
|
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||||
@@ -1054,8 +1054,8 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["caches"] = cache_kvs_list
|
self.share_inputs["caches"] = cache_kvs_list
|
||||||
|
|
||||||
if not profile and create_cache_tensor:
|
if not profile and create_cache_tensor:
|
||||||
logger.info("✅ kv cache is ready!")
|
|
||||||
cache_ready_signal.value[self.local_rank] = 1
|
cache_ready_signal.value[self.local_rank] = 1
|
||||||
|
logger.info(f"✅ kv cache is ready! {cache_ready_signal.value}")
|
||||||
|
|
||||||
paddle.device.cuda.empty_cache()
|
paddle.device.cuda.empty_cache()
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user