[feat] support prefix cache clearing when /clear_load_weight is called (#4091)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* [feat] support clearing prefix cache (cherry-picked from release/2.1)

* [fix] fix ipc suffix, use port instead

* [fix] fix prefix caching not enabled

* [fix] fix code style

* [fix] wait for rank0 to update weight status
This commit is contained in:
李泳桦
2025-09-16 11:11:20 +08:00
committed by GitHub
parent fbb4e0f8d1
commit 7ccbcc5a62
17 changed files with 624 additions and 181 deletions

View File

@@ -41,7 +41,7 @@ from fastdeploy.config import (
)
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.inter_communicator import ExistTaskStatus, IPCSignal, ModelWeightsStatus
from fastdeploy.model_executor.layers.quantization import get_quantization_config
from fastdeploy.platforms import current_platform
from fastdeploy.utils import get_logger, parse_quantization
@@ -175,7 +175,7 @@ class PaddleDisWorkerProc:
name="launched_expert_service_signal",
array=launched_expert_service_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
while self.launched_expert_service_signal.value[self.local_rank % self.max_chips_per_node] == 0:
@@ -192,7 +192,7 @@ class PaddleDisWorkerProc:
name="worker_ready_signal",
array=workers_ready,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
self.worker_ready_signal.value[self.local_rank % self.max_chips_per_node] = 1
@@ -260,7 +260,7 @@ class PaddleDisWorkerProc:
self.model_weights_signal = paddle.zeros([1], dtype=paddle.int32)
while True:
if self.local_rank % self.parallel_config.tensor_parallel_size == 0:
if self.model_weights_status.value[0] != 0:
if self.model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel:
paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.ep_group)
@@ -281,7 +281,7 @@ class PaddleDisWorkerProc:
if self.nnode > 1 and self.parallel_config.tensor_parallel_size > self.max_chips_per_node:
self.task_queue.read_finish_flag.set(1)
else:
self.exist_task_signal.value[0] = 1
self.exist_task_signal.value[0] = ExistTaskStatus.EXIST
if self.parallel_config.tensor_parallel_size > 1:
# Synchronize the signal for other workers
@@ -292,7 +292,7 @@ class PaddleDisWorkerProc:
paddle.distributed.barrier(self.parallel_config.ep_group)
else:
paddle.distributed.barrier(self.parallel_config.tp_group)
if self.model_weights_signal[0] != 0:
if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL:
logger.info(f"Rank: {self.local_rank} has updated parameters.")
from fastdeploy.rl.dynamic_weight_manager import (
DynamicWeightManager,
@@ -304,16 +304,16 @@ class PaddleDisWorkerProc:
self.worker.model_runner,
self.parallel_config.engine_worker_queue_port,
)
self.model_weights_signal[0] = 0
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
if self.exist_task_signal.value[0] == 1 or self.task_queue.read_finish_flag.get() == 1:
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
logger.info(f"Rank: {self.local_rank} Detected new requests.")
self.insert_step = True
tasks, read_finish = self.task_queue.get_tasks()
if read_finish:
# Ensure that every worker get the task
self.exist_task_signal.value[0] = 0
self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY
self.task_queue.read_finish_flag.set(0)
req_dicts = []
@@ -389,7 +389,7 @@ class PaddleDisWorkerProc:
name="get_profile_block_num",
array=get_profile_block_num,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
self.get_profile_block_num_signal.value[0] = num_blocks_local
@@ -397,18 +397,7 @@ class PaddleDisWorkerProc:
num_blocks_local = self.fd_config.parallel_config.total_block_num
logger.info(f"------- num_blocks_global: {num_blocks_local} --------")
# wait engine launch cache_manager
if self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed":
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
self.launched_cache_manager_signal = IPCSignal(
name="launched_cache_manager_signal",
array=launched_cache_manager_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False,
)
while np.any(self.launched_cache_manager_signal.value[0] <= 0):
time.sleep(0.01)
# 4. init kv_cache with accurate num_blocks
self.worker.initialize_cache(num_gpu_blocks=num_blocks_local)
@@ -443,7 +432,7 @@ class PaddleDisWorkerProc:
name="loaded_model_signal",
array=loaded_model_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
if self.ranks > 1: