diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index 1cc8f4d31..547fcfd60 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -14,18 +14,72 @@ # limitations under the License. """ +import argparse +import json import math -import threading import time - +import threading import numpy as np import paddle from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager +from fastdeploy.config import SpeculativeConfig from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal +from fastdeploy.model_executor.ops.gpu import set_data_ipc from fastdeploy.utils import get_logger -logger = get_logger("cache_messager", "cache_messager.log") + +def parse_args(): + """ + 从命令行解析参数 + """ + parser = argparse.ArgumentParser("Cache Messager") + parser.add_argument( + "--splitwise_role", + type=str, + default="mixed", + help="splitwise role, can be decode, prefill or mixed", + ) + parser.add_argument("--rank", type=int, default=0, help="current rank") + parser.add_argument("--device_id", type=int, default=0, help="device id") + parser.add_argument("--num_hidden_layers", type=int, default=1, help="model num layers") + parser.add_argument("--head_dim", type=int, default=1, help="model head dim") + parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head") + parser.add_argument("--rdma_port", type=str, default="", help="rmda port") + parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel") + parser.add_argument("--engine_pid", type=str, default=None, help="engine pid") + parser.add_argument( + "--protocol", + type=str, + default="ipc", + help="cache transfer protocol, only surport ipc now", + ) + parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip") + parser.add_argument( + "--engine_worker_queue_port", + type=int, + default=9923, + help="engine worker queue port", + ) + parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number") + parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)") + parser.add_argument( + "--cache_dtype", + type=str, + default="bfloat16", + choices=["uint8", "bfloat16"], + help="cache dtype", + ) + parser.add_argument( + "--speculative_config", + type=json.loads, + default="{}", + help="speculative config", + ) + parser.add_argument("--local_data_parallel_id", type=int, default=0) + + args = parser.parse_args() + return args class CacheMessager: @@ -43,7 +97,7 @@ class CacheMessager: gpu_cache_kvs, rank, nranks, - num_layers, + num_hidden_layers, gpu_id=0, rdma_port=None, ): @@ -57,7 +111,7 @@ class CacheMessager: gpu_cache_kvs (dict): GPU kv cache rank (int): current rank nranks (int): global rank number - num_layers (int): model layer number + num_hidden_layers (int): model layer number gpu_id (int, optional): GPU ID rdma_port (int, optional): RDMA port @@ -86,13 +140,13 @@ class CacheMessager: logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}") # 1. initialize the cache_k_ptr_list and cache_v_ptr_list - self.num_layers = num_layers + self.num_hidden_layers = num_hidden_layers cache_k_ptr_list = [] cache_v_ptr_list = [] cache_k = [] cache_v = [] self.messager = {} - for layer_idx in range(self.num_layers): + for layer_idx in range(self.num_hidden_layers): key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"] val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"] cache_k.append(key_cache) @@ -109,7 +163,7 @@ class CacheMessager: if key_cache.dtype == paddle.bfloat16: block_bytes *= 2 logger.info( - f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, " + f"layers {num_hidden_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, " f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}" ) self.block_bytes = block_bytes @@ -144,17 +198,13 @@ class CacheMessager: self.cache_info = dict() self.rank_id = self.rank + local_data_parallel_id * self.nranks # align with engine worker rank (paddle.distributed.launch) - layerwise_send_cache_thread = threading.Thread(target=self._prefill_layerwise_send_cache_thread) - layerwise_send_cache_thread.daemon = True - layerwise_send_cache_thread.start() - connect_rdma_thread = threading.Thread(target=self._handle_connect_task) connect_rdma_thread.daemon = True connect_rdma_thread.start() logger.info(f"cache messager init finished, use {transfer_protocol}") - def _prefill_layerwise_send_cache_thread(self): + def prefill_layerwise_send_cache_thread(self): """ layerwise_send_cache_thread: send cache to other instance @@ -204,7 +254,7 @@ class CacheMessager: cache_info = self.engine_worker_queue.get_cache_info() if cache_info: - logger.debug(f"cache info {cache_info}") + logger.info(f"cache info {cache_info}") for info in cache_info: if info["request_id"] in self.cache_info: self.cache_info[info["request_id"]].update(info) @@ -223,7 +273,7 @@ class CacheMessager: self.cache_info[info["request_id"]] = info prefilled_layer_idx = layer_shm_value.value[0] prefilled_step_idx = step_shm_value.value[0] - if prefilled_layer_idx == self.num_layers - 1: + if prefilled_layer_idx == self.num_hidden_layers - 1: time.sleep(0.001) prefilled_layer_idx = layer_shm_value.value[0] prefilled_step_idx = step_shm_value.value[0] @@ -234,7 +284,7 @@ class CacheMessager: if not self.cache_info: time.sleep(0.001) continue - logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}") + logger.info(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}") for req_id, item in list(self.cache_info.items()): if "status" not in item: continue @@ -251,7 +301,7 @@ class CacheMessager: target_id = int(item["rdma_ports"][self.rank]) status = self.messager[current_transfer_protocol].connect(target_ip, target_id) if not status: - logger.error(f"connect to {target_ip}:{target_id} failed") + logger.info(f"connect to {target_ip}:{target_id} failed") item["status"] = "error" self.engine_worker_queue.finish_request_barrier.wait() if self.rank == 0: @@ -263,7 +313,7 @@ class CacheMessager: src_block_ids = paddle.to_tensor(item["src_block_ids"], dtype="int32", place="cpu") dest_block_ids = paddle.to_tensor(item["dest_block_ids"], dtype="int32", place="cpu") if item["current_id"] < prefilled_step_idx: - current_layer_idx = self.num_layers + current_layer_idx = self.num_hidden_layers else: current_layer_idx = prefilled_layer_idx + 1 @@ -281,7 +331,7 @@ class CacheMessager: self.engine_worker_queue.finish_request_barrier.wait() if self.rank == 0: self.engine_worker_queue.put_finished_req([(item["request_id"], "write cache error")]) - logger.error( + logger.info( f"write cache failed, layer_idx: {layer_idx}, " f"req_id: {item['request_id']}, dest_ip: {target_ip}" ) @@ -292,14 +342,14 @@ class CacheMessager: block_num = len(src_block_ids) avg_time_per_block = cost_time * 1000 / block_num # ms send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s - logger.debug( + logger.info( f"finish write cache for a layer, {item['request_id']}, {layer_idx}" f" {current_transfer_protocol}" f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)}," f"avg_time per block(ms): {round(avg_time_per_block, 5)}" ) item["layer_idx"] = current_layer_idx - if item["layer_idx"] == self.num_layers: + if item["layer_idx"] == self.num_hidden_layers: if item["transfer_protocol"] == "ipc": self.messager["ipc"].write_block_by_sync(target_id) logger.info(f"finish write cache {item['request_id']}") @@ -313,8 +363,8 @@ class CacheMessager: self.last_layer_idx = prefilled_layer_idx except Exception as e: - logger.error(f"prefill layerwise send cache thread has exception: {e}") - + logger.info(f"prefill layerwise send cache thread has exception: {e}") + def _handle_connect_task(self): while True: try: @@ -333,3 +383,90 @@ class CacheMessager: self.engine_worker_queue.put_connect_rdma_task_response(response) except Exception as e: logger.error(f"handle_connect_task has exception: {e}") + + +def main(): + device = args.device_id + rank = args.rank + paddle.set_device(f"gpu:{device}") + cache_type = args.cache_dtype + speculative_config = SpeculativeConfig(args.speculative_config) + num_extra_layers = speculative_config.num_extra_cache_layer + num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * speculative_config.num_gpu_block_expand_ratio) + gpu_cache_kvs = {} + gpu_cache_k_tensors = [] + gpu_cache_v_tensors = [] + + for i in range(args.num_hidden_layers + num_extra_layers): + num_gpu_blocks = args.num_gpu_blocks if i < args.num_hidden_layers else num_extra_layer_gpu_blocks + + gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full( + shape=[ + num_gpu_blocks, + args.kv_num_head, + args.block_size, + args.head_dim, + ], + fill_value=0, + dtype=cache_type, + ) + gpu_cache_k_tensors.append(gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"]) + gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full( + shape=[ + num_gpu_blocks, + args.kv_num_head, + args.block_size, + args.head_dim, + ], + fill_value=0, + dtype=cache_type, + ) + gpu_cache_v_tensors.append(gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"]) + + set_data_ipc( + gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"], + f"key_caches_{i}_rank{rank}.device{device}", + ) + set_data_ipc( + gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"], + f"value_caches_{i}_rank{rank}.device{device}", + ) + cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()]) + logger.info(f"device :{device}") + logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}") + logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}") + + cache_messager = CacheMessager( + splitwise_role=args.splitwise_role, + transfer_protocol=args.protocol, + pod_ip=args.pod_ip, + engine_worker_queue_port=args.engine_worker_queue_port, + local_data_parallel_id=args.local_data_parallel_id, + gpu_cache_kvs=gpu_cache_kvs, + rank=rank, + nranks=args.mp_num, + num_hidden_layers=args.num_hidden_layers + num_extra_layers, + gpu_id=device, + rdma_port=args.rdma_port, + ) + + cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32) + cache_ready_signal = IPCSignal( + name="cache_ready_signal", + array=cache_ready_signal_data, + dtype=np.int32, + suffix=args.engine_pid, + create=False, + ) + cache_ready_signal.value[rank] = 1 + cache_messager.prefill_layerwise_send_cache_thread() + + +if __name__ == "__main__": + + args = parse_args() + logger = get_logger("cache_messager", "cache_messager.log") + + logger.info("create cache messager...") + logger.info(f"{args}") + main() diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index 34ccf144c..c9f062201 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -28,7 +28,7 @@ from fastdeploy.config import SpeculativeConfig from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal from fastdeploy.model_executor.ops.gpu import ( cuda_host_alloc, - set_data_ipc, + share_external_data, swap_cache_all_layers, ) from fastdeploy.utils import get_logger @@ -39,26 +39,12 @@ def parse_args(): 从命令行解析参数 """ parser = argparse.ArgumentParser("Cache transfer manager") - parser.add_argument( - "--splitwise_role", - type=str, - default="mixed", - help="splitwise role, can be decode, prefill or mixed", - ) parser.add_argument("--rank", type=int, default=0, help="current rank") parser.add_argument("--device_id", type=int, default=0, help="device id") - parser.add_argument("--num_layers", type=int, default=1, help="model num layers") + parser.add_argument("--num_hidden_layers", type=int, default=1, help="model num layers") parser.add_argument("--head_dim", type=int, default=1, help="model head dim") parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head") - parser.add_argument("--rdma_port", type=str, default="", help="rmda port") parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel") - parser.add_argument( - "--protocol", - type=str, - default="ipc", - help="cache transfer protocol, only surport ipc now", - ) - parser.add_argument("--enable_splitwise", type=int, default=0, help="enable splitwise ") parser.add_argument("--cache_queue_port", type=int, default=9923, help="cache queue port") parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip") parser.add_argument( @@ -68,7 +54,6 @@ def parse_args(): help="engine worker queue port", ) parser.add_argument("--engine_pid", type=str, default=None, help="engine pid") - parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number") parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number") parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)") @@ -109,7 +94,6 @@ class CacheTransferManager: device = args.device_id rank = args.rank - paddle.set_device(f"gpu:{device}") self.gpu_cache_kvs = {} self.cpu_cache_kvs = {} self.gpu_cache_k_tensors = [] @@ -138,40 +122,27 @@ class CacheTransferManager: self.num_cpu_blocks = args.num_cpu_blocks cache_type = args.cache_dtype - for i in range(args.num_layers + self.num_extra_layers): - num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks + cache_shape = [ + args.num_gpu_blocks, + args.kv_num_head, + args.block_size, + args.head_dim, + ] - self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full( - shape=[ - num_gpu_blocks, - args.kv_num_head, - args.block_size, - args.head_dim, - ], - fill_value=0, - dtype=cache_type, - ) - self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"]) - self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full( - shape=[ - num_gpu_blocks, - args.kv_num_head, - args.block_size, - args.head_dim, - ], - fill_value=0, - dtype=cache_type, - ) - self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"]) + for i in range(args.num_hidden_layers + self.num_extra_layers): + num_gpu_blocks = args.num_gpu_blocks if i < args.num_hidden_layers else self.num_extra_layer_gpu_blocks + cache_shape[0] = num_gpu_blocks + key_name = f"key_caches_{i}_rank{rank}.device{device}" + value_name = f"value_caches_{i}_rank{rank}.device{device}" + key_cache = paddle.empty(shape=[], dtype=cache_type) + value_cache = paddle.empty(shape=[], dtype=cache_type) + key_cache = share_external_data(key_cache, key_name, cache_shape) + value_cache = share_external_data(value_cache, value_name, cache_shape) + self.gpu_cache_kvs[key_name] = key_cache + self.gpu_cache_kvs[value_name] = value_cache + self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name]) + self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[value_name]) - set_data_ipc( - self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"], - f"key_caches_{i}_rank{rank}.device{device}", - ) - set_data_ipc( - self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"], - f"value_caches_{i}_rank{rank}.device{device}", - ) cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()]) logger.info(f"device :{self.device}") logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}") @@ -180,7 +151,7 @@ class CacheTransferManager: paddle.set_device("cpu") self.k_dst_ptrs = [] self.v_dst_ptrs = [] - for i in range(args.num_layers + self.num_extra_layers): + for i in range(args.num_hidden_layers + self.num_extra_layers): self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"] = cuda_host_alloc( args.num_cpu_blocks * args.bytes_per_layer_per_block ) @@ -190,38 +161,6 @@ class CacheTransferManager: ) self.v_dst_ptrs.append(self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"]) - cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32) - self.cache_ready_signal = IPCSignal( - name="cache_ready_signal", - array=cache_ready_signal_data, - dtype=np.int32, - suffix=args.engine_pid, - create=False, - ) - self.cache_ready_signal.value[self.rank] = 1 - - paddle.set_device(f"gpu:{device}") - if args.enable_splitwise: - logger.debug("create cache messager...") - logger.info(f"{args}") - from fastdeploy.cache_manager.cache_messager import CacheMessager - - self.cache_messager = CacheMessager( - splitwise_role=args.splitwise_role, - transfer_protocol=args.protocol, - pod_ip=args.pod_ip, - engine_worker_queue_port=args.engine_worker_queue_port, - local_data_parallel_id=args.local_data_parallel_id, - gpu_cache_kvs=self.gpu_cache_kvs, - rank=self.rank, - nranks=args.mp_num, - num_layers=args.num_layers + self.num_extra_layers, - gpu_id=self.device, - rdma_port=args.rdma_port, - ) - logger.info("successfully create cache messager") - logger.info(f"done init CacheMessager gmem alloc : {paddle.device.cuda.memory_allocated()}") - cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32) self.cache_task_broadcast_signal = IPCSignal( name="cache_task_broadcast_signal", diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 0ac34ad6a..e08e86eab 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -141,6 +141,76 @@ class PrefixCacheManager: filename = "cache_transfer_manager.py" py_path = os.path.join(current_dir_path, filename) + cache_messager_processes = [] + if self.splitwise_role != "mixed": + cache_messager_processes = self.launch_cache_messager( + cache_config, + tensor_parallel_size, + device_ids, + pod_ip, + engine_worker_queue_port, + pid_suffix, + ) + if cache_messager_processes is None: + raise RuntimeError("Launch cache messager failed") + return [] + + if ( + hasattr(cache_config.model_cfg, "num_key_value_heads") + and hasattr(cache_config.model_cfg, "num_key_value_heads") + and cache_config.model_cfg.num_key_value_heads is not None + and int(cache_config.model_cfg.num_key_value_heads) > 0 + ): + kv_num_head = int(cache_config.model_cfg.num_key_value_heads) // tensor_parallel_size + else: + kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size + + log_dir = envs.FD_LOG_DIR + cache_manager_processes = [] + for i in range(tensor_parallel_size): + launch_cmd = ( + f" {sys.executable} {py_path}" + + f" --device_id {int(device_ids[i])}" + + f" --rank {i}" + + f" --num_hidden_layers {cache_config.model_cfg.num_hidden_layers}" + + f" --head_dim {cache_config.model_cfg.head_dim}" + + f" --kv_num_head {kv_num_head}" + + f" --mp_num {tensor_parallel_size}" + + f" --cache_dtype {cache_config.cache_dtype}" + + f" --cache_queue_port {cache_config.cache_queue_port}" + + f" --pod_ip {pod_ip}" + + f" --engine_worker_queue_port {engine_worker_queue_port}" + + f" --num_gpu_blocks {cache_config.total_block_num}" + + f" --num_cpu_blocks {cache_config.num_cpu_blocks}" + + f" --bytes_per_layer_per_block {cache_config.bytes_per_layer_per_block}" + + f" --block_size {cache_config.block_size}" + + f" --engine_pid {pid_suffix}" + + f" --local_data_parallel_id {self.local_data_parallel_id}" + + f" --speculative_config '{self.speculative_config.to_json_string()}'" + + f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1" + ) + logger.info(f"Launch cache transfer manager, command:{launch_cmd}") + cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid)) + exit_code = cache_manager_processes[-1].poll() + if exit_code is None: + logger.info("Launch cache transfer manager successful") + else: + logger.info("Launch cache transfer manager failed, see launch_cache_manager.log for more information") + + if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0: + logger.info("Enable hierarchical cache.") + self._enable_cpu_cache() + cache_manager_processes.extend(cache_messager_processes) + return cache_manager_processes + + def launch_cache_messager( + self, cache_config, tensor_parallel_size, device_ids, pod_ip, engine_worker_queue_port, pid_suffix + ): + """ + launch_cache_messager function used to initialize the cache messager. + """ + current_dir_path = os.path.split(os.path.abspath(__file__))[0] + filename = "cache_messager.py" if ( hasattr(cache_config.model_cfg, "num_key_value_heads") and hasattr(cache_config.model_cfg, "num_key_value_heads") @@ -159,8 +229,10 @@ class PrefixCacheManager: suffix=pid_suffix, create=True, ) + + py_path = os.path.join(current_dir_path, filename) log_dir = envs.FD_LOG_DIR - cache_manager_processes = [] + cache_messager_processes = [] for i in range(tensor_parallel_size): launch_cmd = ( "FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7" @@ -169,42 +241,34 @@ class PrefixCacheManager: + f" --device_id {int(device_ids[i])}" + f" --rank {i}" + f" --splitwise_role {self.splitwise_role}" - + f" --num_layers {cache_config.model_cfg.num_hidden_layers}" + + f" --num_hidden_layers {cache_config.model_cfg.num_hidden_layers}" + f" --head_dim {cache_config.model_cfg.head_dim}" + f" --kv_num_head {kv_num_head}" + f" --mp_num {tensor_parallel_size}" + f" --cache_dtype {cache_config.cache_dtype}" - + f" --cache_queue_port {cache_config.cache_queue_port}" - + f" --enable_splitwise {int(self.enable_splitwise)}" + f" --pod_ip {pod_ip}" + f" --engine_worker_queue_port {engine_worker_queue_port}" + f" --num_gpu_blocks {cache_config.total_block_num}" - + f" --num_cpu_blocks {cache_config.num_cpu_blocks}" - + f" --bytes_per_layer_per_block {cache_config.bytes_per_layer_per_block}" + f" --block_size {cache_config.block_size}" - + f" --engine_pid {pid_suffix}" + f" --protocol {cache_config.cache_transfer_protocol}" + f" --local_data_parallel_id {self.local_data_parallel_id}" + + f" --engine_pid {pid_suffix}" + f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}" + f" --speculative_config '{self.speculative_config.to_json_string()}'" - + f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1" + + f" >{log_dir}/launch_cache_messager_{int(device_ids[i])}.log 2>&1" ) - logger.info(f"Launch cache transfer manager, command:{launch_cmd}") - cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid)) - # 等待cache初始化完毕 - logger.info("Waiting for cache transfer manager ready...") + logger.info(f"Launch cache messager, command:{launch_cmd}") + cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid)) + logger.info("Waiting for cache ready...") while np.sum(self.cache_ready_signal.value) != tensor_parallel_size: time.sleep(1) - exit_code = cache_manager_processes[-1].poll() + exit_code = cache_messager_processes[-1].poll() if exit_code is None: - logger.info("Launch cache transfer manager successful") + logger.info("Launch cache messager successful") else: - logger.info("Launch cache transfer manager failed, see launch_cache_manager.log for more information") - - if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0: - logger.info("Enable hierarchical cache.") - self._enable_cpu_cache() - return cache_manager_processes + logger.info("Launch cache messager failed, see launch_cache_messager.log for more information") + cache_messager_processes = None + return cache_messager_processes def update_cache_config(self, cache_config): """ diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 6ed550509..dccd53ced 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -775,10 +775,6 @@ class LLMEngine: """ Insert tasks to engine. """ - for task in tasks: - start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER) - if task.sampling_params.bad_words is not None: - task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer) # TODO 返回至 scheduler if allocated: current_tasks = [] @@ -805,6 +801,11 @@ class LLMEngine: self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz)) return True + for task in tasks: + start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER) + if task.sampling_params.bad_words is not None: + task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer) + self.resource_manager.check_and_free_block_tables() if not isinstance(tasks, list): @@ -846,11 +847,10 @@ class LLMEngine: llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}") for task in tasks: task.inference_start_time = time.time() - if not is_prefill: - if not self.cfg.enable_mm: - self.update_requests_chunk_size(tasks) - else: - self.update_mm_requests_chunk_size(tasks) + if not self.cfg.enable_mm: + self.update_requests_chunk_size(tasks) + else: + self.update_mm_requests_chunk_size(tasks) self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz)) if is_prefill and self.cfg.scheduler_config.name != "splitwise": self.engine_worker_queue.available_prefill_instances.put(1) @@ -992,14 +992,17 @@ class LLMEngine: self.running = False if hasattr(self, "cache_manager_processes"): - self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear() - self.resource_manager.cache_manager.cache_ready_signal.clear() for p in self.cache_manager_processes: llm_logger.info(f"Killing cache manager process {p.pid}") try: os.killpg(p.pid, signal.SIGTERM) except Exception as e: print(f"Error extracting file: {e}") + if hasattr(self.resource_manager.cache_manager, "cache_ready_signal"): + self.resource_manager.cache_manager.cache_ready_signal.clear() + self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear() + if hasattr(self, "zmq_server") and self.zmq_server is not None: + self.zmq_server.close() self.worker_ready_signal.clear() self.exist_task_signal.clear() self.exist_swapped_task_signal.clear() @@ -1024,6 +1027,7 @@ class LLMEngine: if hasattr(self, "dp_processed"): for p in self.dp_processed: p.join() + self.engine_worker_queue_server.cleanup() def _setting_environ_variables(self): """ diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 39f0fce42..e9c1e63a4 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -37,6 +37,7 @@ from fastdeploy.model_executor.ops.gpu import ( eagle_get_self_hidden_states, mtp_save_first_token, mtp_step_paddle, + set_data_ipc, share_external_data, ) from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding @@ -141,9 +142,7 @@ class MTPProposer(Proposer): kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type ) - if not self.parallel_config.do_profile and ( - self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed" - ): + if not self.parallel_config.do_profile and self.parallel_config.splitwise_role != "mixed": cache_kvs_list = [] for i in range( self.num_main_model_layers, @@ -160,7 +159,10 @@ class MTPProposer(Proposer): self.model_inputs["caches"] = cache_kvs_list else: - for i in range(self.model_config.num_hidden_layers): + for i in range( + self.num_main_model_layers, + self.num_main_model_layers + self.model_config.num_hidden_layers, + ): self.cache_kvs[f"key_caches_{i}"] = paddle.full( shape=kv_cache_shape, fill_value=0, @@ -171,6 +173,15 @@ class MTPProposer(Proposer): fill_value=0, dtype=cache_type, ) + if self.cache_config.enable_prefix_caching: + set_data_ipc( + self.cache_kvs[f"key_caches_{i}"], + f"key_caches_{i}_rank{self.local_rank}.device{self.device_id}", + ) + set_data_ipc( + self.cache_kvs[f"value_caches_{i}"], + f"value_caches_{i}_rank{self.local_rank}.device{self.device_id}", + ) self.model_inputs["caches"] = list(self.cache_kvs.values()) for value in self.cache_kvs.values(): del value @@ -235,7 +246,7 @@ class MTPProposer(Proposer): self.main_model_num_gpu_blocks = num_gpu_blocks self.num_gpu_blocks = int(num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio) - if not (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"): + if self.parallel_config.splitwise_role == "mixed": self.initialize_kv_cache() # Reset free list diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 4b67b595e..59f44edb5 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -43,6 +43,7 @@ from fastdeploy.model_executor.layers.sample.sampler import Sampler, Speculative from fastdeploy.model_executor.model_loader import get_model_loader from fastdeploy.model_executor.ops.gpu import ( recover_decode_task, + set_data_ipc, set_value_by_flags_and_idx, share_external_data, ) @@ -904,7 +905,7 @@ class GPUModelRunner(ModelRunnerBase): ) local_rank = self.local_rank % self.parallel_config.tensor_parallel_size - if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"): + if not profile and self.parallel_config.splitwise_role != "mixed": cache_kvs_list = [] for i in range(self.model_config.num_hidden_layers): key_cache = paddle.empty(shape=[], dtype=cache_type) @@ -930,6 +931,15 @@ class GPUModelRunner(ModelRunnerBase): fill_value=0, dtype=cache_type, ) + if self.cache_config.enable_prefix_caching: + set_data_ipc( + cache_kvs[f"key_caches_{i}"], + f"key_caches_{i}_rank{local_rank}.device{self.device_id}", + ) + set_data_ipc( + cache_kvs[f"value_caches_{i}"], + f"value_caches_{i}_rank{local_rank}.device{self.device_id}", + ) self.share_inputs["caches"] = list(cache_kvs.values()) for value in cache_kvs.values(): del value @@ -1138,6 +1148,8 @@ class GPUModelRunner(ModelRunnerBase): if task.chunk_idx > len(task.prefill_chunk_info): continue self.restore_chunked_prefill_request[task.request_id] = task + if len(self.restore_chunked_prefill_request) > 0: + self.share_inputs["not_need_stop"][0] = True for id, task in list(self.restore_chunked_prefill_request.items()): idx = task.idx @@ -1182,7 +1194,7 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size self.share_inputs["prompt_lens"][idx : idx + 1] += token_chunk_size self.share_inputs["step_idx"][idx : idx + 1] = 0 - + self.share_inputs["stop_flags"][idx : idx + 1] = False if self.speculative_decoding and self.proposer.is_chunk_prefill_enabled(): self.proposer.update_task_chunk_prefill(task) task.chunk_idx += 1 @@ -1507,12 +1519,12 @@ class GPUModelRunner(ModelRunnerBase): hidden_dim = self.model_config.head_dim * self.model_config.kv_num_heads # NOTE(liuzichang): Implement multi-layer MTP architecture in the future - num_layers = ( + num_hidden_layers = ( self.model_config.num_hidden_layers + self.speculative_config.num_gpu_block_expand_ratio if self.speculative_method in ["mtp"] else self.model_config.num_hidden_layers ) - required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v + required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_hidden_layers # k + v return required_memory def not_need_stop(self) -> bool: diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index eace66487..1573714b5 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -408,7 +408,7 @@ class PaddleDisWorkerProc: logger.info(f"------- num_blocks_global: {num_blocks_local} --------") # wait engine launch cache_manager - if self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed": + if self.parallel_config.splitwise_role != "mixed": launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32) self.launched_cache_manager_signal = IPCSignal( name="launched_cache_manager_signal",