Unify server-side and model-side Config (Part1) (#3018)

* move cache config

* fix mtp
This commit is contained in:
YuanRisheng
2025-07-28 10:51:52 +08:00
committed by GitHub
parent 8f426c1690
commit 6ccc10ad47
23 changed files with 243 additions and 289 deletions

View File

@@ -97,10 +97,10 @@ class MTPProposer(Proposer):
num_tokens // batch_size,
self.parallel_config.max_model_len - max_dec_len,
)
input_length = int(full_length * self.parallel_config.kv_cache_ratio)
input_length = int(full_length * self.cache_config.kv_cache_ratio)
block_num = (
input_length + self.parallel_config.block_size - 1
) // self.parallel_config.block_size + self.parallel_config.enc_dec_block_num
input_length + self.cache_config.block_size - 1
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
for i in range(batch_size):
idx = i
@@ -141,7 +141,7 @@ class MTPProposer(Proposer):
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
)
if not self.parallel_config.do_profile and (
self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
):
cache_kvs_list = []
for i in range(
@@ -219,14 +219,14 @@ class MTPProposer(Proposer):
self.main_model_num_gpu_blocks = num_gpu_blocks
self.num_gpu_blocks = int(num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
if not (self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
if not (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
self.initialize_kv_cache()
# Reset free list
free_list = list(
range(
self.num_gpu_blocks - 1,
int(self.main_model_num_gpu_blocks * self.parallel_config.kv_cache_ratio) - 1,
int(self.main_model_num_gpu_blocks * self.cache_config.kv_cache_ratio) - 1,
-1,
)
)
@@ -299,7 +299,7 @@ class MTPProposer(Proposer):
self.free_list = list(
range(
self.parallel_config.total_block_num - 1,
int(self.parallel_config.total_block_num * self.parallel_config.kv_cache_ratio) - 1,
int(self.parallel_config.total_block_num * self.cache_config.kv_cache_ratio) - 1,
-1,
)
)
@@ -371,7 +371,7 @@ class MTPProposer(Proposer):
]
self.model_inputs["pre_ids"][idx : idx + 1] = -1
self.model_inputs["step_idx"][idx : idx + 1] = 0
if self.parallel_config.enable_chunked_prefill:
if self.cache_config.enable_chunked_prefill:
token_chunk_size = request.prefill_chunk_info[0]
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
self.model_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size
@@ -640,7 +640,7 @@ class MTPProposer(Proposer):
self.model_inputs["used_list_len"],
self.model_inputs["free_list"],
self.model_inputs["free_list_len"],
self.parallel_config.block_size,
self.cache_config.block_size,
self.max_draft_token_num,
)