mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[feat] support prefix cache clearing when /clear_load_weight
is called (#4091)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* [feat] support clearing prefix cache (cherry-picked from release/2.1) * [fix] fix ipc suffix, use port instead * [fix] fix prefix caching not enabled * [fix] fix code style * [fix] wait for rank0 to update weight status
This commit is contained in:
@@ -41,7 +41,7 @@ from fastdeploy.config import (
|
||||
)
|
||||
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
|
||||
from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
|
||||
from fastdeploy.inter_communicator import IPCSignal
|
||||
from fastdeploy.inter_communicator import ExistTaskStatus, IPCSignal, ModelWeightsStatus
|
||||
from fastdeploy.model_executor.layers.quantization import get_quantization_config
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import get_logger, parse_quantization
|
||||
@@ -175,7 +175,7 @@ class PaddleDisWorkerProc:
|
||||
name="launched_expert_service_signal",
|
||||
array=launched_expert_service_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_pid,
|
||||
suffix=self.parallel_config.engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
while self.launched_expert_service_signal.value[self.local_rank % self.max_chips_per_node] == 0:
|
||||
@@ -192,7 +192,7 @@ class PaddleDisWorkerProc:
|
||||
name="worker_ready_signal",
|
||||
array=workers_ready,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_pid,
|
||||
suffix=self.parallel_config.engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
self.worker_ready_signal.value[self.local_rank % self.max_chips_per_node] = 1
|
||||
@@ -260,7 +260,7 @@ class PaddleDisWorkerProc:
|
||||
self.model_weights_signal = paddle.zeros([1], dtype=paddle.int32)
|
||||
while True:
|
||||
if self.local_rank % self.parallel_config.tensor_parallel_size == 0:
|
||||
if self.model_weights_status.value[0] != 0:
|
||||
if self.model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
|
||||
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
|
||||
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel:
|
||||
paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.ep_group)
|
||||
@@ -281,7 +281,7 @@ class PaddleDisWorkerProc:
|
||||
if self.nnode > 1 and self.parallel_config.tensor_parallel_size > self.max_chips_per_node:
|
||||
self.task_queue.read_finish_flag.set(1)
|
||||
else:
|
||||
self.exist_task_signal.value[0] = 1
|
||||
self.exist_task_signal.value[0] = ExistTaskStatus.EXIST
|
||||
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
# Synchronize the signal for other workers
|
||||
@@ -292,7 +292,7 @@ class PaddleDisWorkerProc:
|
||||
paddle.distributed.barrier(self.parallel_config.ep_group)
|
||||
else:
|
||||
paddle.distributed.barrier(self.parallel_config.tp_group)
|
||||
if self.model_weights_signal[0] != 0:
|
||||
if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL:
|
||||
logger.info(f"Rank: {self.local_rank} has updated parameters.")
|
||||
from fastdeploy.rl.dynamic_weight_manager import (
|
||||
DynamicWeightManager,
|
||||
@@ -304,16 +304,16 @@ class PaddleDisWorkerProc:
|
||||
self.worker.model_runner,
|
||||
self.parallel_config.engine_worker_queue_port,
|
||||
)
|
||||
self.model_weights_signal[0] = 0
|
||||
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
|
||||
|
||||
if self.exist_task_signal.value[0] == 1 or self.task_queue.read_finish_flag.get() == 1:
|
||||
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
|
||||
logger.info(f"Rank: {self.local_rank} Detected new requests.")
|
||||
self.insert_step = True
|
||||
|
||||
tasks, read_finish = self.task_queue.get_tasks()
|
||||
if read_finish:
|
||||
# Ensure that every worker get the task
|
||||
self.exist_task_signal.value[0] = 0
|
||||
self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY
|
||||
self.task_queue.read_finish_flag.set(0)
|
||||
|
||||
req_dicts = []
|
||||
@@ -389,7 +389,7 @@ class PaddleDisWorkerProc:
|
||||
name="get_profile_block_num",
|
||||
array=get_profile_block_num,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_pid,
|
||||
suffix=self.parallel_config.engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
self.get_profile_block_num_signal.value[0] = num_blocks_local
|
||||
@@ -397,18 +397,7 @@ class PaddleDisWorkerProc:
|
||||
num_blocks_local = self.fd_config.parallel_config.total_block_num
|
||||
|
||||
logger.info(f"------- num_blocks_global: {num_blocks_local} --------")
|
||||
# wait engine launch cache_manager
|
||||
if self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed":
|
||||
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
|
||||
self.launched_cache_manager_signal = IPCSignal(
|
||||
name="launched_cache_manager_signal",
|
||||
array=launched_cache_manager_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_pid,
|
||||
create=False,
|
||||
)
|
||||
while np.any(self.launched_cache_manager_signal.value[0] <= 0):
|
||||
time.sleep(0.01)
|
||||
|
||||
# 4. init kv_cache with accurate num_blocks
|
||||
self.worker.initialize_cache(num_gpu_blocks=num_blocks_local)
|
||||
|
||||
@@ -443,7 +432,7 @@ class PaddleDisWorkerProc:
|
||||
name="loaded_model_signal",
|
||||
array=loaded_model_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_pid,
|
||||
suffix=self.parallel_config.engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
if self.ranks > 1:
|
||||
|
Reference in New Issue
Block a user