[Feature] Support using prefix-caching + cudagraph for inference (#2924)

* fix the bug in cudagraph+prefix-caching but still have some bug with profile

Change-Id: Ibf2ba3f2e3b08641d03f4b1391d7c862c3efa397

* add the signal to make sure cache manager launched

* fix judge condition

* reomove useless control

* update control stream

* update

* fix xpu

* change the do_profile flag

* update

* add new threads to init cache_manager

---------

Co-authored-by: RAM <gstian5555@outlook.com>
This commit is contained in:
Zero Rains
2025-07-22 15:59:45 +08:00
committed by GitHub
parent 48e6a0ca26
commit 89a485b69f
11 changed files with 63 additions and 65 deletions

View File

@@ -347,7 +347,7 @@ class PaddleDisWorkerProc:
self.exist_prefill_task_signal.value[0] = self.worker.prefill_finished()
def determine_num_available_blocks(self) -> None:
def initialize_kv_cache(self) -> None:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
@@ -400,8 +400,25 @@ class PaddleDisWorkerProc:
self.get_profile_block_num_signal.value[0] = num_blocks_local
else:
num_blocks_local = self.fd_config.parallel_config.total_block_num
# 4. Updata share inputs
self.worker.reinitialize_kv_cache(num_gpu_blocks=num_blocks_local)
logger.info(f"------- num_blocks_global: {num_blocks_local} --------")
# wait engine launch cache_manager
if self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed":
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
self.launched_cache_manager_signal = IPCSignal(
name="launched_cache_manager_signal",
array=launched_cache_manager_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False,
)
while np.any(self.launched_cache_manager_signal.value[0] <= 0):
time.sleep(0.01)
# 4. init kv_cache with accurate num_blocks
self.worker.initialize_cache(num_gpu_blocks=num_blocks_local)
def graph_optimize_and_warm_up_model(self) -> None:
self.worker.graph_optimize_and_warm_up_model()
def init_device(self) -> None:
"""Initialize device and Construct model runner"""
@@ -714,8 +731,8 @@ def run_worker_proc() -> None:
# Load model
worker_proc.load_model()
logger.info("determine_num_available_blocks")
worker_proc.determine_num_available_blocks()
# Initialize KV Cache
worker_proc.initialize_kv_cache()
# Trigger CUDAGraph capture
worker_proc.worker.graph_optimize_and_warm_up_model()