[Feature] Support using prefix-caching + cudagraph for inference (#2924)

* fix the bug in cudagraph+prefix-caching but still have some bug with profile

Change-Id: Ibf2ba3f2e3b08641d03f4b1391d7c862c3efa397

* add the signal to make sure cache manager launched

* fix judge condition

* reomove useless control

* update control stream

* update

* fix xpu

* change the do_profile flag

* update

* add new threads to init cache_manager

---------

Co-authored-by: RAM <gstian5555@outlook.com>
This commit is contained in:
Zero Rains
2025-07-22 15:59:45 +08:00
committed by GitHub
parent 48e6a0ca26
commit 89a485b69f
11 changed files with 63 additions and 65 deletions

View File

@@ -863,8 +863,6 @@ class EngineArgs:
graph_opt_cfg = self.create_graph_optimization_config() graph_opt_cfg = self.create_graph_optimization_config()
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph) graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
assert not (self.use_cudagraph and self.enable_prefix_caching), "Prefix caching cannot be used with CUDA graph"
assert not ( assert not (
self.tensor_parallel_size <= 1 and self.enable_custom_all_reduce self.tensor_parallel_size <= 1 and self.enable_custom_all_reduce
), "enable_custom_all_reduce must be used with tensor_parallel_size>1" ), "enable_custom_all_reduce must be used with tensor_parallel_size>1"

View File

@@ -183,6 +183,7 @@ class LLMEngine:
engine_worker_queue_port=self.cfg.engine_worker_queue_port, engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=self.ipc_signal_suffix, pid_suffix=self.ipc_signal_suffix,
) )
self.launched_cache_manager_signal.value[0] = 1
self.worker_proc = self._start_worker_service() self.worker_proc = self._start_worker_service()
console_logger.info("Waitting worker processes ready...") console_logger.info("Waitting worker processes ready...")
@@ -217,9 +218,6 @@ class LLMEngine:
# Start TokenProcessor thread # Start TokenProcessor thread
self.token_processor.run() self.token_processor.run()
if self.do_profile:
self._stop_profile()
if self.cfg.splitwise_role != "mixed": if self.cfg.splitwise_role != "mixed":
# 单机逻辑 # 单机逻辑
self.engine_worker_queue.available_prefill_instances.put(1) self.engine_worker_queue.available_prefill_instances.put(1)
@@ -849,6 +847,17 @@ class LLMEngine:
create=True, create=True,
) )
# launched_cache_manager_signal 用于感知engine是否启动了cache_manager
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
self.launched_cache_manager_signal = IPCSignal(
name="launched_cache_manager_signal",
array=launched_cache_manager_signal_data,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=True,
)
# worker_live_signal 用于engine感知各worker进程是否存活记录每个step 时间 # worker_live_signal 用于engine感知各worker进程是否存活记录每个step 时间
worker_healthy_live_recorded_time_array = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32) worker_healthy_live_recorded_time_array = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32)
self.worker_healthy_live_signal = IPCSignal( self.worker_healthy_live_signal = IPCSignal(
@@ -1133,6 +1142,7 @@ class LLMEngine:
engine_worker_queue_port=self.cfg.engine_worker_queue_port, engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=self.ipc_signal_suffix, pid_suffix=self.ipc_signal_suffix,
) )
self.launched_cache_manager_signal.value[0] = 1
def check_health(self, time_interval_threashold=30): def check_health(self, time_interval_threashold=30):
""" """
@@ -1171,6 +1181,10 @@ class LLMEngine:
self.checking_worker_status_thread = threading.Thread(target=detect_thread, daemon=True) self.checking_worker_status_thread = threading.Thread(target=detect_thread, daemon=True)
self.checking_worker_status_thread.start() self.checking_worker_status_thread.start()
checking_worker_init_kv_cache_status_thread = None
if self.do_profile:
checking_worker_init_kv_cache_status_thread = threading.Thread(target=self._stop_profile, daemon=True)
checking_worker_init_kv_cache_status_thread.start()
# display weight loadding progress # display weight loadding progress
with tqdm(total=100, desc="Loading Weights") as pbar: with tqdm(total=100, desc="Loading Weights") as pbar:
@@ -1201,6 +1215,8 @@ class LLMEngine:
self.worker_init_status["finished"] = True self.worker_init_status["finished"] = True
try: try:
self.checking_worker_status_thread.join(timeout=1) self.checking_worker_status_thread.join(timeout=1)
if checking_worker_init_kv_cache_status_thread is not None:
checking_worker_init_kv_cache_status_thread.join(timeout=1)
except Exception: except Exception:
pass pass
return True return True

View File

@@ -151,8 +151,6 @@ class GCUModelRunner(ModelRunnerBase):
""" """
Process inputs for prefill tasks and insert it to share_inputs buffer Process inputs for prefill tasks and insert it to share_inputs buffer
""" """
if "caches" not in self.share_inputs:
self.initialize_kv_cache()
if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill": if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill":
os.environ["PREFILL_NODE_ONE_STEP_STOP"] = "1" os.environ["PREFILL_NODE_ONE_STEP_STOP"] = "1"
@@ -561,7 +559,7 @@ class GCUModelRunner(ModelRunnerBase):
self.initialize_kv_cache() self.initialize_kv_cache()
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory") self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
def initialize_kv_cache(self) -> None: def initialize_kv_cache(self, profile: bool = False) -> None:
""" """
Initialize kv cache Initialize kv cache
""" """
@@ -582,7 +580,7 @@ class GCUModelRunner(ModelRunnerBase):
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num) kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num)
# local_rank = self.local_rank % self.parallel_config.tensor_parallel_size # local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not self.parallel_config.do_profile and ( if not profile and (
self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed" self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
): ):
raise NotImplementedError("prefix_caching is not support by GCUModelRunner.") raise NotImplementedError("prefix_caching is not support by GCUModelRunner.")
@@ -1012,7 +1010,7 @@ class GCUModelRunner(ModelRunnerBase):
# Initialize kv cache for profile run. After profile run kv cache will be reset. # Initialize kv cache for profile run. After profile run kv cache will be reset.
self.num_gcu_blocks = self.parallel_config.total_block_num self.num_gcu_blocks = self.parallel_config.total_block_num
self.initialize_kv_cache() self.initialize_kv_cache(profile=True)
# 1. Profile with multimodal encoder & encoder cache # 1. Profile with multimodal encoder & encoder cache
@@ -1038,8 +1036,7 @@ class GCUModelRunner(ModelRunnerBase):
self.num_gcu_blocks = num_gpu_blocks self.num_gcu_blocks = num_gpu_blocks
# Reset block table and kv cache with global block num # Reset block table and kv cache with global block num
if not (self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"): self.initialize_kv_cache()
self.initialize_kv_cache()
# Reset free list # Reset free list
free_list = list( free_list = list(
@@ -1057,8 +1054,6 @@ class GCUModelRunner(ModelRunnerBase):
} }
) )
self.parallel_config.do_profile = False
if self.speculative_method in ["mtp"]: if self.speculative_method in ["mtp"]:
self.proposer.update_block_num(num_gpu_blocks) self.proposer.update_block_num(num_gpu_blocks)

View File

@@ -98,9 +98,9 @@ class GcuWorker(WorkerBase):
""" """ """ """
return self.model_runner.get_model() return self.model_runner.get_model()
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: def initialize_cache(self, num_gpu_blocks: int) -> None:
""" """ """ """
pass self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
def execute_model( def execute_model(
self, self,
@@ -134,7 +134,3 @@ class GcuWorker(WorkerBase):
def cal_theortical_kvcache(self) -> int: def cal_theortical_kvcache(self) -> int:
""" """ """ """
return self.model_runner.cal_theortical_kvcache() return self.model_runner.cal_theortical_kvcache()
def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None:
""" """
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)

View File

@@ -193,9 +193,6 @@ class GPUModelRunner(ModelRunnerBase):
Process inputs for prefill tasks and insert it to share_inputs buffer Process inputs for prefill tasks and insert it to share_inputs buffer
TODO(gongshaotian): Refactor this func TODO(gongshaotian): Refactor this func
""" """
# NOTE(luotingdan): Lazy initialize kv cache
if "caches" not in self.share_inputs:
self.initialize_kv_cache()
# NOTE(luotingdan): Set environment variable of prefill node # NOTE(luotingdan): Set environment variable of prefill node
if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill": if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill":
@@ -700,7 +697,7 @@ class GPUModelRunner(ModelRunnerBase):
for attn_backend in self.attn_backends: for attn_backend in self.attn_backends:
attn_backend.init_attention_metadata(self.forward_meta) attn_backend.init_attention_metadata(self.forward_meta)
def initialize_kv_cache(self) -> None: def initialize_kv_cache(self, profile: bool = False) -> None:
""" """
Initialize kv cache Initialize kv cache
""" """
@@ -721,7 +718,7 @@ class GPUModelRunner(ModelRunnerBase):
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num) kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num)
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not self.parallel_config.do_profile and ( if not profile and (
self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed" self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
): ):
cache_kvs_list = [] cache_kvs_list = []
@@ -739,7 +736,6 @@ class GPUModelRunner(ModelRunnerBase):
else: else:
for i in range(self.model_config.num_hidden_layers): for i in range(self.model_config.num_hidden_layers):
cache_kvs[f"key_caches_{i}"] = paddle.full( cache_kvs[f"key_caches_{i}"] = paddle.full(
shape=kv_cache_shape, shape=kv_cache_shape,
fill_value=0, fill_value=0,
@@ -1218,7 +1214,7 @@ class GPUModelRunner(ModelRunnerBase):
# Initialize kv cache for profile run. After profile run kv cache will be reset. # Initialize kv cache for profile run. After profile run kv cache will be reset.
# TODO(gongshaotian): Optimize the management logic of kvcache # TODO(gongshaotian): Optimize the management logic of kvcache
self.num_gpu_blocks = self.parallel_config.total_block_num self.num_gpu_blocks = self.parallel_config.total_block_num
self.initialize_kv_cache() self.initialize_kv_cache(profile=True)
# 1. Profile with multimodal encoder & encoder cache # 1. Profile with multimodal encoder & encoder cache
@@ -1243,8 +1239,7 @@ class GPUModelRunner(ModelRunnerBase):
self.num_gpu_blocks = num_gpu_blocks self.num_gpu_blocks = num_gpu_blocks
# Reset block table and kv cache with global block num # Reset block table and kv cache with global block num
if not (self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"): self.initialize_kv_cache()
self.initialize_kv_cache()
# Reset free list # Reset free list
free_list = list( free_list = list(
@@ -1262,8 +1257,6 @@ class GPUModelRunner(ModelRunnerBase):
} }
) )
self.parallel_config.do_profile = False
if self.speculative_method in ["mtp"]: if self.speculative_method in ["mtp"]:
self.proposer.update_block_num(num_gpu_blocks) self.proposer.update_block_num(num_gpu_blocks)

View File

@@ -165,9 +165,10 @@ class GpuWorker(WorkerBase):
"""Get current model""" """Get current model"""
return self.model_runner.get_model() return self.model_runner.get_model()
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: def initialize_cache(self, num_gpu_blocks: int) -> None:
"""Initizlize the KV Cache""" """Initizlize the KV Cache with accurate num_gpu_blocks"""
pass # accurate cache size
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
def execute_model( def execute_model(
self, self,
@@ -198,7 +199,3 @@ class GpuWorker(WorkerBase):
def cal_theortical_kvcache(self) -> int: def cal_theortical_kvcache(self) -> int:
"""Calculate the block memory required""" """Calculate the block memory required"""
return self.model_runner.cal_theortical_kvcache() return self.model_runner.cal_theortical_kvcache()
def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None:
"""Reinitialize the kv cache using the parameters from the profile"""
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)

View File

@@ -141,9 +141,6 @@ class IluvatarModelRunner(ModelRunnerBase):
Process inputs for prefill tasks and insert it to share_inputs buffer Process inputs for prefill tasks and insert it to share_inputs buffer
TODO(gongshaotian): Refactor this func TODO(gongshaotian): Refactor this func
""" """
# NOTE(luotingdan): Lazy initialize kv cache
if "caches" not in self.share_inputs:
self.initialize_kv_cache()
# NOTE(luotingdan): Set environment variable of prefill node # NOTE(luotingdan): Set environment variable of prefill node
if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill": if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill":
@@ -552,7 +549,7 @@ class IluvatarModelRunner(ModelRunnerBase):
if self.forward_meta is not None: if self.forward_meta is not None:
self.forward_meta.clear_caches() self.forward_meta.clear_caches()
def initialize_kv_cache(self) -> None: def initialize_kv_cache(self, profile: bool = False) -> None:
""" """
Initialize kv cache Initialize kv cache
""" """
@@ -992,7 +989,7 @@ class IluvatarModelRunner(ModelRunnerBase):
# Initialize kv cache for profile run. After profile run kv cache will be reset. # Initialize kv cache for profile run. After profile run kv cache will be reset.
# TODO(gongshaotian): Optimize the management logic of kvcache # TODO(gongshaotian): Optimize the management logic of kvcache
self.num_gpu_blocks = self.parallel_config.total_block_num self.num_gpu_blocks = self.parallel_config.total_block_num
self.initialize_kv_cache() self.initialize_kv_cache(profile=True)
# 1. Profile with multimodal encoder & encoder cache # 1. Profile with multimodal encoder & encoder cache
@@ -1016,8 +1013,7 @@ class IluvatarModelRunner(ModelRunnerBase):
self.num_gpu_blocks = num_gpu_blocks self.num_gpu_blocks = num_gpu_blocks
# Reset block table and kv cache with global block num # Reset block table and kv cache with global block num
if not (self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"): self.initialize_kv_cache()
self.initialize_kv_cache()
# Reset free list # Reset free list
free_list = list( free_list = list(
@@ -1035,8 +1031,6 @@ class IluvatarModelRunner(ModelRunnerBase):
} }
) )
self.parallel_config.do_profile = False
def cal_theortical_kvcache(self): def cal_theortical_kvcache(self):
""" """
Calculate the total block memory required at the model level Calculate the total block memory required at the model level

View File

@@ -99,9 +99,9 @@ class IluvatarWorker(WorkerBase):
""" """ """ """
return self.model_runner.get_model() return self.model_runner.get_model()
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: def initialize_cache(self, num_gpu_blocks: int) -> None:
""" """ """ """
pass self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
def execute_model( def execute_model(
self, self,
@@ -135,7 +135,3 @@ class IluvatarWorker(WorkerBase):
def cal_theortical_kvcache(self) -> int: def cal_theortical_kvcache(self) -> int:
""" """ """ """
return self.model_runner.cal_theortical_kvcache() return self.model_runner.cal_theortical_kvcache()
def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None:
""" """
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)

View File

@@ -64,7 +64,7 @@ class WorkerBase(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: def initialize_cache(self, num_gpu_blocks: int) -> None:
"""Initizlize the KV Cache with the given size in blocks.""" """Initizlize the KV Cache with the given size in blocks."""
raise NotImplementedError raise NotImplementedError

View File

@@ -347,7 +347,7 @@ class PaddleDisWorkerProc:
self.exist_prefill_task_signal.value[0] = self.worker.prefill_finished() self.exist_prefill_task_signal.value[0] = self.worker.prefill_finished()
def determine_num_available_blocks(self) -> None: def initialize_kv_cache(self) -> None:
"""Profiles the peak memory usage of the model to determine how many """Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs. KV blocks may be allocated without OOMs.
@@ -400,8 +400,25 @@ class PaddleDisWorkerProc:
self.get_profile_block_num_signal.value[0] = num_blocks_local self.get_profile_block_num_signal.value[0] = num_blocks_local
else: else:
num_blocks_local = self.fd_config.parallel_config.total_block_num num_blocks_local = self.fd_config.parallel_config.total_block_num
# 4. Updata share inputs
self.worker.reinitialize_kv_cache(num_gpu_blocks=num_blocks_local) logger.info(f"------- num_blocks_global: {num_blocks_local} --------")
# wait engine launch cache_manager
if self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed":
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
self.launched_cache_manager_signal = IPCSignal(
name="launched_cache_manager_signal",
array=launched_cache_manager_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False,
)
while np.any(self.launched_cache_manager_signal.value[0] <= 0):
time.sleep(0.01)
# 4. init kv_cache with accurate num_blocks
self.worker.initialize_cache(num_gpu_blocks=num_blocks_local)
def graph_optimize_and_warm_up_model(self) -> None:
self.worker.graph_optimize_and_warm_up_model()
def init_device(self) -> None: def init_device(self) -> None:
"""Initialize device and Construct model runner""" """Initialize device and Construct model runner"""
@@ -714,8 +731,8 @@ def run_worker_proc() -> None:
# Load model # Load model
worker_proc.load_model() worker_proc.load_model()
logger.info("determine_num_available_blocks") # Initialize KV Cache
worker_proc.determine_num_available_blocks() worker_proc.initialize_kv_cache()
# Trigger CUDAGraph capture # Trigger CUDAGraph capture
worker_proc.worker.graph_optimize_and_warm_up_model() worker_proc.worker.graph_optimize_and_warm_up_model()

View File

@@ -131,9 +131,9 @@ class XpuWorker(WorkerBase):
""" """ """ """
return self.model_runner.get_model() return self.model_runner.get_model()
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: def initialize_cache(self, num_gpu_blocks: int) -> None:
""" """ """ """
pass self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
def execute_model( def execute_model(
self, self,
@@ -159,7 +159,3 @@ class XpuWorker(WorkerBase):
def check_health(self) -> bool: def check_health(self) -> bool:
""" """ """ """
return True return True
def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None:
""" """
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)