diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index f39232cd2..124988fce 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -30,14 +30,25 @@ from fastdeploy import envs from fastdeploy.cache_manager.cache_data import CacheStatus from fastdeploy.config import SpeculativeConfig from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus -from fastdeploy.model_executor.ops.gpu import ( - cuda_host_alloc, - cuda_host_free, - set_data_ipc, - share_external_data, - swap_cache_all_layers, - unset_data_ipc, -) +from fastdeploy.platforms import current_platform + +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import ( + cuda_host_alloc, + cuda_host_free, + set_data_ipc, + share_external_data, + swap_cache_all_layers, + unset_data_ipc, + ) +elif current_platform.is_xpu(): + from fastdeploy.model_executor.ops.xpu import ( + cuda_host_alloc, + cuda_host_free, + set_data_ipc, + share_external_data, + swap_cache_all_layers, + ) from fastdeploy.utils import get_logger @@ -114,7 +125,6 @@ class CacheTransferManager: """ 初始化CacheTransferManager """ - device = args.device_id rank = args.rank self.gpu_cache_kvs = {} @@ -173,8 +183,9 @@ class CacheTransferManager: suffix=args.engine_pid, create=False, ) - - threading.Thread(target=self.clear_or_update_caches, args=[args], daemon=True).start() + # TODO XPU support RL + if not current_platform.is_xpu(): + threading.Thread(target=self.clear_or_update_caches, args=[args], daemon=True).start() def _init_gpu_cache(self, args): @@ -185,7 +196,10 @@ class CacheTransferManager: logger.info(f"[rank {self.rank}/{self.n_ranks}] OK! Stop waiting.") logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.") - paddle.set_device(f"gpu:{self.device}") + if current_platform.is_cuda(): + paddle.set_device(f"gpu:{self.device}") + elif current_platform.is_xpu(): + paddle.set_device(f"xpu:{self.device}") for i in range(args.num_layers + self.num_extra_layers): num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks cache_shape = [num_gpu_blocks, args.kv_num_head, args.block_size, args.head_dim] @@ -202,8 +216,12 @@ class CacheTransferManager: logger.info(f"[rank {self.rank}/{self.n_ranks}] ..attaching kv cache for layer {i}: {cache_shape}") key_cache = paddle.empty(shape=[], dtype=args.cache_dtype) val_cache = paddle.empty(shape=[], dtype=args.cache_dtype) - key_cache = share_external_data(key_cache, key_name, cache_shape) - val_cache = share_external_data(val_cache, val_name, cache_shape) + if current_platform.is_xpu(): + key_cache = share_external_data(key_cache, key_name, cache_shape, True) + val_cache = share_external_data(val_cache, val_name, cache_shape, True) + else: + key_cache = share_external_data(key_cache, key_name, cache_shape) + val_cache = share_external_data(val_cache, val_name, cache_shape) self.gpu_cache_kvs[key_name] = key_cache self.gpu_cache_kvs[val_name] = val_cache @@ -217,9 +235,10 @@ class CacheTransferManager: cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()]) logger.info(f"[rank {self.rank}/{self.n_ranks}] device :{self.device}") logger.info(f"[rank {self.rank}/{self.n_ranks}] cache_kv_size_byte : {cache_kv_size_byte}") - logger.info( - f"[rank {self.rank}/{self.n_ranks}] done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}" - ) + if current_platform.is_cuda(): + logger.info( + f"[rank {self.rank}/{self.n_ranks}] done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}" + ) def _init_cpu_cache(self, args): if args.num_cpu_blocks == 0: @@ -473,7 +492,10 @@ class CacheTransferManager: time.sleep(0.1) # clear gpu caches - paddle.set_device(f"gpu:{self.device}") + if current_platform.is_cuda(): + paddle.set_device(f"gpu:{self.device}") + elif current_platform.is_xpu(): + paddle.set_device(f"xpu:{self.device}") for name, tensor in self.gpu_cache_kvs.items(): unset_data_ipc(tensor, name, True, False) self.gpu_cache_kvs.clear() @@ -543,5 +565,8 @@ if __name__ == "__main__": args = parse_args() rank_id = args.rank + args.local_data_parallel_id * args.mp_num logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_rank{rank_id}.log") - paddle.set_device(f"gpu:{args.device_id}") + if current_platform.is_cuda(): + paddle.set_device(f"gpu:{args.device_id}") + elif current_platform.is_xpu(): + paddle.set_device(f"xpu:{args.device_id}") main() diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 123b33cfd..d24923bd2 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -410,7 +410,7 @@ class EngineArgs: self.enable_prefix_caching = False if self.speculative_config is not None: self.enable_prefix_caching = False - if not current_platform.is_cuda(): + if not current_platform.is_cuda() and not current_platform.is_xpu(): self.enable_prefix_caching = False # if self.dynamic_load_weight: # self.enable_prefix_caching = False diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index f04680498..985e2a911 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -26,6 +26,7 @@ from fastdeploy import envs from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request, RequestType from fastdeploy.input.ernie4_5_vl_processor import DataProcessor +from fastdeploy.inter_communicator import IPCSignal from fastdeploy.model_executor.forward_meta import ForwardMeta, XPUForwardMeta from fastdeploy.model_executor.graph_optimization.utils import ( profile_run_guard, @@ -45,6 +46,8 @@ from fastdeploy.model_executor.ops.xpu import ( get_infer_param, get_padding_offset, recover_decode_task, + set_data_ipc, + share_external_data, update_inputs_v1, ) from fastdeploy.utils import get_logger @@ -335,11 +338,19 @@ def step_paddle( class XPUModelRunner(ModelRunnerBase): """ """ - def __init__(self, fd_config: FDConfig, device: str, rank: int, local_rank: int): + def __init__( + self, + fd_config: FDConfig, + device: str, # logic device + device_id: int, # physical device id + rank: int, + local_rank: int, + ): super().__init__(fd_config=fd_config, device=device) self.enable_mm = self.model_config.enable_mm self.rank = rank self.local_rank = local_rank + self.device_id = device_id self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop # VL model config: @@ -895,11 +906,11 @@ class XPUModelRunner(ModelRunnerBase): for attn_backend in self.attn_backends: attn_backend.init_attention_metadata(self.forward_meta) - def initialize_kv_cache(self) -> None: + def initialize_kv_cache(self, profile: bool = False) -> None: """ Initialize kv cache """ - cache_kvs = {} + # cache_kvs = {} max_block_num = self.num_gpu_blocks # Get kv cache dtype @@ -914,21 +925,56 @@ class XPUModelRunner(ModelRunnerBase): # Get kv cache shape kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num) + local_rank = self.local_rank % self.parallel_config.tensor_parallel_size + + cache_ready_signal_data = np.zeros(shape=[self.parallel_config.tensor_parallel_size], dtype=np.int32) + cache_ready_signal = IPCSignal( + name="cache_ready_signal", + array=cache_ready_signal_data, + dtype=np.int32, + suffix=self.parallel_config.engine_worker_queue_port, + create=False, + ) + + # Check if gpu runner needs to create kv cache + # 1. During profiling, it creates its own kv cache. + # 2. GPU runner creates kv cache tensor unless p/d disaggregation is enabled. + create_cache_tensor = profile or self.scheduler_config.splitwise_role == "mixed" + if not create_cache_tensor: + logger.info(f"Waiting for cache managers to create kv cache.. {cache_ready_signal.value}") + while cache_ready_signal.value[local_rank] != 1: + time.sleep(1) + logger.info(f"OK! Stop waiting. {cache_ready_signal.value}") + + logger.info(f"Initializing kv cache for all layers. {cache_ready_signal.value}") + cache_kvs_list = [] for i in range(self.model_config.num_hidden_layers): - cache_kvs[f"key_caches_{i}"] = paddle.full( - shape=kv_cache_shape, - fill_value=0, - dtype=cache_type, - ) - cache_kvs[f"value_caches_{i}"] = paddle.full( - shape=kv_cache_shape, - fill_value=0, - dtype=cache_type, - ) - self.share_inputs["caches"] = list(cache_kvs.values()) - for value in cache_kvs.values(): - del value + key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}" + val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}" + + if create_cache_tensor: + logger.info(f"..creating kv cache for layer {i}: {kv_cache_shape}") + key_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type) + set_data_ipc(key_cache, key_cache_name) + val_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type) + set_data_ipc(val_cache, val_cache_name) + cache_kvs_list.extend([key_cache, val_cache]) + + else: + logger.info(f"..attaching kv cache for layer {i}: {kv_cache_shape}") + key_cache = paddle.empty(shape=[], dtype=cache_type) + key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape, False) + val_cache = paddle.empty(shape=[], dtype=cache_type) + val_cache = share_external_data(val_cache, val_cache_name, kv_cache_shape, False) + cache_kvs_list.extend([key_cache, val_cache]) + + self.share_inputs["caches"] = cache_kvs_list + + if not profile and create_cache_tensor: + cache_ready_signal.value[local_rank] = 1 + logger.info(f"✅ kv cache is ready! {cache_ready_signal.value}") + paddle.device.xpu.empty_cache() def initialize_attn_backend(self) -> None: @@ -1138,18 +1184,12 @@ class XPUModelRunner(ModelRunnerBase): return None - def prepare_profile(self) -> None: - """Prepare the profile run by setting the block number and initializing the KV cache.""" - paddle.device.xpu.empty_cache() - self.num_gpu_blocks = self.parallel_config.total_block_num - self.initialize_kv_cache() - @profile_run_guard(True) def profile_run(self) -> None: """Execute a forward pass with dummy inputs to profile the memory usage of the model""" self.num_gpu_blocks = self.parallel_config.total_block_num - self.initialize_kv_cache() + self.initialize_kv_cache(profile=True) self._dummy_run( num_tokens=int(self.scheduler_config.max_num_batched_tokens), diff --git a/fastdeploy/worker/xpu_worker.py b/fastdeploy/worker/xpu_worker.py index ef7450ec7..6ffbb4f26 100644 --- a/fastdeploy/worker/xpu_worker.py +++ b/fastdeploy/worker/xpu_worker.py @@ -23,6 +23,7 @@ from paddle import nn from fastdeploy import envs from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request +from fastdeploy.platforms import current_platform from fastdeploy.utils import get_logger, set_random_seed from fastdeploy.worker.output import ModelRunnerOutput from fastdeploy.worker.worker_base import WorkerBase @@ -49,10 +50,11 @@ class XpuWorker(WorkerBase): def init_device(self): """Initialize device and Construct model runner""" + self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 if paddle.is_compiled_with_xpu(): # Set environment variable self.device_ids = self.parallel_config.device_ids.split(",") - self.device = f"xpu:{self.local_rank}" + self.device = f"xpu:{self.local_rank % self.max_chips_per_node}" paddle.device.set_device(self.device) paddle.set_default_dtype(self.parallel_config.dtype) @@ -67,6 +69,7 @@ class XpuWorker(WorkerBase): fd_config=self.fd_config, device=self.device, rank=self.rank, + device_id=int(self.device_ids[self.local_rank % self.max_chips_per_node]), local_rank=self.local_rank, ) @@ -109,7 +112,6 @@ class XpuWorker(WorkerBase): used_memory: {used_memory}, free_memory: {free_memory}" ) - self.model_runner.prepare_profile() if self.parallel_config.use_ep: logger.warning("EP mode does not support profile run.") else: