mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] Optimize prefix cache (#3208)
* [LLM] support ep * Update worker_process.py * Update expert_service.py * Update worker_process.py * format files * optimize prefix cache * optimize prefix cache * optimize prefix cache * pre commit format * pre commit format * pre commit format * Update cache_messager.py
This commit is contained in:
@@ -141,6 +141,76 @@ class PrefixCacheManager:
|
||||
filename = "cache_transfer_manager.py"
|
||||
py_path = os.path.join(current_dir_path, filename)
|
||||
|
||||
cache_messager_processes = []
|
||||
if self.splitwise_role != "mixed":
|
||||
cache_messager_processes = self.launch_cache_messager(
|
||||
cache_config,
|
||||
tensor_parallel_size,
|
||||
device_ids,
|
||||
pod_ip,
|
||||
engine_worker_queue_port,
|
||||
pid_suffix,
|
||||
)
|
||||
if cache_messager_processes is None:
|
||||
raise RuntimeError("Launch cache messager failed")
|
||||
return []
|
||||
|
||||
if (
|
||||
hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||
and hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||
and cache_config.model_cfg.num_key_value_heads is not None
|
||||
and int(cache_config.model_cfg.num_key_value_heads) > 0
|
||||
):
|
||||
kv_num_head = int(cache_config.model_cfg.num_key_value_heads) // tensor_parallel_size
|
||||
else:
|
||||
kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size
|
||||
|
||||
log_dir = envs.FD_LOG_DIR
|
||||
cache_manager_processes = []
|
||||
for i in range(tensor_parallel_size):
|
||||
launch_cmd = (
|
||||
f" {sys.executable} {py_path}"
|
||||
+ f" --device_id {int(device_ids[i])}"
|
||||
+ f" --rank {i}"
|
||||
+ f" --num_hidden_layers {cache_config.model_cfg.num_hidden_layers}"
|
||||
+ f" --head_dim {cache_config.model_cfg.head_dim}"
|
||||
+ f" --kv_num_head {kv_num_head}"
|
||||
+ f" --mp_num {tensor_parallel_size}"
|
||||
+ f" --cache_dtype {cache_config.cache_dtype}"
|
||||
+ f" --cache_queue_port {cache_config.cache_queue_port}"
|
||||
+ f" --pod_ip {pod_ip}"
|
||||
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
|
||||
+ f" --num_gpu_blocks {cache_config.total_block_num}"
|
||||
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
|
||||
+ f" --bytes_per_layer_per_block {cache_config.bytes_per_layer_per_block}"
|
||||
+ f" --block_size {cache_config.block_size}"
|
||||
+ f" --engine_pid {pid_suffix}"
|
||||
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
|
||||
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
|
||||
+ f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1"
|
||||
)
|
||||
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
|
||||
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
|
||||
exit_code = cache_manager_processes[-1].poll()
|
||||
if exit_code is None:
|
||||
logger.info("Launch cache transfer manager successful")
|
||||
else:
|
||||
logger.info("Launch cache transfer manager failed, see launch_cache_manager.log for more information")
|
||||
|
||||
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
|
||||
logger.info("Enable hierarchical cache.")
|
||||
self._enable_cpu_cache()
|
||||
cache_manager_processes.extend(cache_messager_processes)
|
||||
return cache_manager_processes
|
||||
|
||||
def launch_cache_messager(
|
||||
self, cache_config, tensor_parallel_size, device_ids, pod_ip, engine_worker_queue_port, pid_suffix
|
||||
):
|
||||
"""
|
||||
launch_cache_messager function used to initialize the cache messager.
|
||||
"""
|
||||
current_dir_path = os.path.split(os.path.abspath(__file__))[0]
|
||||
filename = "cache_messager.py"
|
||||
if (
|
||||
hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||
and hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||
@@ -159,8 +229,10 @@ class PrefixCacheManager:
|
||||
suffix=pid_suffix,
|
||||
create=True,
|
||||
)
|
||||
|
||||
py_path = os.path.join(current_dir_path, filename)
|
||||
log_dir = envs.FD_LOG_DIR
|
||||
cache_manager_processes = []
|
||||
cache_messager_processes = []
|
||||
for i in range(tensor_parallel_size):
|
||||
launch_cmd = (
|
||||
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
|
||||
@@ -169,42 +241,34 @@ class PrefixCacheManager:
|
||||
+ f" --device_id {int(device_ids[i])}"
|
||||
+ f" --rank {i}"
|
||||
+ f" --splitwise_role {self.splitwise_role}"
|
||||
+ f" --num_layers {cache_config.model_cfg.num_hidden_layers}"
|
||||
+ f" --num_hidden_layers {cache_config.model_cfg.num_hidden_layers}"
|
||||
+ f" --head_dim {cache_config.model_cfg.head_dim}"
|
||||
+ f" --kv_num_head {kv_num_head}"
|
||||
+ f" --mp_num {tensor_parallel_size}"
|
||||
+ f" --cache_dtype {cache_config.cache_dtype}"
|
||||
+ f" --cache_queue_port {cache_config.cache_queue_port}"
|
||||
+ f" --enable_splitwise {int(self.enable_splitwise)}"
|
||||
+ f" --pod_ip {pod_ip}"
|
||||
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
|
||||
+ f" --num_gpu_blocks {cache_config.total_block_num}"
|
||||
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
|
||||
+ f" --bytes_per_layer_per_block {cache_config.bytes_per_layer_per_block}"
|
||||
+ f" --block_size {cache_config.block_size}"
|
||||
+ f" --engine_pid {pid_suffix}"
|
||||
+ f" --protocol {cache_config.cache_transfer_protocol}"
|
||||
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
|
||||
+ f" --engine_pid {pid_suffix}"
|
||||
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
|
||||
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
|
||||
+ f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1"
|
||||
+ f" >{log_dir}/launch_cache_messager_{int(device_ids[i])}.log 2>&1"
|
||||
)
|
||||
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
|
||||
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
|
||||
# 等待cache初始化完毕
|
||||
logger.info("Waiting for cache transfer manager ready...")
|
||||
logger.info(f"Launch cache messager, command:{launch_cmd}")
|
||||
cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
|
||||
logger.info("Waiting for cache ready...")
|
||||
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
|
||||
time.sleep(1)
|
||||
exit_code = cache_manager_processes[-1].poll()
|
||||
exit_code = cache_messager_processes[-1].poll()
|
||||
if exit_code is None:
|
||||
logger.info("Launch cache transfer manager successful")
|
||||
logger.info("Launch cache messager successful")
|
||||
else:
|
||||
logger.info("Launch cache transfer manager failed, see launch_cache_manager.log for more information")
|
||||
|
||||
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
|
||||
logger.info("Enable hierarchical cache.")
|
||||
self._enable_cpu_cache()
|
||||
return cache_manager_processes
|
||||
logger.info("Launch cache messager failed, see launch_cache_messager.log for more information")
|
||||
cache_messager_processes = None
|
||||
return cache_messager_processes
|
||||
|
||||
def update_cache_config(self, cache_config):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user