mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[feat] support prefix cache clearing when /clear_load_weight
is called (#4091)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* [feat] support clearing prefix cache (cherry-picked from release/2.1) * [fix] fix ipc suffix, use port instead * [fix] fix prefix caching not enabled * [fix] fix code style * [fix] wait for rank0 to update weight status
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
@@ -27,7 +28,13 @@ from fastdeploy.config import ModelConfig
|
||||
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
|
||||
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
|
||||
from fastdeploy.input.preprocess import InputPreprocessor
|
||||
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
|
||||
from fastdeploy.inter_communicator import (
|
||||
IPCSignal,
|
||||
KVCacheStatus,
|
||||
ModelWeightsStatus,
|
||||
PrefixTreeStatus,
|
||||
ZmqIpcClient,
|
||||
)
|
||||
from fastdeploy.metrics.work_metrics import work_process_metrics
|
||||
from fastdeploy.multimodal.registry import MultimodalRegistry
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -55,6 +62,8 @@ class EngineClient:
|
||||
enable_logprob=False,
|
||||
workers=1,
|
||||
tool_parser=None,
|
||||
enable_prefix_caching=None,
|
||||
splitwise_role=None,
|
||||
):
|
||||
import fastdeploy.model_executor.models # noqa: F401
|
||||
|
||||
@@ -76,6 +85,8 @@ class EngineClient:
|
||||
self.reasoning_parser = reasoning_parser
|
||||
self.data_processor = input_processor.create_processor()
|
||||
self.max_model_len = max_model_len
|
||||
self.enable_prefix_caching = enable_prefix_caching
|
||||
self.enable_splitwise = splitwise_role != "mixed"
|
||||
max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
|
||||
if tensor_parallel_size <= max_chips_per_node:
|
||||
@@ -101,10 +112,27 @@ class EngineClient:
|
||||
suffix=port,
|
||||
create=False,
|
||||
)
|
||||
prefix_tree_status = np.zeros([1], dtype=np.int32)
|
||||
self.prefix_tree_status_signal = IPCSignal(
|
||||
name="prefix_tree_status",
|
||||
array=prefix_tree_status,
|
||||
dtype=np.int32,
|
||||
suffix=port,
|
||||
create=False,
|
||||
)
|
||||
kv_cache_status = np.zeros([1], dtype=np.int32)
|
||||
self.kv_cache_status_signal = IPCSignal(
|
||||
name="kv_cache_status",
|
||||
array=kv_cache_status,
|
||||
dtype=np.int32,
|
||||
suffix=port,
|
||||
create=False,
|
||||
)
|
||||
self.connection_manager = DealerConnectionManager(
|
||||
pid, max_connections=int(os.getenv("FD_DEALER_CONNECTIONS", 50))
|
||||
)
|
||||
self.connection_initialized = False
|
||||
self.clear_update_lock = threading.Lock()
|
||||
|
||||
def create_zmq_client(self, model, mode):
|
||||
"""
|
||||
@@ -310,7 +338,7 @@ class EngineClient:
|
||||
Check the health of the model server by checking whether all workers are alive.
|
||||
|
||||
"""
|
||||
if self.model_weights_status_signal.value[0] == 0:
|
||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL:
|
||||
return True, ""
|
||||
else:
|
||||
return False, "No model weight enabled"
|
||||
@@ -321,21 +349,42 @@ class EngineClient:
|
||||
1 : worker receive the signal and start to update model weight
|
||||
2 : worker update finish and notify client
|
||||
"""
|
||||
if self.model_weights_status_signal.value[0] == 0:
|
||||
return True, ""
|
||||
if self.model_weights_status_signal.value[0] == 1:
|
||||
return False, "updating model weight already"
|
||||
with self.clear_update_lock:
|
||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL:
|
||||
return True, ""
|
||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING:
|
||||
return False, "updating model weight already"
|
||||
|
||||
self.model_weights_status_signal.value[0] = 1
|
||||
api_server_logger.info(f"start update model weight {self.model_weights_status_signal.value}")
|
||||
while self.model_weights_status_signal.value[0] != 0 and timeout != 0:
|
||||
self.model_weights_status_signal.value[0] = ModelWeightsStatus.UPDATING
|
||||
if self.enable_prefix_caching or self.enable_splitwise:
|
||||
self.kv_cache_status_signal.value[0] = KVCacheStatus.UPDATING
|
||||
if self.enable_prefix_caching:
|
||||
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.UPDATING
|
||||
api_server_logger.info(f"start update model weight {self.model_weights_status_signal.value}")
|
||||
all_updated = False
|
||||
while timeout >= 0 and not all_updated:
|
||||
api_server_logger.info(
|
||||
f"Updating model weights.. "
|
||||
f"model_weights_status: {self.model_weights_status_signal.value[0]}, "
|
||||
f"prefix_tree_status: {self.prefix_tree_status_signal.value[0]}, "
|
||||
f"kv_cache_status: {self.kv_cache_status_signal.value[0]} "
|
||||
)
|
||||
weight_updated = self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL
|
||||
cache_updated = self.kv_cache_status_signal.value[0] == KVCacheStatus.NORMAL
|
||||
prefix_updated = self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.NORMAL
|
||||
if self.enable_prefix_caching or self.enable_splitwise:
|
||||
if self.enable_prefix_caching:
|
||||
all_updated = weight_updated and cache_updated and prefix_updated
|
||||
else:
|
||||
all_updated = weight_updated and cache_updated
|
||||
else:
|
||||
all_updated = weight_updated
|
||||
time.sleep(1)
|
||||
timeout -= 1
|
||||
if timeout < 0:
|
||||
return False, "Update model weight timeout"
|
||||
time.sleep(1)
|
||||
timeout -= 1
|
||||
continue
|
||||
if self.model_weights_status_signal.value[0] != 0:
|
||||
return False, "Update model weight timeout"
|
||||
time.sleep(1)
|
||||
return True, ""
|
||||
return True, ""
|
||||
|
||||
def clear_load_weight(self, timeout=300):
|
||||
"""
|
||||
@@ -343,19 +392,42 @@ class EngineClient:
|
||||
-1 : worker receive the signal and start to clear model weight
|
||||
-2 : worker clear finish and notify client
|
||||
"""
|
||||
if self.model_weights_status_signal.value[0] == -2:
|
||||
return True, ""
|
||||
if self.model_weights_status_signal.value[0] == -1:
|
||||
return False, "clearing model weight already"
|
||||
|
||||
self.model_weights_status_signal.value[0] = -1
|
||||
with self.clear_update_lock:
|
||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED:
|
||||
return True, ""
|
||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING:
|
||||
return False, "clearing model weight already"
|
||||
|
||||
api_server_logger.info(f"start clear model weight {self.model_weights_status_signal.value}")
|
||||
while self.model_weights_status_signal.value[0] != -2 and timeout != 0:
|
||||
self.model_weights_status_signal.value[0] = ModelWeightsStatus.CLEARING
|
||||
if self.enable_prefix_caching or self.enable_splitwise:
|
||||
self.kv_cache_status_signal.value[0] = KVCacheStatus.CLEARING
|
||||
if self.enable_prefix_caching:
|
||||
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARING
|
||||
|
||||
api_server_logger.info(f"start clear model weight {self.model_weights_status_signal.value}")
|
||||
all_cleared = False
|
||||
while timeout >= 0 and not all_cleared:
|
||||
api_server_logger.info(
|
||||
f"Clearing model weights.. "
|
||||
f"model_weights_status: {self.model_weights_status_signal.value[0]}, "
|
||||
f"prefix_tree_status: {self.prefix_tree_status_signal.value[0]}, "
|
||||
f"kv_cache_status: {self.kv_cache_status_signal.value[0]} "
|
||||
)
|
||||
weight_cleared = self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED
|
||||
cache_cleared = self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARED
|
||||
prefix_cleared = self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARED
|
||||
if self.enable_prefix_caching or self.enable_splitwise:
|
||||
if self.enable_prefix_caching:
|
||||
all_cleared = weight_cleared and cache_cleared and prefix_cleared
|
||||
else:
|
||||
all_cleared = weight_cleared and cache_cleared
|
||||
else:
|
||||
all_cleared = weight_cleared
|
||||
time.sleep(1)
|
||||
timeout -= 1
|
||||
|
||||
if timeout < 0:
|
||||
return False, "Clear model weight timeout"
|
||||
time.sleep(1)
|
||||
timeout -= 1
|
||||
continue
|
||||
if self.model_weights_status_signal.value[0] != -2:
|
||||
return False, "clear model weight timeout"
|
||||
time.sleep(1)
|
||||
return True, ""
|
||||
return True, ""
|
||||
|
Reference in New Issue
Block a user