Unify server-side and model-side Config (Part1) (#3018)

* move cache config

* fix mtp
This commit is contained in:
YuanRisheng
2025-07-28 10:51:52 +08:00
committed by GitHub
parent 8f426c1690
commit 6ccc10ad47
23 changed files with 243 additions and 289 deletions

View File

@@ -207,7 +207,7 @@ class GCUModelRunner(ModelRunnerBase):
self.share_inputs["prompt_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids)
# Use chunked prefill
if self.parallel_config.enable_chunked_prefill:
if self.cache_config.enable_chunked_prefill:
request.set("chunk_idx", 1)
logger.info(f"prefill_chunk_info: {request.prefill_chunk_info}")
token_chunk_size = request.prefill_chunk_info[0]
@@ -287,10 +287,10 @@ class GCUModelRunner(ModelRunnerBase):
num_tokens // batch_size,
self.parallel_config.max_model_len - max_dec_len,
)
input_length = int(full_length * self.parallel_config.kv_cache_ratio)
input_length = int(full_length * self.cache_config.kv_cache_ratio)
block_num = (
input_length + self.parallel_config.block_size - 1
) // self.parallel_config.block_size + self.parallel_config.enc_dec_block_num
input_length + self.cache_config.block_size - 1
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
for i in range(batch_size):
idx = i
@@ -417,15 +417,15 @@ class GCUModelRunner(ModelRunnerBase):
# Set block tables
pre_max_block_num = (
self.parallel_config.max_model_len + self.parallel_config.block_size - 1
) // self.parallel_config.block_size + self.parallel_config.enc_dec_block_num
self.parallel_config.max_model_len + self.cache_config.block_size - 1
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
self.share_inputs["block_tables"] = paddle.full([max_num_seqs, pre_max_block_num], -1, dtype="int32")
# Initialize free list
free_list = list(
range(
self.parallel_config.total_block_num - 1,
int(self.parallel_config.total_block_num * self.parallel_config.kv_cache_ratio) - 1,
int(self.parallel_config.total_block_num * self.cache_config.kv_cache_ratio) - 1,
-1,
)
)
@@ -608,9 +608,7 @@ class GCUModelRunner(ModelRunnerBase):
)
# local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not profile and (
self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
):
if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
raise NotImplementedError("prefix_caching is not support by GCUModelRunner.")
else:
for i in range(self.model_config.num_hidden_layers):
@@ -795,7 +793,7 @@ class GCUModelRunner(ModelRunnerBase):
"""
Update chunked prefill related parameters
"""
if not self.parallel_config.enable_chunked_prefill:
if not self.cache_config.enable_chunked_prefill:
return
for task in tasks:
if task.get("prefill_chunk_info", None) is None:
@@ -861,7 +859,7 @@ class GCUModelRunner(ModelRunnerBase):
A list of indices corresponding to the requests that need to be skipped.
"""
skip_idx_list = []
if not self.parallel_config.enable_chunked_prefill or self.guided_backend is None:
if not self.cache_config.enable_chunked_prefill or self.guided_backend is None:
return skip_idx_list
for task in model_forward_batch:
@@ -1079,7 +1077,7 @@ class GCUModelRunner(ModelRunnerBase):
free_list = list(
range(
self.num_gcu_blocks - 1,
int(self.num_gcu_blocks * self.parallel_config.kv_cache_ratio) - 1,
int(self.num_gcu_blocks * self.cache_config.kv_cache_ratio) - 1,
-1,
)
)
@@ -1123,7 +1121,7 @@ class GCUModelRunner(ModelRunnerBase):
if self.speculative_method in ["mtp"]
else self.model_config.num_hidden_layers
)
required_memory = byte_of_dtype * 2 * (self.parallel_config.block_size * hidden_dim) * num_layers # k + v
required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v
return required_memory
def not_need_stop(self) -> bool: