[FDConfig] add block number verfied (#4983)

* Update config.py

* fix

* update unit test

---------

Co-authored-by: ltd0924 <luotingdan@baidu.com>
This commit is contained in:
ltd0924
2025-11-13 09:48:44 +08:00
committed by GitHub
parent 1c0b0b08b7
commit 303c986cc7
4 changed files with 30 additions and 17 deletions

View File

@@ -1290,6 +1290,9 @@ class CacheConfig:
self.prefill_kvcache_block_num = self.total_block_num
else:
self.prefill_kvcache_block_num = int(self.total_block_num * self.kv_cache_ratio)
assert (
self.prefill_kvcache_block_num >= self.max_block_num_per_seq
), f"current block number :{self.prefill_kvcache_block_num} should be greater than or equal to current model len needed minimum block number :{self.max_block_num_per_seq}"
else:
length = num_total_tokens // number_of_tasks
block_num = (length + self.block_size - 1 + self.dec_token_num) // self.block_size
@@ -1310,6 +1313,9 @@ class CacheConfig:
f"Reset block num, the total_block_num:{self.total_block_num},"
f" prefill_kvcache_block_num:{self.prefill_kvcache_block_num}"
)
assert (
self.prefill_kvcache_block_num >= self.max_block_num_per_seq
), f"current block number :{self.prefill_kvcache_block_num} should be greater than or equal to current model len needed minimum block number :{self.max_block_num_per_seq}"
def print(self):
"""
@@ -1585,8 +1591,8 @@ class FDConfig:
if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.model_config.max_model_len * 0.04)
self.cache_config.postprocess(self.scheduler_config.max_num_batched_tokens, self.scheduler_config.max_num_seqs)
self.cache_config.max_block_num_per_seq = int(self.model_config.max_model_len // self.cache_config.block_size)
self.cache_config.postprocess(self.scheduler_config.max_num_batched_tokens, self.scheduler_config.max_num_seqs)
if self.model_config is not None and self.model_config.enable_mm and not envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.cache_config.enable_prefix_caching = False

View File

@@ -30,7 +30,7 @@ def make_prefix_cache_manager(max_num_seqs, enable_mm=False, num_gpu_blocks_over
)
args = asdict(engine_args)
cache_cfg = CacheConfig(args)
model_cfg = SimpleNamespace(enable_mm=enable_mm, max_model_len=8192)
model_cfg = SimpleNamespace(enable_mm=enable_mm, max_model_len=4196)
speculative_cfg = SimpleNamespace(method=None)
model_cfg.print = print
cache_cfg.bytes_per_layer_per_block = 1
@@ -48,9 +48,16 @@ def make_prefix_cache_manager(max_num_seqs, enable_mm=False, num_gpu_blocks_over
return PrefixCacheManager(config=fd_config, tensor_parallel_size=8, splitwise_role="mixed")
def test_block_num_limit():
import pytest
with pytest.raises(AssertionError):
make_prefix_cache_manager(max_num_seqs=3, enable_mm=False, num_gpu_blocks_override=20)
def test_normal_case():
block_size = 64
cache_manager = make_prefix_cache_manager(max_num_seqs=3, enable_mm=False, num_gpu_blocks_override=100)
cache_manager = make_prefix_cache_manager(max_num_seqs=3, enable_mm=False, num_gpu_blocks_override=128)
req1 = Request.from_dict({"request_id": "req1", "prompt_token_ids": [1] * 3200, "prompt_token_ids_len": 3200})
req2 = Request.from_dict(
{"request_id": "req2", "prompt_token_ids": [1] * 1600 + [2] * 1600, "prompt_token_ids_len": 3200}
@@ -61,14 +68,14 @@ def test_normal_case():
(common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req1, block_size)
assert len(common_block_ids) == 0
assert matched_token_num == 0
assert len(cache_manager.gpu_free_block_list) == 100
assert len(cache_manager.gpu_free_block_list) == 128
req1.block_tables.extend(common_block_ids)
# allocate for req1 inputs
num_new_block = 50
req1.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block))
req1.num_computed_tokens += 50 * block_size
cache_manager.update_cache_blocks(req1, block_size, req1.num_computed_tokens)
assert len(cache_manager.gpu_free_block_list) == 50
assert len(cache_manager.gpu_free_block_list) == 78
# allocate for req2 inputs
(common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req2, block_size)
assert len(common_block_ids) == 25
@@ -85,13 +92,13 @@ def test_normal_case():
assert matched_token_num == 25 * block_size
req3.num_cached_tokens = matched_token_num
req3.num_computed_tokens = 25 * block_size
assert len(cache_manager.gpu_free_block_list) == 25
assert len(cache_manager.gpu_free_block_list) == 53
req3.block_tables.extend(common_block_ids)
num_new_block = 25
assert cache_manager.can_allocate_gpu_blocks(num_new_block)
req3.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block))
cache_manager.update_cache_blocks(req3, block_size, req3.num_computed_tokens)
assert len(cache_manager.gpu_free_block_list) == 0
assert len(cache_manager.gpu_free_block_list) == 28
def test_mm_extra_keys():

View File

@@ -32,7 +32,7 @@ def make_prefix_cache_manager(max_num_seqs, enable_mm=False, num_gpu_blocks_over
)
args = asdict(engine_args)
cache_cfg = CacheConfig(args)
model_cfg = SimpleNamespace(enable_mm=enable_mm, max_model_len=8192)
model_cfg = SimpleNamespace(enable_mm=enable_mm, max_model_len=4196)
speculative_cfg = SimpleNamespace(method=None)
model_cfg.print = print
cache_cfg.bytes_per_layer_per_block = 1

View File

@@ -90,7 +90,7 @@ def test_normal_schedule():
def test_preempted_request():
max_num_seqs = 2
engine_args = EngineArgs(max_num_seqs=max_num_seqs, num_gpu_blocks_override=52, max_num_batched_tokens=3200)
engine_args = EngineArgs(max_num_seqs=max_num_seqs, num_gpu_blocks_override=102, max_num_batched_tokens=3200)
args = asdict(engine_args)
cache_cfg = CacheConfig(args)
model_cfg = SimpleNamespace(enable_mm=False)
@@ -127,22 +127,22 @@ def test_preempted_request():
assert len(resource_manager_v1.waiting) == 1
# step 2
scheduler_reqs = resource_manager_v1.schedule()
assert len(scheduler_reqs) == 1
assert len(scheduler_reqs) == 2
assert scheduler_reqs[0].request_id == "req1"
assert len(scheduler_reqs[0].block_tables) == 52
# step 3
req1.output_token_ids.extend([1] * 128)
scheduler_reqs = resource_manager_v1.schedule()
assert len(scheduler_reqs) == 1
assert scheduler_reqs[0].request_id == "req1"
assert len(resource_manager_v1.running) == 0
assert len(scheduler_reqs) == 2
assert scheduler_reqs[0].request_id == "req2"
assert len(resource_manager_v1.running) == 1
# to be added into waiting queue
assert len(resource_manager_v1.waiting) == 1
assert len(resource_manager_v1.waiting) == 0
assert "req2" in resource_manager_v1.to_be_rescheduled_request_id_set
# mock token_processor to add into waiting
resource_manager_v1.waiting.appendleft(req1)
resource_manager_v1.waiting.appendleft(req2)
# step 4
scheduler_reqs = resource_manager_v1.schedule()
assert len(scheduler_reqs) == 1
assert scheduler_reqs[0].request_id == "req1"
assert len(scheduler_reqs) == 0
assert len(resource_manager_v1.running) == 1
assert len(resource_manager_v1.waiting) == 1