[feat] support prefix cache clearing when /clear_load_weight is called (#4091)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* [feat] support clearing prefix cache (cherry-picked from release/2.1)

* [fix] fix ipc suffix, use port instead

* [fix] fix prefix caching not enabled

* [fix] fix code style

* [fix] wait for rank0 to update weight status
This commit is contained in:
李泳桦
2025-09-16 11:11:20 +08:00
committed by GitHub
parent fbb4e0f8d1
commit 7ccbcc5a62
17 changed files with 624 additions and 181 deletions

View File

@@ -174,6 +174,24 @@ class EngineSevice:
create=True,
)
cache_ready_signal_data = np.zeros(shape=[self.cfg.parallel_config.tensor_parallel_size], dtype=np.int32)
self.cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
swap_space_ready_signal_data = np.zeros(shape=[self.cfg.parallel_config.tensor_parallel_size], dtype=np.int32)
self.swap_space_ready_signal = IPCSignal(
name="swap_space_ready_signal",
array=swap_space_ready_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
model_weights_status = np.zeros([1], dtype=np.int32)
self.model_weights_status_signal = IPCSignal(
name="model_weights_status",
@@ -183,6 +201,24 @@ class EngineSevice:
create=True,
)
prefix_tree_status = np.zeros([1], dtype=np.int32)
self.prefix_tree_status_signal = IPCSignal(
name="prefix_tree_status",
array=prefix_tree_status,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
kv_cache_status = np.zeros([1], dtype=np.int32)
self.kv_cache_status_signal = IPCSignal(
name="kv_cache_status",
array=kv_cache_status,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
def start_worker_queue_service(self, start_queue):
"""
start queue service for engine worker communication
@@ -749,7 +785,7 @@ class EngineSevice:
threading.Thread(target=receiver_loop, daemon=True).start()
def start_cache_service(self, device_ids, ipc_signal_suffix):
def start_cache_service(self, device_ids, ipc_signal_suffix, create_cache_tensor):
return self.resource_manager.cache_manager.launch_cache_manager(
cache_config=self.cfg.cache_config,
tensor_parallel_size=self.cfg.parallel_config.tensor_parallel_size,
@@ -759,6 +795,7 @@ class EngineSevice:
self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
),
pid_suffix=ipc_signal_suffix,
create_cache_tensor=create_cache_tensor,
)
def check_and_free_block_tables(self):
@@ -773,8 +810,12 @@ class EngineSevice:
self.exist_task_signal.clear()
self.exist_swapped_task_signal.clear()
self.worker_healthy_live_signal.clear()
self.cache_ready_signal.clear()
self.swap_space_ready_signal.clear()
self.exist_prefill_task_signal.clear()
self.model_weights_status_signal.clear()
self.prefix_tree_status_signal.clear()
self.kv_cache_status_signal.clear()
if hasattr(self, "send_response_server") and self.send_response_server is not None:
self.send_response_server.close()
if hasattr(self, "recv_request_server") and self.recv_request_server is not None: