mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[XPU] support prefix cache (#4423)
Co-authored-by: ddchenhao66 <dhaochen163.com>
This commit is contained in:
@@ -30,14 +30,25 @@ from fastdeploy import envs
|
||||
from fastdeploy.cache_manager.cache_data import CacheStatus
|
||||
from fastdeploy.config import SpeculativeConfig
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
cuda_host_alloc,
|
||||
cuda_host_free,
|
||||
set_data_ipc,
|
||||
share_external_data,
|
||||
swap_cache_all_layers,
|
||||
unset_data_ipc,
|
||||
)
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
cuda_host_alloc,
|
||||
cuda_host_free,
|
||||
set_data_ipc,
|
||||
share_external_data,
|
||||
swap_cache_all_layers,
|
||||
unset_data_ipc,
|
||||
)
|
||||
elif current_platform.is_xpu():
|
||||
from fastdeploy.model_executor.ops.xpu import (
|
||||
cuda_host_alloc,
|
||||
cuda_host_free,
|
||||
set_data_ipc,
|
||||
share_external_data,
|
||||
swap_cache_all_layers,
|
||||
)
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
|
||||
@@ -114,7 +125,6 @@ class CacheTransferManager:
|
||||
"""
|
||||
初始化CacheTransferManager
|
||||
"""
|
||||
|
||||
device = args.device_id
|
||||
rank = args.rank
|
||||
self.gpu_cache_kvs = {}
|
||||
@@ -173,8 +183,9 @@ class CacheTransferManager:
|
||||
suffix=args.engine_pid,
|
||||
create=False,
|
||||
)
|
||||
|
||||
threading.Thread(target=self.clear_or_update_caches, args=[args], daemon=True).start()
|
||||
# TODO XPU support RL
|
||||
if not current_platform.is_xpu():
|
||||
threading.Thread(target=self.clear_or_update_caches, args=[args], daemon=True).start()
|
||||
|
||||
def _init_gpu_cache(self, args):
|
||||
|
||||
@@ -185,7 +196,10 @@ class CacheTransferManager:
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] OK! Stop waiting.")
|
||||
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
|
||||
paddle.set_device(f"gpu:{self.device}")
|
||||
if current_platform.is_cuda():
|
||||
paddle.set_device(f"gpu:{self.device}")
|
||||
elif current_platform.is_xpu():
|
||||
paddle.set_device(f"xpu:{self.device}")
|
||||
for i in range(args.num_layers + self.num_extra_layers):
|
||||
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
|
||||
cache_shape = [num_gpu_blocks, args.kv_num_head, args.block_size, args.head_dim]
|
||||
@@ -202,8 +216,12 @@ class CacheTransferManager:
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] ..attaching kv cache for layer {i}: {cache_shape}")
|
||||
key_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
|
||||
val_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
|
||||
key_cache = share_external_data(key_cache, key_name, cache_shape)
|
||||
val_cache = share_external_data(val_cache, val_name, cache_shape)
|
||||
if current_platform.is_xpu():
|
||||
key_cache = share_external_data(key_cache, key_name, cache_shape, True)
|
||||
val_cache = share_external_data(val_cache, val_name, cache_shape, True)
|
||||
else:
|
||||
key_cache = share_external_data(key_cache, key_name, cache_shape)
|
||||
val_cache = share_external_data(val_cache, val_name, cache_shape)
|
||||
|
||||
self.gpu_cache_kvs[key_name] = key_cache
|
||||
self.gpu_cache_kvs[val_name] = val_cache
|
||||
@@ -217,9 +235,10 @@ class CacheTransferManager:
|
||||
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] device :{self.device}")
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] cache_kv_size_byte : {cache_kv_size_byte}")
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}"
|
||||
)
|
||||
if current_platform.is_cuda():
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}"
|
||||
)
|
||||
|
||||
def _init_cpu_cache(self, args):
|
||||
if args.num_cpu_blocks == 0:
|
||||
@@ -473,7 +492,10 @@ class CacheTransferManager:
|
||||
time.sleep(0.1)
|
||||
|
||||
# clear gpu caches
|
||||
paddle.set_device(f"gpu:{self.device}")
|
||||
if current_platform.is_cuda():
|
||||
paddle.set_device(f"gpu:{self.device}")
|
||||
elif current_platform.is_xpu():
|
||||
paddle.set_device(f"xpu:{self.device}")
|
||||
for name, tensor in self.gpu_cache_kvs.items():
|
||||
unset_data_ipc(tensor, name, True, False)
|
||||
self.gpu_cache_kvs.clear()
|
||||
@@ -543,5 +565,8 @@ if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
|
||||
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_rank{rank_id}.log")
|
||||
paddle.set_device(f"gpu:{args.device_id}")
|
||||
if current_platform.is_cuda():
|
||||
paddle.set_device(f"gpu:{args.device_id}")
|
||||
elif current_platform.is_xpu():
|
||||
paddle.set_device(f"xpu:{args.device_id}")
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user