[Feature] Support using prefix-caching + cudagraph for inference (#2924)

* fix the bug in cudagraph+prefix-caching but still have some bug with profile

Change-Id: Ibf2ba3f2e3b08641d03f4b1391d7c862c3efa397

* add the signal to make sure cache manager launched

* fix judge condition

* reomove useless control

* update control stream

* update

* fix xpu

* change the do_profile flag

* update

* add new threads to init cache_manager

---------

Co-authored-by: RAM <gstian5555@outlook.com>
This commit is contained in:
Zero Rains
2025-07-22 15:59:45 +08:00
committed by GitHub
parent 48e6a0ca26
commit 89a485b69f
11 changed files with 63 additions and 65 deletions

View File

@@ -141,9 +141,6 @@ class IluvatarModelRunner(ModelRunnerBase):
Process inputs for prefill tasks and insert it to share_inputs buffer
TODO(gongshaotian): Refactor this func
"""
# NOTE(luotingdan): Lazy initialize kv cache
if "caches" not in self.share_inputs:
self.initialize_kv_cache()
# NOTE(luotingdan): Set environment variable of prefill node
if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill":
@@ -552,7 +549,7 @@ class IluvatarModelRunner(ModelRunnerBase):
if self.forward_meta is not None:
self.forward_meta.clear_caches()
def initialize_kv_cache(self) -> None:
def initialize_kv_cache(self, profile: bool = False) -> None:
"""
Initialize kv cache
"""
@@ -992,7 +989,7 @@ class IluvatarModelRunner(ModelRunnerBase):
# Initialize kv cache for profile run. After profile run kv cache will be reset.
# TODO(gongshaotian): Optimize the management logic of kvcache
self.num_gpu_blocks = self.parallel_config.total_block_num
self.initialize_kv_cache()
self.initialize_kv_cache(profile=True)
# 1. Profile with multimodal encoder & encoder cache
@@ -1016,8 +1013,7 @@ class IluvatarModelRunner(ModelRunnerBase):
self.num_gpu_blocks = num_gpu_blocks
# Reset block table and kv cache with global block num
if not (self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
self.initialize_kv_cache()
self.initialize_kv_cache()
# Reset free list
free_list = list(
@@ -1035,8 +1031,6 @@ class IluvatarModelRunner(ModelRunnerBase):
}
)
self.parallel_config.do_profile = False
def cal_theortical_kvcache(self):
"""
Calculate the total block memory required at the model level