mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[feat] support prefix cache clearing when /clear_load_weight
is called (#4091)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* [feat] support clearing prefix cache (cherry-picked from release/2.1) * [fix] fix ipc suffix, use port instead * [fix] fix prefix caching not enabled * [fix] fix code style * [fix] wait for rank0 to update weight status
This commit is contained in:
@@ -59,6 +59,7 @@ else:
|
||||
recover_decode_task,
|
||||
set_value_by_flags_and_idx,
|
||||
share_external_data,
|
||||
set_data_ipc,
|
||||
)
|
||||
|
||||
from fastdeploy.model_executor.pre_and_post_process import (
|
||||
@@ -73,6 +74,7 @@ if not (current_platform.is_dcu() or current_platform.is_iluvatar()):
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
|
||||
from fastdeploy.inter_communicator import IPCSignal
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
|
||||
from fastdeploy.worker.model_runner_base import ModelRunnerBase
|
||||
@@ -978,7 +980,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
"""
|
||||
Initialize kv cache
|
||||
"""
|
||||
cache_kvs = {}
|
||||
# cache_kvs = {}
|
||||
max_block_num = self.num_gpu_blocks
|
||||
|
||||
# Get kv cache dtype
|
||||
@@ -999,34 +1001,50 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||
|
||||
if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
|
||||
cache_kvs_list = []
|
||||
for i in range(self.model_config.num_hidden_layers):
|
||||
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
|
||||
cache_kvs_list.append(key_cache)
|
||||
value_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
value_cache = share_external_data(value_cache, val_cache_name, kv_cache_shape)
|
||||
cache_kvs_list.append(value_cache)
|
||||
cache_ready_signal_data = np.zeros(shape=[self.parallel_config.tensor_parallel_size], dtype=np.int32)
|
||||
cache_ready_signal = IPCSignal(
|
||||
name="cache_ready_signal",
|
||||
array=cache_ready_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
|
||||
# Check if gpu runner needs to create kv cache
|
||||
# 1. During profiling, it creates its own kv cache.
|
||||
# 2. GPU runner creates kv cache tensor unless p/d disaggregation is enabled.
|
||||
create_cache_tensor = profile or self.parallel_config.splitwise_role == "mixed"
|
||||
|
||||
if not create_cache_tensor:
|
||||
logger.info("Waiting for cache managers to create kv cache..")
|
||||
while cache_ready_signal.value[self.local_rank] != 1:
|
||||
time.sleep(1)
|
||||
logger.info("OK! Stop waiting.")
|
||||
|
||||
logger.info("Initializing kv cache for all layers.")
|
||||
cache_kvs_list = []
|
||||
for i in range(self.model_config.num_hidden_layers):
|
||||
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
if create_cache_tensor:
|
||||
logger.info(f"..creating kv cache for layer {i}: {kv_cache_shape}")
|
||||
key_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type)
|
||||
val_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type)
|
||||
set_data_ipc(key_cache, key_cache_name)
|
||||
set_data_ipc(val_cache, val_cache_name)
|
||||
else:
|
||||
logger.info(f"..attaching kv cache for layer {i}: {kv_cache_shape}")
|
||||
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
val_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
|
||||
val_cache = share_external_data(val_cache, val_cache_name, kv_cache_shape)
|
||||
cache_kvs_list.extend([key_cache, val_cache])
|
||||
self.share_inputs["caches"] = cache_kvs_list
|
||||
|
||||
if not profile and create_cache_tensor:
|
||||
logger.info("✅ kv cache is ready!")
|
||||
cache_ready_signal.value[self.local_rank] = 1
|
||||
|
||||
self.share_inputs["caches"] = cache_kvs_list
|
||||
else:
|
||||
for i in range(self.model_config.num_hidden_layers):
|
||||
cache_kvs[f"key_caches_{i}"] = paddle.full(
|
||||
shape=kv_cache_shape,
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
cache_kvs[f"value_caches_{i}"] = paddle.full(
|
||||
shape=kv_cache_shape,
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
self.share_inputs["caches"] = list(cache_kvs.values())
|
||||
for value in cache_kvs.values():
|
||||
del value
|
||||
paddle.device.cuda.empty_cache()
|
||||
|
||||
def initialize_attn_backend(self) -> None:
|
||||
@@ -1672,6 +1690,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs.pop("caches", None)
|
||||
if self.forward_meta is not None:
|
||||
self.forward_meta.clear_caches()
|
||||
paddle.device.cuda.empty_cache()
|
||||
|
||||
def clear_parameters(self, pid):
|
||||
""" " Dynamic model loader use to clear parameters use for RL"""
|
||||
|
Reference in New Issue
Block a user