mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Feature] Support using prefix-caching + cudagraph for inference (#2924)
* fix the bug in cudagraph+prefix-caching but still have some bug with profile Change-Id: Ibf2ba3f2e3b08641d03f4b1391d7c862c3efa397 * add the signal to make sure cache manager launched * fix judge condition * reomove useless control * update control stream * update * fix xpu * change the do_profile flag * update * add new threads to init cache_manager --------- Co-authored-by: RAM <gstian5555@outlook.com>
This commit is contained in:
@@ -863,8 +863,6 @@ class EngineArgs:
|
||||
graph_opt_cfg = self.create_graph_optimization_config()
|
||||
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
|
||||
|
||||
assert not (self.use_cudagraph and self.enable_prefix_caching), "Prefix caching cannot be used with CUDA graph"
|
||||
|
||||
assert not (
|
||||
self.tensor_parallel_size <= 1 and self.enable_custom_all_reduce
|
||||
), "enable_custom_all_reduce must be used with tensor_parallel_size>1"
|
||||
|
@@ -183,6 +183,7 @@ class LLMEngine:
|
||||
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
|
||||
pid_suffix=self.ipc_signal_suffix,
|
||||
)
|
||||
self.launched_cache_manager_signal.value[0] = 1
|
||||
|
||||
self.worker_proc = self._start_worker_service()
|
||||
console_logger.info("Waitting worker processes ready...")
|
||||
@@ -217,9 +218,6 @@ class LLMEngine:
|
||||
# Start TokenProcessor thread
|
||||
self.token_processor.run()
|
||||
|
||||
if self.do_profile:
|
||||
self._stop_profile()
|
||||
|
||||
if self.cfg.splitwise_role != "mixed":
|
||||
# 单机逻辑
|
||||
self.engine_worker_queue.available_prefill_instances.put(1)
|
||||
@@ -849,6 +847,17 @@ class LLMEngine:
|
||||
create=True,
|
||||
)
|
||||
|
||||
# launched_cache_manager_signal 用于感知engine是否启动了cache_manager
|
||||
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
|
||||
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
|
||||
self.launched_cache_manager_signal = IPCSignal(
|
||||
name="launched_cache_manager_signal",
|
||||
array=launched_cache_manager_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.ipc_signal_suffix,
|
||||
create=True,
|
||||
)
|
||||
|
||||
# worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间
|
||||
worker_healthy_live_recorded_time_array = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32)
|
||||
self.worker_healthy_live_signal = IPCSignal(
|
||||
@@ -1133,6 +1142,7 @@ class LLMEngine:
|
||||
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
|
||||
pid_suffix=self.ipc_signal_suffix,
|
||||
)
|
||||
self.launched_cache_manager_signal.value[0] = 1
|
||||
|
||||
def check_health(self, time_interval_threashold=30):
|
||||
"""
|
||||
@@ -1171,6 +1181,10 @@ class LLMEngine:
|
||||
|
||||
self.checking_worker_status_thread = threading.Thread(target=detect_thread, daemon=True)
|
||||
self.checking_worker_status_thread.start()
|
||||
checking_worker_init_kv_cache_status_thread = None
|
||||
if self.do_profile:
|
||||
checking_worker_init_kv_cache_status_thread = threading.Thread(target=self._stop_profile, daemon=True)
|
||||
checking_worker_init_kv_cache_status_thread.start()
|
||||
|
||||
# display weight loadding progress
|
||||
with tqdm(total=100, desc="Loading Weights") as pbar:
|
||||
@@ -1201,6 +1215,8 @@ class LLMEngine:
|
||||
self.worker_init_status["finished"] = True
|
||||
try:
|
||||
self.checking_worker_status_thread.join(timeout=1)
|
||||
if checking_worker_init_kv_cache_status_thread is not None:
|
||||
checking_worker_init_kv_cache_status_thread.join(timeout=1)
|
||||
except Exception:
|
||||
pass
|
||||
return True
|
||||
|
@@ -151,8 +151,6 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
"""
|
||||
Process inputs for prefill tasks and insert it to share_inputs buffer
|
||||
"""
|
||||
if "caches" not in self.share_inputs:
|
||||
self.initialize_kv_cache()
|
||||
|
||||
if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill":
|
||||
os.environ["PREFILL_NODE_ONE_STEP_STOP"] = "1"
|
||||
@@ -561,7 +559,7 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
self.initialize_kv_cache()
|
||||
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
|
||||
|
||||
def initialize_kv_cache(self) -> None:
|
||||
def initialize_kv_cache(self, profile: bool = False) -> None:
|
||||
"""
|
||||
Initialize kv cache
|
||||
"""
|
||||
@@ -582,7 +580,7 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num)
|
||||
# local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||
|
||||
if not self.parallel_config.do_profile and (
|
||||
if not profile and (
|
||||
self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
|
||||
):
|
||||
raise NotImplementedError("prefix_caching is not support by GCUModelRunner.")
|
||||
@@ -1012,7 +1010,7 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
|
||||
# Initialize kv cache for profile run. After profile run kv cache will be reset.
|
||||
self.num_gcu_blocks = self.parallel_config.total_block_num
|
||||
self.initialize_kv_cache()
|
||||
self.initialize_kv_cache(profile=True)
|
||||
|
||||
# 1. Profile with multimodal encoder & encoder cache
|
||||
|
||||
@@ -1038,7 +1036,6 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
self.num_gcu_blocks = num_gpu_blocks
|
||||
|
||||
# Reset block table and kv cache with global block num
|
||||
if not (self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
|
||||
self.initialize_kv_cache()
|
||||
|
||||
# Reset free list
|
||||
@@ -1057,8 +1054,6 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
}
|
||||
)
|
||||
|
||||
self.parallel_config.do_profile = False
|
||||
|
||||
if self.speculative_method in ["mtp"]:
|
||||
self.proposer.update_block_num(num_gpu_blocks)
|
||||
|
||||
|
@@ -98,9 +98,9 @@ class GcuWorker(WorkerBase):
|
||||
""" """
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
|
||||
def initialize_cache(self, num_gpu_blocks: int) -> None:
|
||||
""" """
|
||||
pass
|
||||
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
@@ -134,7 +134,3 @@ class GcuWorker(WorkerBase):
|
||||
def cal_theortical_kvcache(self) -> int:
|
||||
""" """
|
||||
return self.model_runner.cal_theortical_kvcache()
|
||||
|
||||
def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None:
|
||||
""" """
|
||||
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
|
||||
|
@@ -193,9 +193,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
Process inputs for prefill tasks and insert it to share_inputs buffer
|
||||
TODO(gongshaotian): Refactor this func
|
||||
"""
|
||||
# NOTE(luotingdan): Lazy initialize kv cache
|
||||
if "caches" not in self.share_inputs:
|
||||
self.initialize_kv_cache()
|
||||
|
||||
# NOTE(luotingdan): Set environment variable of prefill node
|
||||
if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill":
|
||||
@@ -700,7 +697,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
for attn_backend in self.attn_backends:
|
||||
attn_backend.init_attention_metadata(self.forward_meta)
|
||||
|
||||
def initialize_kv_cache(self) -> None:
|
||||
def initialize_kv_cache(self, profile: bool = False) -> None:
|
||||
"""
|
||||
Initialize kv cache
|
||||
"""
|
||||
@@ -721,7 +718,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num)
|
||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||
|
||||
if not self.parallel_config.do_profile and (
|
||||
if not profile and (
|
||||
self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
|
||||
):
|
||||
cache_kvs_list = []
|
||||
@@ -739,7 +736,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
else:
|
||||
for i in range(self.model_config.num_hidden_layers):
|
||||
|
||||
cache_kvs[f"key_caches_{i}"] = paddle.full(
|
||||
shape=kv_cache_shape,
|
||||
fill_value=0,
|
||||
@@ -1218,7 +1214,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
# Initialize kv cache for profile run. After profile run kv cache will be reset.
|
||||
# TODO(gongshaotian): Optimize the management logic of kvcache
|
||||
self.num_gpu_blocks = self.parallel_config.total_block_num
|
||||
self.initialize_kv_cache()
|
||||
self.initialize_kv_cache(profile=True)
|
||||
|
||||
# 1. Profile with multimodal encoder & encoder cache
|
||||
|
||||
@@ -1243,7 +1239,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
|
||||
# Reset block table and kv cache with global block num
|
||||
if not (self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
|
||||
self.initialize_kv_cache()
|
||||
|
||||
# Reset free list
|
||||
@@ -1262,8 +1257,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
}
|
||||
)
|
||||
|
||||
self.parallel_config.do_profile = False
|
||||
|
||||
if self.speculative_method in ["mtp"]:
|
||||
self.proposer.update_block_num(num_gpu_blocks)
|
||||
|
||||
|
@@ -165,9 +165,10 @@ class GpuWorker(WorkerBase):
|
||||
"""Get current model"""
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
|
||||
"""Initizlize the KV Cache"""
|
||||
pass
|
||||
def initialize_cache(self, num_gpu_blocks: int) -> None:
|
||||
"""Initizlize the KV Cache with accurate num_gpu_blocks"""
|
||||
# accurate cache size
|
||||
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
@@ -198,7 +199,3 @@ class GpuWorker(WorkerBase):
|
||||
def cal_theortical_kvcache(self) -> int:
|
||||
"""Calculate the block memory required"""
|
||||
return self.model_runner.cal_theortical_kvcache()
|
||||
|
||||
def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None:
|
||||
"""Reinitialize the kv cache using the parameters from the profile"""
|
||||
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
|
||||
|
@@ -141,9 +141,6 @@ class IluvatarModelRunner(ModelRunnerBase):
|
||||
Process inputs for prefill tasks and insert it to share_inputs buffer
|
||||
TODO(gongshaotian): Refactor this func
|
||||
"""
|
||||
# NOTE(luotingdan): Lazy initialize kv cache
|
||||
if "caches" not in self.share_inputs:
|
||||
self.initialize_kv_cache()
|
||||
|
||||
# NOTE(luotingdan): Set environment variable of prefill node
|
||||
if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill":
|
||||
@@ -552,7 +549,7 @@ class IluvatarModelRunner(ModelRunnerBase):
|
||||
if self.forward_meta is not None:
|
||||
self.forward_meta.clear_caches()
|
||||
|
||||
def initialize_kv_cache(self) -> None:
|
||||
def initialize_kv_cache(self, profile: bool = False) -> None:
|
||||
"""
|
||||
Initialize kv cache
|
||||
"""
|
||||
@@ -992,7 +989,7 @@ class IluvatarModelRunner(ModelRunnerBase):
|
||||
# Initialize kv cache for profile run. After profile run kv cache will be reset.
|
||||
# TODO(gongshaotian): Optimize the management logic of kvcache
|
||||
self.num_gpu_blocks = self.parallel_config.total_block_num
|
||||
self.initialize_kv_cache()
|
||||
self.initialize_kv_cache(profile=True)
|
||||
|
||||
# 1. Profile with multimodal encoder & encoder cache
|
||||
|
||||
@@ -1016,7 +1013,6 @@ class IluvatarModelRunner(ModelRunnerBase):
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
|
||||
# Reset block table and kv cache with global block num
|
||||
if not (self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
|
||||
self.initialize_kv_cache()
|
||||
|
||||
# Reset free list
|
||||
@@ -1035,8 +1031,6 @@ class IluvatarModelRunner(ModelRunnerBase):
|
||||
}
|
||||
)
|
||||
|
||||
self.parallel_config.do_profile = False
|
||||
|
||||
def cal_theortical_kvcache(self):
|
||||
"""
|
||||
Calculate the total block memory required at the model level
|
||||
|
@@ -99,9 +99,9 @@ class IluvatarWorker(WorkerBase):
|
||||
""" """
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
|
||||
def initialize_cache(self, num_gpu_blocks: int) -> None:
|
||||
""" """
|
||||
pass
|
||||
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
@@ -135,7 +135,3 @@ class IluvatarWorker(WorkerBase):
|
||||
def cal_theortical_kvcache(self) -> int:
|
||||
""" """
|
||||
return self.model_runner.cal_theortical_kvcache()
|
||||
|
||||
def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None:
|
||||
""" """
|
||||
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
|
||||
|
@@ -64,7 +64,7 @@ class WorkerBase(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
|
||||
def initialize_cache(self, num_gpu_blocks: int) -> None:
|
||||
"""Initizlize the KV Cache with the given size in blocks."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@@ -347,7 +347,7 @@ class PaddleDisWorkerProc:
|
||||
|
||||
self.exist_prefill_task_signal.value[0] = self.worker.prefill_finished()
|
||||
|
||||
def determine_num_available_blocks(self) -> None:
|
||||
def initialize_kv_cache(self) -> None:
|
||||
"""Profiles the peak memory usage of the model to determine how many
|
||||
KV blocks may be allocated without OOMs.
|
||||
|
||||
@@ -400,8 +400,25 @@ class PaddleDisWorkerProc:
|
||||
self.get_profile_block_num_signal.value[0] = num_blocks_local
|
||||
else:
|
||||
num_blocks_local = self.fd_config.parallel_config.total_block_num
|
||||
# 4. Updata share inputs
|
||||
self.worker.reinitialize_kv_cache(num_gpu_blocks=num_blocks_local)
|
||||
|
||||
logger.info(f"------- num_blocks_global: {num_blocks_local} --------")
|
||||
# wait engine launch cache_manager
|
||||
if self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed":
|
||||
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
|
||||
self.launched_cache_manager_signal = IPCSignal(
|
||||
name="launched_cache_manager_signal",
|
||||
array=launched_cache_manager_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_pid,
|
||||
create=False,
|
||||
)
|
||||
while np.any(self.launched_cache_manager_signal.value[0] <= 0):
|
||||
time.sleep(0.01)
|
||||
# 4. init kv_cache with accurate num_blocks
|
||||
self.worker.initialize_cache(num_gpu_blocks=num_blocks_local)
|
||||
|
||||
def graph_optimize_and_warm_up_model(self) -> None:
|
||||
self.worker.graph_optimize_and_warm_up_model()
|
||||
|
||||
def init_device(self) -> None:
|
||||
"""Initialize device and Construct model runner"""
|
||||
@@ -714,8 +731,8 @@ def run_worker_proc() -> None:
|
||||
|
||||
# Load model
|
||||
worker_proc.load_model()
|
||||
logger.info("determine_num_available_blocks")
|
||||
worker_proc.determine_num_available_blocks()
|
||||
# Initialize KV Cache
|
||||
worker_proc.initialize_kv_cache()
|
||||
|
||||
# Trigger CUDAGraph capture
|
||||
worker_proc.worker.graph_optimize_and_warm_up_model()
|
||||
|
@@ -131,9 +131,9 @@ class XpuWorker(WorkerBase):
|
||||
""" """
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
|
||||
def initialize_cache(self, num_gpu_blocks: int) -> None:
|
||||
""" """
|
||||
pass
|
||||
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
@@ -159,7 +159,3 @@ class XpuWorker(WorkerBase):
|
||||
def check_health(self) -> bool:
|
||||
""" """
|
||||
return True
|
||||
|
||||
def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None:
|
||||
""" """
|
||||
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
|
||||
|
Reference in New Issue
Block a user