[Feature] support prefix cache in DP (#4359)

* [Feature] support prefix cache in DP

* fix

* Update common_engine.py

* Update common_engine.py

* Update common_engine.py

* Update common_engine.py

---------

Co-authored-by: ltd0924 <luotingdan@baidu.com>
This commit is contained in:
ltd0924
2025-10-11 10:12:12 +08:00
committed by GitHub
parent 368049673b
commit 3f535b45a2
4 changed files with 46 additions and 11 deletions

View File

@@ -198,7 +198,7 @@ class EngineArgs:
The amount of CPU memory to offload to.
"""
cache_queue_port: int = 8003
cache_queue_port: str = "8003"
"""
Port for cache queue.
"""
@@ -741,7 +741,7 @@ class EngineArgs:
cache_group.add_argument(
"--cache-queue-port",
type=int,
type=lambda s: [int(item.strip()) for item in s.split(",")] if s else None,
default=EngineArgs.cache_queue_port,
help="port for cache queue",
)

View File

@@ -68,6 +68,12 @@ class EngineService:
cfg (Config): Config object containing all the configuration parameters.
"""
self.cfg = cfg
if isinstance(self.cfg.cache_config.cache_queue_port, str):
self.cfg.cache_config.cache_queue_port = self.cfg.cache_config.cache_queue_port.split(",")
if isinstance(self.cfg.cache_config.cache_queue_port, list):
self.cfg.cache_config.cache_queue_port = int(
self.cfg.cache_config.cache_queue_port[self.cfg.parallel_config.local_data_parallel_id]
)
if self.cfg.parallel_config.enable_expert_parallel:
self.llm_logger = get_logger(
@@ -251,11 +257,7 @@ class EngineService:
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
if (
self.cfg.cache_config.enable_prefix_caching
or self.cfg.scheduler_config.splitwise_role != "mixed"
and self.cfg.parallel_config.local_data_parallel_id == 0
):
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
self.cache_task_queue = EngineCacheQueue(
address=(
self.cfg.master_ip,

View File

@@ -57,6 +57,11 @@ class ExpertService:
llm_logger.info(f"local_data_parallel_id: {local_data_parallel_id}")
self.cfg.disaggregate_info = None
if self.cfg.cache_config.num_gpu_blocks_override is None:
self.do_profile = True
else:
self.do_profile = False
if cfg.scheduler_config.splitwise_role != "mixed":
if len(self.cfg.cache_config.pd_comm_port) == 1:
self.cfg.cache_config.pd_comm_port[0] = (
@@ -97,9 +102,29 @@ class ExpertService:
ipc_signal_suffix = self.cfg.parallel_config.engine_worker_queue_port[0]
llm_logger.info(f"start expert service {local_data_parallel_id}")
if self.cfg.scheduler_config.splitwise_role != "mixed":
if self.cfg.splitwise_role != "mixed" or self.cfg.cache_config.enable_prefix_caching:
if self.do_profile:
get_profile_block_num = np.zeros([1], dtype=np.int32)
while True:
try:
self.get_profile_block_num_signal = IPCSignal(
name="get_profile_block_num",
array=get_profile_block_num,
dtype=np.int32,
suffix=int(self.cfg.engine_worker_queue_port[0]),
create=False,
)
break
except:
time.sleep(1)
self.reset_kvcache_blocks()
ipc_signal_suffix_cache = self.cfg.parallel_config.engine_worker_queue_port[local_data_parallel_id]
self.engine.start_cache_service(self.cfg.local_device_ids, ipc_signal_suffix_cache)
self.cache_manager_processes = self.engine.start_cache_service(
self.cfg.local_device_ids, ipc_signal_suffix_cache
)
if self.cfg.splitwise_role != "mixed":
self.engine.split_mode_get_tasks()
if self.cfg.scheduler_config.name == "splitwise":
self.cfg.init_cache_info()
@@ -135,6 +160,14 @@ class ExpertService:
)
return True
def reset_kvcache_blocks(self):
self.do_profile = 0
while self.get_profile_block_num_signal.value[0] == 0:
time.sleep(1)
num_gpu_blocks = self.get_profile_block_num_signal.value[0]
self.cfg.cache_config.reset(num_gpu_blocks)
self.engine.resource_manager.reset_cache_config(self.cfg.cache_config)
def _exit_sub_services(self):
"""
exit sub services

View File

@@ -1185,7 +1185,7 @@ class GPUModelRunner(ModelRunnerBase):
if not create_cache_tensor:
logger.info(f"Waiting for cache managers to create kv cache.. {cache_ready_signal.value}")
while cache_ready_signal.value[self.local_rank] != 1:
while cache_ready_signal.value[local_rank] != 1:
time.sleep(1)
logger.info(f"OK! Stop waiting. {cache_ready_signal.value}")
@@ -1236,7 +1236,7 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["caches"] = cache_kvs_list
if not profile and create_cache_tensor:
cache_ready_signal.value[self.local_rank] = 1
cache_ready_signal.value[local_rank] = 1
logger.info(f"✅ kv cache is ready! {cache_ready_signal.value}")
paddle.device.cuda.empty_cache()