[feat] support prefix cache clearing when /clear_load_weight is called (#4091)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* [feat] support clearing prefix cache (cherry-picked from release/2.1)

* [fix] fix ipc suffix, use port instead

* [fix] fix prefix caching not enabled

* [fix] fix code style

* [fix] wait for rank0 to update weight status
This commit is contained in:
李泳桦
2025-09-16 11:11:20 +08:00
committed by GitHub
parent fbb4e0f8d1
commit 7ccbcc5a62
17 changed files with 624 additions and 181 deletions

View File

@@ -16,21 +16,27 @@
import argparse
import concurrent.futures
import gc
import json
import queue
import threading
import time
import traceback
import numpy as np
import paddle
from fastdeploy import envs
from fastdeploy.cache_manager.cache_data import CacheStatus
from fastdeploy.config import SpeculativeConfig
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
from fastdeploy.model_executor.ops.gpu import (
cuda_host_alloc,
cuda_host_free,
set_data_ipc,
share_external_data,
swap_cache_all_layers,
unset_data_ipc,
)
from fastdeploy.utils import get_logger
@@ -93,6 +99,7 @@ def parse_args():
help="speculative config",
)
parser.add_argument("--local_data_parallel_id", type=int, default=0)
parser.add_argument("--create_cache_tensor", action="store_true")
args = parser.parse_args()
return args
@@ -110,7 +117,6 @@ class CacheTransferManager:
device = args.device_id
rank = args.rank
paddle.set_device(f"gpu:{device}")
self.gpu_cache_kvs = {}
self.cpu_cache_kvs = {}
self.gpu_cache_k_tensors = []
@@ -126,6 +132,7 @@ class CacheTransferManager:
self.n_ranks = args.mp_num
self.rank = rank
self.device = device
self.engine_pid = args.engine_pid
address = (args.pod_ip, args.cache_queue_port)
self.cache_task_queue = EngineCacheQueue(
@@ -136,70 +143,27 @@ class CacheTransferManager:
local_data_parallel_id=args.local_data_parallel_id,
)
self.num_cpu_blocks = args.num_cpu_blocks
cache_type = args.cache_dtype
for i in range(args.num_layers + self.num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
fill_value=0,
dtype=cache_type,
)
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
fill_value=0,
dtype=cache_type,
)
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
set_data_ipc(
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
f"key_caches_{i}_rank{rank}.device{device}",
)
set_data_ipc(
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
f"value_caches_{i}_rank{rank}.device{device}",
)
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
logger.info(f"device :{self.device}")
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
paddle.set_device("cpu")
self.k_dst_ptrs = []
self.v_dst_ptrs = []
for i in range(args.num_layers + self.num_extra_layers):
self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"] = cuda_host_alloc(
args.num_cpu_blocks * args.bytes_per_layer_per_block
)
self.k_dst_ptrs.append(self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"])
self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"] = cuda_host_alloc(
args.num_cpu_blocks * args.bytes_per_layer_per_block
)
self.v_dst_ptrs.append(self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"])
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
self.cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=args.engine_pid,
suffix=self.engine_pid,
create=False,
)
self.cache_ready_signal.value[self.rank] = 1
swap_space_ready_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
self.swap_space_ready_signal = IPCSignal(
name="swap_space_ready_signal",
array=swap_space_ready_data,
dtype=np.int32,
suffix=self.engine_pid,
create=False,
)
self.num_cpu_blocks = args.num_cpu_blocks
self._init_cpu_cache(args)
self._init_gpu_cache(args)
paddle.set_device(f"gpu:{device}")
if args.enable_splitwise:
@@ -232,6 +196,72 @@ class CacheTransferManager:
create=False,
)
threading.Thread(target=self.clear_or_update_caches, args=[args], daemon=True).start()
def _init_gpu_cache(self, args):
if not args.create_cache_tensor:
logger.info("Waiting for runners to create kv cache.")
while self.cache_ready_signal.value[self.rank] != 1:
time.sleep(1)
logger.info("OK! Stop waiting.")
logger.info("Initializing kv cache for all layers.")
paddle.set_device(f"gpu:{self.device}")
for i in range(args.num_layers + self.num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
cache_shape = [num_gpu_blocks, args.kv_num_head, args.block_size, args.head_dim]
key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}"
val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}"
if args.create_cache_tensor:
logger.info(f"..creating kv cache for layer {i}: {cache_shape}")
key_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype)
val_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype)
set_data_ipc(key_cache, key_name)
set_data_ipc(val_cache, val_name)
else:
logger.info(f"..attaching kv cache for layer {i}: {cache_shape}")
key_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
val_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
key_cache = share_external_data(key_cache, key_name, cache_shape)
val_cache = share_external_data(val_cache, val_name, cache_shape)
self.gpu_cache_kvs[key_name] = key_cache
self.gpu_cache_kvs[val_name] = val_cache
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[val_name])
if args.create_cache_tensor:
logger.info("✅ kv cache is ready!")
self.cache_ready_signal.value[self.rank] = 1
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
logger.info(f"device :{self.device}")
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
def _init_cpu_cache(self, args):
if args.num_cpu_blocks == 0:
logger.info("💡 no swap space (cpu cache) is specified.")
self.swap_space_ready_signal.value[self.rank] = 1
return
logger.info("Initializing swap space (cpu cache) for all layers.")
paddle.set_device("cpu")
self.k_dst_ptrs = []
self.v_dst_ptrs = []
for i in range(args.num_layers + self.num_extra_layers):
key_name = f"key_caches_{i}_rank{self.rank}"
val_name = f"value_caches_{i}_rank{self.rank}"
need_to_allocate_bytes = args.num_cpu_blocks * args.bytes_per_layer_per_block
logger.info(f"..creating cpu cache for layer {i}: {2 * need_to_allocate_bytes / 1024 ** 3:.2f}GB")
self.cpu_cache_kvs[key_name] = cuda_host_alloc(need_to_allocate_bytes)
self.k_dst_ptrs.append(self.cpu_cache_kvs[key_name])
self.cpu_cache_kvs[val_name] = cuda_host_alloc(need_to_allocate_bytes)
self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name])
logger.info("✅ swap space (cpu cache) is ready!")
self.swap_space_ready_signal.value[self.rank] = 1
def _do_swap_to_cpu_task(
self,
swap_node_ids,
@@ -429,6 +459,67 @@ class CacheTransferManager:
transfer_task_id,
)
def clear_or_update_caches(self, args):
logger.info("Start a thread to clear/restore kv cache when model weights are cleared/updated.")
logger.info(f"FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}")
kv_cache_status = np.zeros([1], dtype=np.int32)
kv_cache_status_signal = IPCSignal(
name="kv_cache_status",
array=kv_cache_status,
dtype=np.int32,
suffix=self.engine_pid,
create=False,
)
while True:
if kv_cache_status_signal.value[0] == KVCacheStatus.CLEARING:
try:
if envs.FD_ENABLE_SWAP_SPACE_CLEARING:
paddle.set_device("cpu")
for ptrs in self.k_dst_ptrs + self.v_dst_ptrs:
cuda_host_free(ptrs)
self.cpu_cache_kvs.clear()
self.k_dst_ptrs.clear()
self.v_dst_ptrs.clear()
gc.collect()
# reset swap_space_ready_signal
self.swap_space_ready_signal.value[self.rank] = 0
while np.sum(self.swap_space_ready_signal.value) != 0:
time.sleep(0.1)
paddle.set_device(f"gpu:{self.device}")
for name, tensor in self.gpu_cache_kvs.items():
unset_data_ipc(tensor, name, True, False)
self.gpu_cache_kvs.clear()
self.gpu_cache_k_tensors.clear()
self.gpu_cache_v_tensors.clear()
# reset cache_ready_signal
self.cache_ready_signal.value[self.rank] = 0
if np.sum(self.cache_ready_signal.value) == 0:
time.sleep(0.1)
kv_cache_status_signal.value[0] = KVCacheStatus.CLEARED
except Exception as e:
logger.error(f"Failed to clear caches: {e}")
elif kv_cache_status_signal.value[0] == KVCacheStatus.UPDATING:
try:
if envs.FD_ENABLE_SWAP_SPACE_CLEARING:
self._init_cpu_cache(args)
while np.sum(self.swap_space_ready_signal.value) != args.mp_num:
time.sleep(0.1)
self._init_gpu_cache(args)
while np.sum(self.cache_ready_signal.value) != args.mp_num:
time.sleep(0.1)
kv_cache_status_signal.value[0] = KVCacheStatus.NORMAL
except Exception as e:
logger.error(f"Failed to restore caches: {e}")
time.sleep(0.1)
def main():
"""