[feat] support prefix cache clearing when /clear_load_weight is called (#4091)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* [feat] support clearing prefix cache (cherry-picked from release/2.1)

* [fix] fix ipc suffix, use port instead

* [fix] fix prefix caching not enabled

* [fix] fix code style

* [fix] wait for rank0 to update weight status
This commit is contained in:
李泳桦
2025-09-16 11:11:20 +08:00
committed by GitHub
parent fbb4e0f8d1
commit 7ccbcc5a62
17 changed files with 624 additions and 181 deletions

View File

@@ -16,6 +16,7 @@
import inspect
import os
import threading
import time
import traceback
import uuid
@@ -27,7 +28,13 @@ from fastdeploy.config import ModelConfig
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
from fastdeploy.inter_communicator import (
IPCSignal,
KVCacheStatus,
ModelWeightsStatus,
PrefixTreeStatus,
ZmqIpcClient,
)
from fastdeploy.metrics.work_metrics import work_process_metrics
from fastdeploy.multimodal.registry import MultimodalRegistry
from fastdeploy.platforms import current_platform
@@ -55,6 +62,8 @@ class EngineClient:
enable_logprob=False,
workers=1,
tool_parser=None,
enable_prefix_caching=None,
splitwise_role=None,
):
import fastdeploy.model_executor.models # noqa: F401
@@ -76,6 +85,8 @@ class EngineClient:
self.reasoning_parser = reasoning_parser
self.data_processor = input_processor.create_processor()
self.max_model_len = max_model_len
self.enable_prefix_caching = enable_prefix_caching
self.enable_splitwise = splitwise_role != "mixed"
max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
if tensor_parallel_size <= max_chips_per_node:
@@ -101,10 +112,27 @@ class EngineClient:
suffix=port,
create=False,
)
prefix_tree_status = np.zeros([1], dtype=np.int32)
self.prefix_tree_status_signal = IPCSignal(
name="prefix_tree_status",
array=prefix_tree_status,
dtype=np.int32,
suffix=port,
create=False,
)
kv_cache_status = np.zeros([1], dtype=np.int32)
self.kv_cache_status_signal = IPCSignal(
name="kv_cache_status",
array=kv_cache_status,
dtype=np.int32,
suffix=port,
create=False,
)
self.connection_manager = DealerConnectionManager(
pid, max_connections=int(os.getenv("FD_DEALER_CONNECTIONS", 50))
)
self.connection_initialized = False
self.clear_update_lock = threading.Lock()
def create_zmq_client(self, model, mode):
"""
@@ -310,7 +338,7 @@ class EngineClient:
Check the health of the model server by checking whether all workers are alive.
"""
if self.model_weights_status_signal.value[0] == 0:
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL:
return True, ""
else:
return False, "No model weight enabled"
@@ -321,21 +349,42 @@ class EngineClient:
1 : worker receive the signal and start to update model weight
2 : worker update finish and notify client
"""
if self.model_weights_status_signal.value[0] == 0:
return True, ""
if self.model_weights_status_signal.value[0] == 1:
return False, "updating model weight already"
with self.clear_update_lock:
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL:
return True, ""
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING:
return False, "updating model weight already"
self.model_weights_status_signal.value[0] = 1
api_server_logger.info(f"start update model weight {self.model_weights_status_signal.value}")
while self.model_weights_status_signal.value[0] != 0 and timeout != 0:
self.model_weights_status_signal.value[0] = ModelWeightsStatus.UPDATING
if self.enable_prefix_caching or self.enable_splitwise:
self.kv_cache_status_signal.value[0] = KVCacheStatus.UPDATING
if self.enable_prefix_caching:
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.UPDATING
api_server_logger.info(f"start update model weight {self.model_weights_status_signal.value}")
all_updated = False
while timeout >= 0 and not all_updated:
api_server_logger.info(
f"Updating model weights.. "
f"model_weights_status: {self.model_weights_status_signal.value[0]}, "
f"prefix_tree_status: {self.prefix_tree_status_signal.value[0]}, "
f"kv_cache_status: {self.kv_cache_status_signal.value[0]} "
)
weight_updated = self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL
cache_updated = self.kv_cache_status_signal.value[0] == KVCacheStatus.NORMAL
prefix_updated = self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.NORMAL
if self.enable_prefix_caching or self.enable_splitwise:
if self.enable_prefix_caching:
all_updated = weight_updated and cache_updated and prefix_updated
else:
all_updated = weight_updated and cache_updated
else:
all_updated = weight_updated
time.sleep(1)
timeout -= 1
if timeout < 0:
return False, "Update model weight timeout"
time.sleep(1)
timeout -= 1
continue
if self.model_weights_status_signal.value[0] != 0:
return False, "Update model weight timeout"
time.sleep(1)
return True, ""
return True, ""
def clear_load_weight(self, timeout=300):
"""
@@ -343,19 +392,42 @@ class EngineClient:
-1 : worker receive the signal and start to clear model weight
-2 : worker clear finish and notify client
"""
if self.model_weights_status_signal.value[0] == -2:
return True, ""
if self.model_weights_status_signal.value[0] == -1:
return False, "clearing model weight already"
self.model_weights_status_signal.value[0] = -1
with self.clear_update_lock:
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED:
return True, ""
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING:
return False, "clearing model weight already"
api_server_logger.info(f"start clear model weight {self.model_weights_status_signal.value}")
while self.model_weights_status_signal.value[0] != -2 and timeout != 0:
self.model_weights_status_signal.value[0] = ModelWeightsStatus.CLEARING
if self.enable_prefix_caching or self.enable_splitwise:
self.kv_cache_status_signal.value[0] = KVCacheStatus.CLEARING
if self.enable_prefix_caching:
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARING
api_server_logger.info(f"start clear model weight {self.model_weights_status_signal.value}")
all_cleared = False
while timeout >= 0 and not all_cleared:
api_server_logger.info(
f"Clearing model weights.. "
f"model_weights_status: {self.model_weights_status_signal.value[0]}, "
f"prefix_tree_status: {self.prefix_tree_status_signal.value[0]}, "
f"kv_cache_status: {self.kv_cache_status_signal.value[0]} "
)
weight_cleared = self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED
cache_cleared = self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARED
prefix_cleared = self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARED
if self.enable_prefix_caching or self.enable_splitwise:
if self.enable_prefix_caching:
all_cleared = weight_cleared and cache_cleared and prefix_cleared
else:
all_cleared = weight_cleared and cache_cleared
else:
all_cleared = weight_cleared
time.sleep(1)
timeout -= 1
if timeout < 0:
return False, "Clear model weight timeout"
time.sleep(1)
timeout -= 1
continue
if self.model_weights_status_signal.value[0] != -2:
return False, "clear model weight timeout"
time.sleep(1)
return True, ""
return True, ""