[feat] support prefix cache clearing when /clear_load_weight is called (#4008)

* [feat] support clearing prefix cache (cherry-picked from release/2.1)

* [fix] fix ipc suffix, use port instead

* [fix] fix prefix caching not enabled

* [fix] fix key/value_cache_scales indent

* [fix] fix ep group all-reduce

* [fix] fix clear/update lock not working when workers > 1

* [chore] add preemption triggered info log

* [fix] fix code style

* [fix] fix max_num_seqs config

* [fix] do not force enable_prefix_caching=False in dynamic loading

* [fix] fix ci

* Revert "[fix] fix ci"

This reverts commit 0bc6d55cc8.

* [fix] initialize available_gpu_block_num with max_gpu_block_num

* [fix] fix config splitwise_role

* [fix] fix clearing caches synchronization and add more logs

* [chore] print cache_ready_signal in log

* [fix] fix scheduler_config.splitwise_role

* [fix] fix cache_messager cache_ready_signal create=True

* [fix] stop cache messager from launching in mixed deployment
This commit is contained in:
李泳桦
2025-09-28 19:42:53 +08:00
committed by GitHub
parent 59313ed7f9
commit 6265f4385f
20 changed files with 697 additions and 213 deletions

View File

@@ -59,6 +59,7 @@ else:
set_value_by_flags_and_idx,
share_external_data,
speculate_schedule_cache,
set_data_ipc,
)
from fastdeploy.model_executor.pre_and_post_process import (
@@ -75,7 +76,7 @@ import zmq
from fastdeploy import envs
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
from fastdeploy.inter_communicator import ZmqIpcClient
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
from fastdeploy.worker.model_runner_base import ModelRunnerBase
@@ -1146,7 +1147,7 @@ class GPUModelRunner(ModelRunnerBase):
"""
Initialize kv cache
"""
cache_kvs = {}
# cache_kvs = {}
max_block_num = self.num_gpu_blocks
# Get kv cache dtype
@@ -1169,47 +1170,59 @@ class GPUModelRunner(ModelRunnerBase):
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not profile and (
self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed"
):
cache_kvs_list = []
for i in range(self.model_config.num_hidden_layers):
key_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
cache_kvs_list.append(key_cache)
value_cache = paddle.empty(shape=[], dtype=cache_type)
value_cache = share_external_data(value_cache, val_cache_name, kv_cache_shape)
cache_kvs_list.append(value_cache)
cache_ready_signal_data = np.zeros(shape=[self.parallel_config.tensor_parallel_size], dtype=np.int32)
cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
self.share_inputs["caches"] = cache_kvs_list
else:
for i in range(self.model_config.num_hidden_layers):
cache_kvs[f"key_caches_{i}"] = paddle.full(
shape=kv_cache_shape,
fill_value=0,
dtype=cache_type,
)
cache_kvs[f"value_caches_{i}"] = paddle.full(
shape=kv_cache_shape,
fill_value=0,
dtype=cache_type,
)
# Check if gpu runner needs to create kv cache
# 1. During profiling, it creates its own kv cache.
# 2. GPU runner creates kv cache tensor unless p/d disaggregation is enabled.
create_cache_tensor = profile or self.scheduler_config.splitwise_role == "mixed"
if not create_cache_tensor:
logger.info(f"Waiting for cache managers to create kv cache.. {cache_ready_signal.value}")
while cache_ready_signal.value[self.local_rank] != 1:
time.sleep(1)
logger.info(f"OK! Stop waiting. {cache_ready_signal.value}")
logger.info(f"Initializing kv cache for all layers. {cache_ready_signal.value}")
cache_kvs_list = []
for i in range(self.model_config.num_hidden_layers):
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
if create_cache_tensor:
logger.info(f"..creating kv cache for layer {i}: {kv_cache_shape}")
key_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type)
val_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type)
set_data_ipc(key_cache, key_cache_name)
set_data_ipc(val_cache, val_cache_name)
cache_kvs_list.extend([key_cache, val_cache])
if kv_cache_quant_type == "block_wise_fp8":
cache_kvs[f"key_cache_scales_{i}"] = paddle.full(
shape=kv_cache_scale_shape,
fill_value=0,
dtype=paddle.get_default_dtype(),
key_cache_scales = paddle.full(
shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype()
)
cache_kvs[f"value_cache_scales_{i}"] = paddle.full(
shape=kv_cache_scale_shape,
fill_value=0,
dtype=paddle.get_default_dtype(),
val_cache_scales = paddle.full(
shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype()
)
self.share_inputs["caches"] = list(cache_kvs.values())
for value in cache_kvs.values():
del value
cache_kvs_list.extend([key_cache_scales, val_cache_scales])
else:
logger.info(f"..attaching kv cache for layer {i}: {kv_cache_shape}")
key_cache = paddle.empty(shape=[], dtype=cache_type)
val_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
val_cache = share_external_data(val_cache, val_cache_name, kv_cache_shape)
cache_kvs_list.extend([key_cache, val_cache])
self.share_inputs["caches"] = cache_kvs_list
if not profile and create_cache_tensor:
cache_ready_signal.value[self.local_rank] = 1
logger.info(f"✅ kv cache is ready! {cache_ready_signal.value}")
paddle.device.cuda.empty_cache()
def initialize_attn_backend(self) -> None:
@@ -1935,6 +1948,7 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs.pop("caches", None)
if self.forward_meta is not None:
self.forward_meta.clear_caches()
paddle.device.cuda.empty_cache()
def clear_parameters(self, pid):
"""Dynamic model loader use to clear parameters use for RL"""