[Feature] Optimize prefix cache (#3208)

* [LLM] support ep

* Update worker_process.py

* Update expert_service.py

* Update worker_process.py

* format files

* optimize prefix cache

* optimize prefix cache

* optimize prefix cache

* pre commit format

* pre commit format

* pre commit format

* Update cache_messager.py
This commit is contained in:
ltd0924
2025-08-05 17:13:11 +08:00
committed by GitHub
parent 9f9971844f
commit dcf9c2daff
7 changed files with 314 additions and 147 deletions

View File

@@ -43,6 +43,7 @@ from fastdeploy.model_executor.layers.sample.sampler import Sampler, Speculative
from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.model_executor.ops.gpu import (
recover_decode_task,
set_data_ipc,
set_value_by_flags_and_idx,
share_external_data,
)
@@ -904,7 +905,7 @@ class GPUModelRunner(ModelRunnerBase):
)
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
if not profile and self.parallel_config.splitwise_role != "mixed":
cache_kvs_list = []
for i in range(self.model_config.num_hidden_layers):
key_cache = paddle.empty(shape=[], dtype=cache_type)
@@ -930,6 +931,15 @@ class GPUModelRunner(ModelRunnerBase):
fill_value=0,
dtype=cache_type,
)
if self.cache_config.enable_prefix_caching:
set_data_ipc(
cache_kvs[f"key_caches_{i}"],
f"key_caches_{i}_rank{local_rank}.device{self.device_id}",
)
set_data_ipc(
cache_kvs[f"value_caches_{i}"],
f"value_caches_{i}_rank{local_rank}.device{self.device_id}",
)
self.share_inputs["caches"] = list(cache_kvs.values())
for value in cache_kvs.values():
del value
@@ -1138,6 +1148,8 @@ class GPUModelRunner(ModelRunnerBase):
if task.chunk_idx > len(task.prefill_chunk_info):
continue
self.restore_chunked_prefill_request[task.request_id] = task
if len(self.restore_chunked_prefill_request) > 0:
self.share_inputs["not_need_stop"][0] = True
for id, task in list(self.restore_chunked_prefill_request.items()):
idx = task.idx
@@ -1182,7 +1194,7 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
self.share_inputs["prompt_lens"][idx : idx + 1] += token_chunk_size
self.share_inputs["step_idx"][idx : idx + 1] = 0
self.share_inputs["stop_flags"][idx : idx + 1] = False
if self.speculative_decoding and self.proposer.is_chunk_prefill_enabled():
self.proposer.update_task_chunk_prefill(task)
task.chunk_idx += 1
@@ -1507,12 +1519,12 @@ class GPUModelRunner(ModelRunnerBase):
hidden_dim = self.model_config.head_dim * self.model_config.kv_num_heads
# NOTE(liuzichang): Implement multi-layer MTP architecture in the future
num_layers = (
num_hidden_layers = (
self.model_config.num_hidden_layers + self.speculative_config.num_gpu_block_expand_ratio
if self.speculative_method in ["mtp"]
else self.model_config.num_hidden_layers
)
required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v
required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_hidden_layers # k + v
return required_memory
def not_need_stop(self) -> bool: