[XPU] support prefix cache (#4423)

Co-authored-by: ddchenhao66 <dhaochen163.com>
This commit is contained in:
ddchenhao66
2025-10-16 11:27:41 +08:00
committed by GitHub
parent 5bde20b0c9
commit 8e392f0ea6
4 changed files with 112 additions and 45 deletions

View File

@@ -30,14 +30,25 @@ from fastdeploy import envs
from fastdeploy.cache_manager.cache_data import CacheStatus
from fastdeploy.config import SpeculativeConfig
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
from fastdeploy.model_executor.ops.gpu import (
cuda_host_alloc,
cuda_host_free,
set_data_ipc,
share_external_data,
swap_cache_all_layers,
unset_data_ipc,
)
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (
cuda_host_alloc,
cuda_host_free,
set_data_ipc,
share_external_data,
swap_cache_all_layers,
unset_data_ipc,
)
elif current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import (
cuda_host_alloc,
cuda_host_free,
set_data_ipc,
share_external_data,
swap_cache_all_layers,
)
from fastdeploy.utils import get_logger
@@ -114,7 +125,6 @@ class CacheTransferManager:
"""
初始化CacheTransferManager
"""
device = args.device_id
rank = args.rank
self.gpu_cache_kvs = {}
@@ -173,8 +183,9 @@ class CacheTransferManager:
suffix=args.engine_pid,
create=False,
)
threading.Thread(target=self.clear_or_update_caches, args=[args], daemon=True).start()
# TODO XPU support RL
if not current_platform.is_xpu():
threading.Thread(target=self.clear_or_update_caches, args=[args], daemon=True).start()
def _init_gpu_cache(self, args):
@@ -185,7 +196,10 @@ class CacheTransferManager:
logger.info(f"[rank {self.rank}/{self.n_ranks}] OK! Stop waiting.")
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
paddle.set_device(f"gpu:{self.device}")
if current_platform.is_cuda():
paddle.set_device(f"gpu:{self.device}")
elif current_platform.is_xpu():
paddle.set_device(f"xpu:{self.device}")
for i in range(args.num_layers + self.num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
cache_shape = [num_gpu_blocks, args.kv_num_head, args.block_size, args.head_dim]
@@ -202,8 +216,12 @@ class CacheTransferManager:
logger.info(f"[rank {self.rank}/{self.n_ranks}] ..attaching kv cache for layer {i}: {cache_shape}")
key_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
val_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
key_cache = share_external_data(key_cache, key_name, cache_shape)
val_cache = share_external_data(val_cache, val_name, cache_shape)
if current_platform.is_xpu():
key_cache = share_external_data(key_cache, key_name, cache_shape, True)
val_cache = share_external_data(val_cache, val_name, cache_shape, True)
else:
key_cache = share_external_data(key_cache, key_name, cache_shape)
val_cache = share_external_data(val_cache, val_name, cache_shape)
self.gpu_cache_kvs[key_name] = key_cache
self.gpu_cache_kvs[val_name] = val_cache
@@ -217,9 +235,10 @@ class CacheTransferManager:
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
logger.info(f"[rank {self.rank}/{self.n_ranks}] device :{self.device}")
logger.info(f"[rank {self.rank}/{self.n_ranks}] cache_kv_size_byte : {cache_kv_size_byte}")
logger.info(
f"[rank {self.rank}/{self.n_ranks}] done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}"
)
if current_platform.is_cuda():
logger.info(
f"[rank {self.rank}/{self.n_ranks}] done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}"
)
def _init_cpu_cache(self, args):
if args.num_cpu_blocks == 0:
@@ -473,7 +492,10 @@ class CacheTransferManager:
time.sleep(0.1)
# clear gpu caches
paddle.set_device(f"gpu:{self.device}")
if current_platform.is_cuda():
paddle.set_device(f"gpu:{self.device}")
elif current_platform.is_xpu():
paddle.set_device(f"xpu:{self.device}")
for name, tensor in self.gpu_cache_kvs.items():
unset_data_ipc(tensor, name, True, False)
self.gpu_cache_kvs.clear()
@@ -543,5 +565,8 @@ if __name__ == "__main__":
args = parse_args()
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_rank{rank_id}.log")
paddle.set_device(f"gpu:{args.device_id}")
if current_platform.is_cuda():
paddle.set_device(f"gpu:{args.device_id}")
elif current_platform.is_xpu():
paddle.set_device(f"xpu:{args.device_id}")
main()