mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[feat] support prefix cache clearing when /clear_load_weight
is called (#4091)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* [feat] support clearing prefix cache (cherry-picked from release/2.1) * [fix] fix ipc suffix, use port instead * [fix] fix prefix caching not enabled * [fix] fix code style * [fix] wait for rank0 to update weight status
This commit is contained in:
@@ -174,6 +174,24 @@ class EngineSevice:
|
||||
create=True,
|
||||
)
|
||||
|
||||
cache_ready_signal_data = np.zeros(shape=[self.cfg.parallel_config.tensor_parallel_size], dtype=np.int32)
|
||||
self.cache_ready_signal = IPCSignal(
|
||||
name="cache_ready_signal",
|
||||
array=cache_ready_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=current_suffix,
|
||||
create=True,
|
||||
)
|
||||
|
||||
swap_space_ready_signal_data = np.zeros(shape=[self.cfg.parallel_config.tensor_parallel_size], dtype=np.int32)
|
||||
self.swap_space_ready_signal = IPCSignal(
|
||||
name="swap_space_ready_signal",
|
||||
array=swap_space_ready_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=current_suffix,
|
||||
create=True,
|
||||
)
|
||||
|
||||
model_weights_status = np.zeros([1], dtype=np.int32)
|
||||
self.model_weights_status_signal = IPCSignal(
|
||||
name="model_weights_status",
|
||||
@@ -183,6 +201,24 @@ class EngineSevice:
|
||||
create=True,
|
||||
)
|
||||
|
||||
prefix_tree_status = np.zeros([1], dtype=np.int32)
|
||||
self.prefix_tree_status_signal = IPCSignal(
|
||||
name="prefix_tree_status",
|
||||
array=prefix_tree_status,
|
||||
dtype=np.int32,
|
||||
suffix=current_suffix,
|
||||
create=True,
|
||||
)
|
||||
|
||||
kv_cache_status = np.zeros([1], dtype=np.int32)
|
||||
self.kv_cache_status_signal = IPCSignal(
|
||||
name="kv_cache_status",
|
||||
array=kv_cache_status,
|
||||
dtype=np.int32,
|
||||
suffix=current_suffix,
|
||||
create=True,
|
||||
)
|
||||
|
||||
def start_worker_queue_service(self, start_queue):
|
||||
"""
|
||||
start queue service for engine worker communication
|
||||
@@ -749,7 +785,7 @@ class EngineSevice:
|
||||
|
||||
threading.Thread(target=receiver_loop, daemon=True).start()
|
||||
|
||||
def start_cache_service(self, device_ids, ipc_signal_suffix):
|
||||
def start_cache_service(self, device_ids, ipc_signal_suffix, create_cache_tensor):
|
||||
return self.resource_manager.cache_manager.launch_cache_manager(
|
||||
cache_config=self.cfg.cache_config,
|
||||
tensor_parallel_size=self.cfg.parallel_config.tensor_parallel_size,
|
||||
@@ -759,6 +795,7 @@ class EngineSevice:
|
||||
self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
|
||||
),
|
||||
pid_suffix=ipc_signal_suffix,
|
||||
create_cache_tensor=create_cache_tensor,
|
||||
)
|
||||
|
||||
def check_and_free_block_tables(self):
|
||||
@@ -773,8 +810,12 @@ class EngineSevice:
|
||||
self.exist_task_signal.clear()
|
||||
self.exist_swapped_task_signal.clear()
|
||||
self.worker_healthy_live_signal.clear()
|
||||
self.cache_ready_signal.clear()
|
||||
self.swap_space_ready_signal.clear()
|
||||
self.exist_prefill_task_signal.clear()
|
||||
self.model_weights_status_signal.clear()
|
||||
self.prefix_tree_status_signal.clear()
|
||||
self.kv_cache_status_signal.clear()
|
||||
if hasattr(self, "send_response_server") and self.send_response_server is not None:
|
||||
self.send_response_server.close()
|
||||
if hasattr(self, "recv_request_server") and self.recv_request_server is not None:
|
||||
|
Reference in New Issue
Block a user