[feat] support prefix cache clearing when /clear_load_weight is called (#4091)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* [feat] support clearing prefix cache (cherry-picked from release/2.1)

* [fix] fix ipc suffix, use port instead

* [fix] fix prefix caching not enabled

* [fix] fix code style

* [fix] wait for rank0 to update weight status
This commit is contained in:
李泳桦
2025-09-16 11:11:20 +08:00
committed by GitHub
parent fbb4e0f8d1
commit 7ccbcc5a62
17 changed files with 624 additions and 181 deletions

View File

@@ -0,0 +1,71 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "cuda_multiprocess.h"
#if !defined(_WIN32)
#include <errno.h>
#include <string.h>
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#endif
// 可选:仅删除/解除共享内存命名对象(不依赖之前保存的 addr/fd
static inline int sharedMemoryUnlinkByName(const char* name) {
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
// Windows 上没有 shm_unlink 语义。命名对象在最后一个句柄关闭后消失。
// 这里做“尽力而为”:尝试打开后立即关闭,减少一次引用。
HANDLE hMap = OpenFileMappingA(FILE_MAP_ALL_ACCESS, FALSE, name);
if (hMap) {
CloseHandle(hMap);
return 0;
}
// 已经不存在也算成功
return 0;
#else
// POSIX: 移除名字,未来不可再 open已映射区仍存活直至 munmap
if (shm_unlink(name) != 0) {
if (errno == ENOENT) return 0; // 不存在视作成功
return errno;
}
return 0;
#endif
}
void UnsetDataIpc(const paddle::Tensor& tmp_input,
const std::string& shm_name,
bool close_ipc,
bool unlink_shm) {
// 1) 关闭消费者导入的 IPC 映射(仅当 close_ipc=true 且该指针确为 OpenMemHandle 得来)
if (close_ipc) {
void* ptr = const_cast<void*>(tmp_input.data());
checkCudaErrors(cudaIpcCloseMemHandle(ptr));
}
// 2) 解除共享内存命名对象(仅处理“名字”,不保证解除旧映射)
if (unlink_shm) {
int rc = sharedMemoryUnlinkByName(shm_name.c_str());
if (rc != 0) {
PD_THROW("Unlink shared memory failed: name=%s, err=%d",
shm_name.c_str(), rc);
}
}
}
PD_BUILD_STATIC_OP(unset_data_ipc)
.Inputs({"tmp_input"})
.Attrs({"shm_name: std::string", "close_ipc: bool", "unlink_shm: bool"})
.SetKernelFn(PD_KERNEL(UnsetDataIpc));

View File

@@ -208,6 +208,7 @@ if paddle.is_compiled_with_rocm():
"gpu_ops/rebuild_padding.cu",
"gpu_ops/step.cu",
"gpu_ops/set_data_ipc.cu",
"gpu_ops/unset_data_ipc.cu",
"gpu_ops/moe/tritonmoe_preprocess.cu",
"gpu_ops/step_system_cache.cu",
"gpu_ops/get_output_ep.cc",
@@ -278,6 +279,7 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/beam_search_softmax.cu",
"gpu_ops/rebuild_padding.cu",
"gpu_ops/set_data_ipc.cu",
"gpu_ops/unset_data_ipc.cu",
"gpu_ops/read_data_ipc.cu",
"gpu_ops/enforce_generation.cu",
"gpu_ops/dequant_int8.cu",

View File

@@ -98,8 +98,8 @@ class CacheMessager:
cache_v = []
self.messager = {}
for layer_idx in range(self.num_layers):
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}.device{gpu_id}"]
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}.device{gpu_id}"]
cache_k.append(key_cache)
cache_v.append(val_cache)
cache_k_ptr_list.append(key_cache.data_ptr())

View File

@@ -16,21 +16,27 @@
import argparse
import concurrent.futures
import gc
import json
import queue
import threading
import time
import traceback
import numpy as np
import paddle
from fastdeploy import envs
from fastdeploy.cache_manager.cache_data import CacheStatus
from fastdeploy.config import SpeculativeConfig
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
from fastdeploy.model_executor.ops.gpu import (
cuda_host_alloc,
cuda_host_free,
set_data_ipc,
share_external_data,
swap_cache_all_layers,
unset_data_ipc,
)
from fastdeploy.utils import get_logger
@@ -93,6 +99,7 @@ def parse_args():
help="speculative config",
)
parser.add_argument("--local_data_parallel_id", type=int, default=0)
parser.add_argument("--create_cache_tensor", action="store_true")
args = parser.parse_args()
return args
@@ -110,7 +117,6 @@ class CacheTransferManager:
device = args.device_id
rank = args.rank
paddle.set_device(f"gpu:{device}")
self.gpu_cache_kvs = {}
self.cpu_cache_kvs = {}
self.gpu_cache_k_tensors = []
@@ -126,6 +132,7 @@ class CacheTransferManager:
self.n_ranks = args.mp_num
self.rank = rank
self.device = device
self.engine_pid = args.engine_pid
address = (args.pod_ip, args.cache_queue_port)
self.cache_task_queue = EngineCacheQueue(
@@ -136,70 +143,27 @@ class CacheTransferManager:
local_data_parallel_id=args.local_data_parallel_id,
)
self.num_cpu_blocks = args.num_cpu_blocks
cache_type = args.cache_dtype
for i in range(args.num_layers + self.num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
fill_value=0,
dtype=cache_type,
)
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
fill_value=0,
dtype=cache_type,
)
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
set_data_ipc(
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
f"key_caches_{i}_rank{rank}.device{device}",
)
set_data_ipc(
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
f"value_caches_{i}_rank{rank}.device{device}",
)
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
logger.info(f"device :{self.device}")
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
paddle.set_device("cpu")
self.k_dst_ptrs = []
self.v_dst_ptrs = []
for i in range(args.num_layers + self.num_extra_layers):
self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"] = cuda_host_alloc(
args.num_cpu_blocks * args.bytes_per_layer_per_block
)
self.k_dst_ptrs.append(self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"])
self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"] = cuda_host_alloc(
args.num_cpu_blocks * args.bytes_per_layer_per_block
)
self.v_dst_ptrs.append(self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"])
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
self.cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=args.engine_pid,
suffix=self.engine_pid,
create=False,
)
self.cache_ready_signal.value[self.rank] = 1
swap_space_ready_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
self.swap_space_ready_signal = IPCSignal(
name="swap_space_ready_signal",
array=swap_space_ready_data,
dtype=np.int32,
suffix=self.engine_pid,
create=False,
)
self.num_cpu_blocks = args.num_cpu_blocks
self._init_cpu_cache(args)
self._init_gpu_cache(args)
paddle.set_device(f"gpu:{device}")
if args.enable_splitwise:
@@ -232,6 +196,72 @@ class CacheTransferManager:
create=False,
)
threading.Thread(target=self.clear_or_update_caches, args=[args], daemon=True).start()
def _init_gpu_cache(self, args):
if not args.create_cache_tensor:
logger.info("Waiting for runners to create kv cache.")
while self.cache_ready_signal.value[self.rank] != 1:
time.sleep(1)
logger.info("OK! Stop waiting.")
logger.info("Initializing kv cache for all layers.")
paddle.set_device(f"gpu:{self.device}")
for i in range(args.num_layers + self.num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
cache_shape = [num_gpu_blocks, args.kv_num_head, args.block_size, args.head_dim]
key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}"
val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}"
if args.create_cache_tensor:
logger.info(f"..creating kv cache for layer {i}: {cache_shape}")
key_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype)
val_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype)
set_data_ipc(key_cache, key_name)
set_data_ipc(val_cache, val_name)
else:
logger.info(f"..attaching kv cache for layer {i}: {cache_shape}")
key_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
val_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
key_cache = share_external_data(key_cache, key_name, cache_shape)
val_cache = share_external_data(val_cache, val_name, cache_shape)
self.gpu_cache_kvs[key_name] = key_cache
self.gpu_cache_kvs[val_name] = val_cache
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[val_name])
if args.create_cache_tensor:
logger.info("✅ kv cache is ready!")
self.cache_ready_signal.value[self.rank] = 1
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
logger.info(f"device :{self.device}")
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
def _init_cpu_cache(self, args):
if args.num_cpu_blocks == 0:
logger.info("💡 no swap space (cpu cache) is specified.")
self.swap_space_ready_signal.value[self.rank] = 1
return
logger.info("Initializing swap space (cpu cache) for all layers.")
paddle.set_device("cpu")
self.k_dst_ptrs = []
self.v_dst_ptrs = []
for i in range(args.num_layers + self.num_extra_layers):
key_name = f"key_caches_{i}_rank{self.rank}"
val_name = f"value_caches_{i}_rank{self.rank}"
need_to_allocate_bytes = args.num_cpu_blocks * args.bytes_per_layer_per_block
logger.info(f"..creating cpu cache for layer {i}: {2 * need_to_allocate_bytes / 1024 ** 3:.2f}GB")
self.cpu_cache_kvs[key_name] = cuda_host_alloc(need_to_allocate_bytes)
self.k_dst_ptrs.append(self.cpu_cache_kvs[key_name])
self.cpu_cache_kvs[val_name] = cuda_host_alloc(need_to_allocate_bytes)
self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name])
logger.info("✅ swap space (cpu cache) is ready!")
self.swap_space_ready_signal.value[self.rank] = 1
def _do_swap_to_cpu_task(
self,
swap_node_ids,
@@ -429,6 +459,67 @@ class CacheTransferManager:
transfer_task_id,
)
def clear_or_update_caches(self, args):
logger.info("Start a thread to clear/restore kv cache when model weights are cleared/updated.")
logger.info(f"FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}")
kv_cache_status = np.zeros([1], dtype=np.int32)
kv_cache_status_signal = IPCSignal(
name="kv_cache_status",
array=kv_cache_status,
dtype=np.int32,
suffix=self.engine_pid,
create=False,
)
while True:
if kv_cache_status_signal.value[0] == KVCacheStatus.CLEARING:
try:
if envs.FD_ENABLE_SWAP_SPACE_CLEARING:
paddle.set_device("cpu")
for ptrs in self.k_dst_ptrs + self.v_dst_ptrs:
cuda_host_free(ptrs)
self.cpu_cache_kvs.clear()
self.k_dst_ptrs.clear()
self.v_dst_ptrs.clear()
gc.collect()
# reset swap_space_ready_signal
self.swap_space_ready_signal.value[self.rank] = 0
while np.sum(self.swap_space_ready_signal.value) != 0:
time.sleep(0.1)
paddle.set_device(f"gpu:{self.device}")
for name, tensor in self.gpu_cache_kvs.items():
unset_data_ipc(tensor, name, True, False)
self.gpu_cache_kvs.clear()
self.gpu_cache_k_tensors.clear()
self.gpu_cache_v_tensors.clear()
# reset cache_ready_signal
self.cache_ready_signal.value[self.rank] = 0
if np.sum(self.cache_ready_signal.value) == 0:
time.sleep(0.1)
kv_cache_status_signal.value[0] = KVCacheStatus.CLEARED
except Exception as e:
logger.error(f"Failed to clear caches: {e}")
elif kv_cache_status_signal.value[0] == KVCacheStatus.UPDATING:
try:
if envs.FD_ENABLE_SWAP_SPACE_CLEARING:
self._init_cpu_cache(args)
while np.sum(self.swap_space_ready_signal.value) != args.mp_num:
time.sleep(0.1)
self._init_gpu_cache(args)
while np.sum(self.cache_ready_signal.value) != args.mp_num:
time.sleep(0.1)
kv_cache_status_signal.value[0] = KVCacheStatus.NORMAL
except Exception as e:
logger.error(f"Failed to restore caches: {e}")
time.sleep(0.1)
def main():
"""

View File

@@ -31,7 +31,7 @@ import numpy as np
from fastdeploy import envs
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
from fastdeploy.cache_manager.cache_metrics import CacheMetrics
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import get_logger
@@ -71,6 +71,7 @@ class PrefixCacheManager:
else:
self.num_gpu_blocks = self.cache_config.prefill_kvcache_block_num
self.num_cpu_blocks = self.cache_config.num_cpu_blocks
self.gpu_free_block_list = list(range(self.num_gpu_blocks - 1, -1, -1))
if self.num_cpu_blocks > 0:
self.cpu_free_block_list = list(range(self.num_cpu_blocks - 1, -1, -1))
@@ -78,6 +79,7 @@ class PrefixCacheManager:
self.cpu_free_block_list = []
heapq.heapify(self.gpu_free_block_list)
heapq.heapify(self.cpu_free_block_list)
self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks))
self.radix_tree_root = BlockNode(-1, [], 0, 0, -1, 0, None, None, None)
@@ -123,6 +125,7 @@ class PrefixCacheManager:
pod_ip,
engine_worker_queue_port,
pid_suffix,
create_cache_tensor,
):
"""
launch_cache_manager function used to initialize the cache manager.
@@ -133,7 +136,7 @@ class PrefixCacheManager:
name="cache_task_broadcast_signal",
array=broadcast_cache_task_flag_array,
dtype=np.int32,
suffix=pid_suffix,
suffix=engine_worker_queue_port,
create=True,
)
@@ -160,20 +163,41 @@ class PrefixCacheManager:
else:
kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size
kv_num_head = max(1, kv_num_head)
cache_ready_signal_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
self.cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=pid_suffix,
create=True,
suffix=engine_worker_queue_port,
create=False,
)
swap_space_ready_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
self.swap_space_ready_signal = IPCSignal(
name="swap_space_ready_signal",
array=swap_space_ready_data,
dtype=np.int32,
suffix=engine_worker_queue_port,
create=False,
)
prefix_tree_status = np.zeros([1], dtype=np.int32)
self.prefix_tree_status_signal = IPCSignal(
name="prefix_tree_status",
array=prefix_tree_status,
dtype=np.int32,
suffix=engine_worker_queue_port,
create=False,
)
# Run command to launch cache transfer managers
logger.info(f"create_cache_tensor: {create_cache_tensor}")
log_dir = envs.FD_LOG_DIR
cache_manager_processes = []
for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
+ f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}"
+ f" {sys.executable} {py_path}"
+ f" --device_id {int(device_ids[i])}"
+ f" --rank {i}"
@@ -196,23 +220,33 @@ class PrefixCacheManager:
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ (" --create_cache_tensor" if create_cache_tensor else "")
+ f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1"
)
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
# 等待cache初始化完毕
logger.info("Waiting for cache transfer manager ready...")
logger.info("PrefixCacheManager is waiting for kv cache to be initialized.")
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
time.sleep(1)
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
while np.sum(self.swap_space_ready_signal.value) != tensor_parallel_size:
time.sleep(1)
exit_code = cache_manager_processes[-1].poll()
if exit_code is None:
logger.info("Launch cache transfer manager successful")
else:
logger.info("Launch cache transfer manager failed, see launch_cache_manager.log for more information")
# Start additional threads
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
logger.info("Enable hierarchical cache.")
self._enable_cpu_cache()
threading.Thread(target=self.recv_data_transfer_result).start()
if cache_config.enable_prefix_caching:
threading.Thread(target=self.clear_prefix_cache, daemon=True).start()
return cache_manager_processes
def update_cache_config(self, cache_config):
@@ -237,21 +271,6 @@ class PrefixCacheManager:
main_process_metrics.max_gpu_block_num.set(self.num_gpu_blocks)
main_process_metrics.available_gpu_resource.set(1.0)
def _enable_cpu_cache(self):
"""
_enable_cpu_cache function used to enable cpu cache.
"""
# ipc_cache_queue_port = self.cache_config.cache_queue_port
# self.cache_task_queue = CacheQueueManager(
# rank=0,
# mp_num=tensor_parallel_size,
# port=ipc_cache_queue_port,
# )
# 开启获取传输任务结果的监听线程
self.transfer_recv_thread = threading.Thread(target=self.recv_data_transfer_result)
self.transfer_recv_thread.start()
def can_allocate_gpu_blocks(self, num_blocks: int):
"""
Check if num_blocks gpu blocks can be allocated.
@@ -1295,3 +1314,70 @@ class PrefixCacheManager:
except Exception as e:
logger.warning(f"recv_data_transfer_result: error: {e}, {str(traceback.format_exc())}")
raise e
def reset(self):
"""
Reset the RadixTree.
"""
if len(self.node_map) == 0:
return
logger.info("Resetting the RadixTree!")
# wait for swap tasks to finish
if self.gpu_free_task_future is not None:
self.gpu_free_task_future.result()
self.gpu_free_task_future = None
for event in list(self.task_swapping_event.values()):
event.wait()
self.task_swapping_event.clear()
# clear node map
self.node_map.clear()
self.req_leaf_map.clear()
self.leaf_req_map.clear()
self.unfilled_req_block_map.clear()
self.cache_info.clear()
# reset gpu cache data structure
self.gpu_lru_leaf_heap.clear()
self.gpu_lru_leaf_set.clear()
# reset cpu cache data structure
self.cpu_lru_leaf_heap.clear()
self.cpu_lru_leaf_set.clear()
# reset gpu/cpu free block list
self.gpu_free_block_list = list(range(self.num_gpu_blocks - 1, -1, -1))
if self.num_cpu_blocks > 0:
self.cpu_free_block_list = list(range(self.num_cpu_blocks - 1, -1, -1))
else:
self.cpu_free_block_list = []
heapq.heapify(self.gpu_free_block_list)
heapq.heapify(self.cpu_free_block_list)
# reset node/tree
self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks))
self.radix_tree_root = BlockNode(-1, [], 0, 0, -1, 0, None, None, None)
# reset metrics
self.metrics.reset_metrics()
main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list))
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)
def clear_prefix_cache(self):
"""
If the model weights status is updating or clearing, reset prefix cache tree
"""
logger.info("Start a thread to clear prefix cache when model weights are cleared.")
prefix_tree_status_signal = self.prefix_tree_status_signal
while True:
if prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARING:
self.reset()
prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARED
logger.info("Prefix cache tree is cleared.")
if prefix_tree_status_signal.value[0] == PrefixTreeStatus.UPDATING:
prefix_tree_status_signal.value[0] = PrefixTreeStatus.NORMAL
logger.info("Prefix cache tree is updated.")
time.sleep(0.01)

View File

@@ -387,6 +387,7 @@ class EngineArgs:
"""
Post-initialization processing to set default tokenizer if not provided.
"""
if not self.tokenizer:
self.tokenizer = self.model
if self.splitwise_role == "decode":
@@ -397,8 +398,8 @@ class EngineArgs:
self.enable_prefix_caching = False
if not current_platform.is_cuda():
self.enable_prefix_caching = False
if self.dynamic_load_weight:
self.enable_prefix_caching = False
# if self.dynamic_load_weight:
# self.enable_prefix_caching = False
if self.enable_logprob:
if self.speculative_config is not None:
raise NotImplementedError("Logprob does not support speculation_config.")

View File

@@ -174,6 +174,24 @@ class EngineSevice:
create=True,
)
cache_ready_signal_data = np.zeros(shape=[self.cfg.parallel_config.tensor_parallel_size], dtype=np.int32)
self.cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
swap_space_ready_signal_data = np.zeros(shape=[self.cfg.parallel_config.tensor_parallel_size], dtype=np.int32)
self.swap_space_ready_signal = IPCSignal(
name="swap_space_ready_signal",
array=swap_space_ready_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
model_weights_status = np.zeros([1], dtype=np.int32)
self.model_weights_status_signal = IPCSignal(
name="model_weights_status",
@@ -183,6 +201,24 @@ class EngineSevice:
create=True,
)
prefix_tree_status = np.zeros([1], dtype=np.int32)
self.prefix_tree_status_signal = IPCSignal(
name="prefix_tree_status",
array=prefix_tree_status,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
kv_cache_status = np.zeros([1], dtype=np.int32)
self.kv_cache_status_signal = IPCSignal(
name="kv_cache_status",
array=kv_cache_status,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
def start_worker_queue_service(self, start_queue):
"""
start queue service for engine worker communication
@@ -749,7 +785,7 @@ class EngineSevice:
threading.Thread(target=receiver_loop, daemon=True).start()
def start_cache_service(self, device_ids, ipc_signal_suffix):
def start_cache_service(self, device_ids, ipc_signal_suffix, create_cache_tensor):
return self.resource_manager.cache_manager.launch_cache_manager(
cache_config=self.cfg.cache_config,
tensor_parallel_size=self.cfg.parallel_config.tensor_parallel_size,
@@ -759,6 +795,7 @@ class EngineSevice:
self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
),
pid_suffix=ipc_signal_suffix,
create_cache_tensor=create_cache_tensor,
)
def check_and_free_block_tables(self):
@@ -773,8 +810,12 @@ class EngineSevice:
self.exist_task_signal.clear()
self.exist_swapped_task_signal.clear()
self.worker_healthy_live_signal.clear()
self.cache_ready_signal.clear()
self.swap_space_ready_signal.clear()
self.exist_prefill_task_signal.clear()
self.model_weights_status_signal.clear()
self.prefix_tree_status_signal.clear()
self.kv_cache_status_signal.clear()
if hasattr(self, "send_response_server") and self.send_response_server is not None:
self.send_response_server.close()
if hasattr(self, "recv_request_server") and self.recv_request_server is not None:

View File

@@ -123,12 +123,12 @@ class LLMEngine:
llm_logger.info(f"Start zmq server, api_server_pid: {api_server_pid}")
self.engine.start_zmq_service(api_server_pid)
if self.do_profile == 0 and (
self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed"
):
# If block numer is specified and model is deployed in mixed mode, start cache manager first
if not self.do_profile and self.cfg.splitwise_role != "mixed":
device_ids = self.cfg.device_ids.split(",")
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix, True)
# Start workers
self.worker_proc = self._start_worker_service()
console_logger.info("Waiting worker processes ready...")
time.sleep(5)
@@ -154,11 +154,17 @@ class LLMEngine:
return False
time.sleep(1)
# If block number is not specified, let workers do profiling to determine the block number,
# and then start the cache manager
if self.do_profile:
self._stop_profile()
elif self.cfg.cache_config.enable_prefix_caching:
device_ids = self.cfg.device_ids.split(",")
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix, False)
# Launch components: scheduler, cache_manager, expert_service et.al.
self.launch_components()
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
if self.cfg.splitwise_role != "mixed":
self.launched_cache_manager_signal.value[0] = 1
# Worker launched
@@ -168,6 +174,23 @@ class LLMEngine:
return False
console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.")
# Print blocks number & max running requests to console
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
block_size = self.cfg.cache_config.block_size
num_gpu_blocks = self.cfg.cache_config.num_gpu_blocks_override or self.cfg.cache_config.total_block_num
num_cpu_blocks = self.cfg.cache_config.num_cpu_blocks
max_running_requests = min(
(num_gpu_blocks + num_cpu_blocks) * block_size // self.cfg.max_model_len, self.cfg.max_num_seqs
)
console_logger.info(
f"Detected {num_gpu_blocks} gpu blocks and {num_cpu_blocks} cpu blocks in cache (block size: {block_size})."
)
console_logger.info(
f"FastDeploy will be serving {max_running_requests} running requests "
f"if each sequence reaches its maximum length: {self.cfg.max_model_len}"
)
return True
def _get_generated_result(self):
@@ -580,7 +603,9 @@ class LLMEngine:
self.engine.resource_manager.reset_cache_config(self.cfg.cache_config)
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
device_ids = self.cfg.device_ids.split(",")
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
self.cache_manager_processes = self.engine.start_cache_service(
device_ids, self.ipc_signal_suffix, self.cfg.splitwise_role != "mixed"
)
def check_health(self, time_interval_threashold=30):
"""

View File

@@ -132,7 +132,6 @@ class ExpertService:
if hasattr(self, "cache_manager_processes"):
self.engine.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
self.engine.resource_manager.cache_manager.cache_ready_signal.clear()
for p in self.cache_manager_processes:
llm_logger.info(f"Killing cache manager process {p.pid}")
try:

View File

@@ -16,6 +16,7 @@
import inspect
import os
import threading
import time
import traceback
import uuid
@@ -27,7 +28,13 @@ from fastdeploy.config import ModelConfig
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
from fastdeploy.inter_communicator import (
IPCSignal,
KVCacheStatus,
ModelWeightsStatus,
PrefixTreeStatus,
ZmqIpcClient,
)
from fastdeploy.metrics.work_metrics import work_process_metrics
from fastdeploy.multimodal.registry import MultimodalRegistry
from fastdeploy.platforms import current_platform
@@ -55,6 +62,8 @@ class EngineClient:
enable_logprob=False,
workers=1,
tool_parser=None,
enable_prefix_caching=None,
splitwise_role=None,
):
import fastdeploy.model_executor.models # noqa: F401
@@ -76,6 +85,8 @@ class EngineClient:
self.reasoning_parser = reasoning_parser
self.data_processor = input_processor.create_processor()
self.max_model_len = max_model_len
self.enable_prefix_caching = enable_prefix_caching
self.enable_splitwise = splitwise_role != "mixed"
max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
if tensor_parallel_size <= max_chips_per_node:
@@ -101,10 +112,27 @@ class EngineClient:
suffix=port,
create=False,
)
prefix_tree_status = np.zeros([1], dtype=np.int32)
self.prefix_tree_status_signal = IPCSignal(
name="prefix_tree_status",
array=prefix_tree_status,
dtype=np.int32,
suffix=port,
create=False,
)
kv_cache_status = np.zeros([1], dtype=np.int32)
self.kv_cache_status_signal = IPCSignal(
name="kv_cache_status",
array=kv_cache_status,
dtype=np.int32,
suffix=port,
create=False,
)
self.connection_manager = DealerConnectionManager(
pid, max_connections=int(os.getenv("FD_DEALER_CONNECTIONS", 50))
)
self.connection_initialized = False
self.clear_update_lock = threading.Lock()
def create_zmq_client(self, model, mode):
"""
@@ -310,7 +338,7 @@ class EngineClient:
Check the health of the model server by checking whether all workers are alive.
"""
if self.model_weights_status_signal.value[0] == 0:
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL:
return True, ""
else:
return False, "No model weight enabled"
@@ -321,21 +349,42 @@ class EngineClient:
1 : worker receive the signal and start to update model weight
2 : worker update finish and notify client
"""
if self.model_weights_status_signal.value[0] == 0:
return True, ""
if self.model_weights_status_signal.value[0] == 1:
return False, "updating model weight already"
with self.clear_update_lock:
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL:
return True, ""
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING:
return False, "updating model weight already"
self.model_weights_status_signal.value[0] = 1
api_server_logger.info(f"start update model weight {self.model_weights_status_signal.value}")
while self.model_weights_status_signal.value[0] != 0 and timeout != 0:
self.model_weights_status_signal.value[0] = ModelWeightsStatus.UPDATING
if self.enable_prefix_caching or self.enable_splitwise:
self.kv_cache_status_signal.value[0] = KVCacheStatus.UPDATING
if self.enable_prefix_caching:
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.UPDATING
api_server_logger.info(f"start update model weight {self.model_weights_status_signal.value}")
all_updated = False
while timeout >= 0 and not all_updated:
api_server_logger.info(
f"Updating model weights.. "
f"model_weights_status: {self.model_weights_status_signal.value[0]}, "
f"prefix_tree_status: {self.prefix_tree_status_signal.value[0]}, "
f"kv_cache_status: {self.kv_cache_status_signal.value[0]} "
)
weight_updated = self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL
cache_updated = self.kv_cache_status_signal.value[0] == KVCacheStatus.NORMAL
prefix_updated = self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.NORMAL
if self.enable_prefix_caching or self.enable_splitwise:
if self.enable_prefix_caching:
all_updated = weight_updated and cache_updated and prefix_updated
else:
all_updated = weight_updated and cache_updated
else:
all_updated = weight_updated
time.sleep(1)
timeout -= 1
if timeout < 0:
return False, "Update model weight timeout"
time.sleep(1)
timeout -= 1
continue
if self.model_weights_status_signal.value[0] != 0:
return False, "Update model weight timeout"
time.sleep(1)
return True, ""
return True, ""
def clear_load_weight(self, timeout=300):
"""
@@ -343,19 +392,42 @@ class EngineClient:
-1 : worker receive the signal and start to clear model weight
-2 : worker clear finish and notify client
"""
if self.model_weights_status_signal.value[0] == -2:
return True, ""
if self.model_weights_status_signal.value[0] == -1:
return False, "clearing model weight already"
self.model_weights_status_signal.value[0] = -1
with self.clear_update_lock:
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED:
return True, ""
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING:
return False, "clearing model weight already"
api_server_logger.info(f"start clear model weight {self.model_weights_status_signal.value}")
while self.model_weights_status_signal.value[0] != -2 and timeout != 0:
self.model_weights_status_signal.value[0] = ModelWeightsStatus.CLEARING
if self.enable_prefix_caching or self.enable_splitwise:
self.kv_cache_status_signal.value[0] = KVCacheStatus.CLEARING
if self.enable_prefix_caching:
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARING
api_server_logger.info(f"start clear model weight {self.model_weights_status_signal.value}")
all_cleared = False
while timeout >= 0 and not all_cleared:
api_server_logger.info(
f"Clearing model weights.. "
f"model_weights_status: {self.model_weights_status_signal.value[0]}, "
f"prefix_tree_status: {self.prefix_tree_status_signal.value[0]}, "
f"kv_cache_status: {self.kv_cache_status_signal.value[0]} "
)
weight_cleared = self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED
cache_cleared = self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARED
prefix_cleared = self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARED
if self.enable_prefix_caching or self.enable_splitwise:
if self.enable_prefix_caching:
all_cleared = weight_cleared and cache_cleared and prefix_cleared
else:
all_cleared = weight_cleared and cache_cleared
else:
all_cleared = weight_cleared
time.sleep(1)
timeout -= 1
if timeout < 0:
return False, "Clear model weight timeout"
time.sleep(1)
timeout -= 1
continue
if self.model_weights_status_signal.value[0] != -2:
return False, "clear model weight timeout"
time.sleep(1)
return True, ""
return True, ""

View File

@@ -170,6 +170,8 @@ async def lifespan(app: FastAPI):
enable_logprob=args.enable_logprob,
workers=args.workers,
tool_parser=args.tool_call_parser,
enable_prefix_caching=args.enable_prefix_caching,
splitwise_role=args.splitwise_role,
)
await engine_client.connection_manager.initialize()
app.state.dynamic_load_weight = args.dynamic_load_weight

View File

@@ -103,6 +103,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"),
# LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
"FD_ENABLE_SWAP_SPACE_CLEARING": lambda: int(os.getenv("FD_ENABLE_SWAP_SPACE_CLEARING", "0")),
}

View File

@@ -17,6 +17,12 @@
from .engine_cache_queue import EngineCacheQueue
from .engine_worker_queue import EngineWorkerQueue
from .ipc_signal import IPCSignal, shared_memory_exists
from .ipc_signal_const import (
ExistTaskStatus,
KVCacheStatus,
ModelWeightsStatus,
PrefixTreeStatus,
)
from .zmq_client import ZmqIpcClient
from .zmq_server import ZmqIpcServer, ZmqTcpServer
@@ -28,4 +34,8 @@ __all__ = [
"ZmqTcpServer",
"ZmqIpcServer",
"shared_memory_exists",
"ExistTaskStatus",
"PrefixTreeStatus",
"ModelWeightsStatus",
"KVCacheStatus",
]

View File

@@ -0,0 +1,32 @@
from dataclasses import dataclass
@dataclass
class ModelWeightsStatus:
NORMAL = 0
UPDATING = 1
CLEARING = -1
CLEARED = -2
@dataclass
class PrefixTreeStatus:
NORMAL = 0
UPDATING = 1
CLEARING = -1
CLEARED = -2
@dataclass
class KVCacheStatus:
NORMAL = 0
UPDATING = 1
CLEARING = -1
CLEARED = -2
@dataclass
class ExistTaskStatus:
EMPTY = 0
EXIST = 1
REFUSE = 2

View File

@@ -25,6 +25,7 @@ from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.inter_communicator import ModelWeightsStatus
class DynamicWeightManager:
@@ -59,6 +60,7 @@ class DynamicWeightManager:
def update_parameters(self, pid: int = 0) -> None:
"""Core method to update model parameters based on strategy."""
logger.info(f"start update paramaters: suffix={pid} rank={self.rank}")
start_time = time.perf_counter()
paddle.device.cuda.empty_cache()
@@ -106,7 +108,7 @@ class DynamicWeightManager:
def clear_parameters(self, pid: int = 0) -> None:
"""Clear all model parameters and free memory."""
logger.info("start clear paramaters")
logger.info(f"start clear paramaters: suffix={pid} rank={self.rank}")
paddle.device.cuda.empty_cache()
for param in self.model.state_dict().values():
param._clear_data()
@@ -119,7 +121,7 @@ class DynamicWeightManager:
paddle.distributed.barrier(self.parallel_config.ep_group)
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
paddle.distributed.shutdown_process_group()
self._update_shared_status(pid, -2)
self._update_shared_status(pid, ModelWeightsStatus.CLEARED)
def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor], src_type: str):
"""Update model parameters from given state dictionary."""
@@ -150,7 +152,7 @@ class DynamicWeightManager:
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.barrier(self.parallel_config.tp_group)
if not self.first_load:
self._update_shared_status(pid, 0)
self._update_shared_status(pid, ModelWeightsStatus.NORMAL)
self.first_load = False
def _get_gpu_id(self) -> int:
@@ -217,20 +219,20 @@ class DynamicWeightManager:
"""
check model weights status
"""
logger.info(f"dynamic weight manager is check model weights status! {model_weights_status.value[0]}")
is_stop = 0
while model_weights_status.value[0] != 0:
if model_weights_status.value[0] == 1:
while model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
if model_weights_status.value[0] == ModelWeightsStatus.UPDATING:
logger.info("infer engine stopped! start to load new checkpoint...")
model_runner.update_parameters(pid)
elif model_weights_status.value[0] == -1:
elif model_weights_status.value[0] == ModelWeightsStatus.CLEARING:
logger.info("infer engine stopped! start to clear checkpoint...")
model_runner.clear_parameters(pid)
while True:
if model_weights_status.value[0] == 0:
if model_weights_status.value[0] == ModelWeightsStatus.NORMAL:
logger.info("finished loading new checkpoint")
break
elif is_stop == 1 or (model_weights_status.value[0] == -2 and is_stop == 0):
elif is_stop == 1 or (model_weights_status.value[0] == ModelWeightsStatus.CLEARED and is_stop == 0):
if is_stop == 0:
logger.info("finished clearing checkpoint")
is_stop = 1

View File

@@ -59,6 +59,7 @@ else:
recover_decode_task,
set_value_by_flags_and_idx,
share_external_data,
set_data_ipc,
)
from fastdeploy.model_executor.pre_and_post_process import (
@@ -73,6 +74,7 @@ if not (current_platform.is_dcu() or current_platform.is_iluvatar()):
from fastdeploy import envs
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
from fastdeploy.worker.model_runner_base import ModelRunnerBase
@@ -978,7 +980,7 @@ class GPUModelRunner(ModelRunnerBase):
"""
Initialize kv cache
"""
cache_kvs = {}
# cache_kvs = {}
max_block_num = self.num_gpu_blocks
# Get kv cache dtype
@@ -999,34 +1001,50 @@ class GPUModelRunner(ModelRunnerBase):
)
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
cache_kvs_list = []
for i in range(self.model_config.num_hidden_layers):
key_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
cache_kvs_list.append(key_cache)
value_cache = paddle.empty(shape=[], dtype=cache_type)
value_cache = share_external_data(value_cache, val_cache_name, kv_cache_shape)
cache_kvs_list.append(value_cache)
cache_ready_signal_data = np.zeros(shape=[self.parallel_config.tensor_parallel_size], dtype=np.int32)
cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
# Check if gpu runner needs to create kv cache
# 1. During profiling, it creates its own kv cache.
# 2. GPU runner creates kv cache tensor unless p/d disaggregation is enabled.
create_cache_tensor = profile or self.parallel_config.splitwise_role == "mixed"
if not create_cache_tensor:
logger.info("Waiting for cache managers to create kv cache..")
while cache_ready_signal.value[self.local_rank] != 1:
time.sleep(1)
logger.info("OK! Stop waiting.")
logger.info("Initializing kv cache for all layers.")
cache_kvs_list = []
for i in range(self.model_config.num_hidden_layers):
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
if create_cache_tensor:
logger.info(f"..creating kv cache for layer {i}: {kv_cache_shape}")
key_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type)
val_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type)
set_data_ipc(key_cache, key_cache_name)
set_data_ipc(val_cache, val_cache_name)
else:
logger.info(f"..attaching kv cache for layer {i}: {kv_cache_shape}")
key_cache = paddle.empty(shape=[], dtype=cache_type)
val_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
val_cache = share_external_data(val_cache, val_cache_name, kv_cache_shape)
cache_kvs_list.extend([key_cache, val_cache])
self.share_inputs["caches"] = cache_kvs_list
if not profile and create_cache_tensor:
logger.info("✅ kv cache is ready!")
cache_ready_signal.value[self.local_rank] = 1
self.share_inputs["caches"] = cache_kvs_list
else:
for i in range(self.model_config.num_hidden_layers):
cache_kvs[f"key_caches_{i}"] = paddle.full(
shape=kv_cache_shape,
fill_value=0,
dtype=cache_type,
)
cache_kvs[f"value_caches_{i}"] = paddle.full(
shape=kv_cache_shape,
fill_value=0,
dtype=cache_type,
)
self.share_inputs["caches"] = list(cache_kvs.values())
for value in cache_kvs.values():
del value
paddle.device.cuda.empty_cache()
def initialize_attn_backend(self) -> None:
@@ -1672,6 +1690,7 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs.pop("caches", None)
if self.forward_meta is not None:
self.forward_meta.clear_caches()
paddle.device.cuda.empty_cache()
def clear_parameters(self, pid):
""" " Dynamic model loader use to clear parameters use for RL"""

View File

@@ -41,7 +41,7 @@ from fastdeploy.config import (
)
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.inter_communicator import ExistTaskStatus, IPCSignal, ModelWeightsStatus
from fastdeploy.model_executor.layers.quantization import get_quantization_config
from fastdeploy.platforms import current_platform
from fastdeploy.utils import get_logger, parse_quantization
@@ -175,7 +175,7 @@ class PaddleDisWorkerProc:
name="launched_expert_service_signal",
array=launched_expert_service_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
while self.launched_expert_service_signal.value[self.local_rank % self.max_chips_per_node] == 0:
@@ -192,7 +192,7 @@ class PaddleDisWorkerProc:
name="worker_ready_signal",
array=workers_ready,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
self.worker_ready_signal.value[self.local_rank % self.max_chips_per_node] = 1
@@ -260,7 +260,7 @@ class PaddleDisWorkerProc:
self.model_weights_signal = paddle.zeros([1], dtype=paddle.int32)
while True:
if self.local_rank % self.parallel_config.tensor_parallel_size == 0:
if self.model_weights_status.value[0] != 0:
if self.model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel:
paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.ep_group)
@@ -281,7 +281,7 @@ class PaddleDisWorkerProc:
if self.nnode > 1 and self.parallel_config.tensor_parallel_size > self.max_chips_per_node:
self.task_queue.read_finish_flag.set(1)
else:
self.exist_task_signal.value[0] = 1
self.exist_task_signal.value[0] = ExistTaskStatus.EXIST
if self.parallel_config.tensor_parallel_size > 1:
# Synchronize the signal for other workers
@@ -292,7 +292,7 @@ class PaddleDisWorkerProc:
paddle.distributed.barrier(self.parallel_config.ep_group)
else:
paddle.distributed.barrier(self.parallel_config.tp_group)
if self.model_weights_signal[0] != 0:
if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL:
logger.info(f"Rank: {self.local_rank} has updated parameters.")
from fastdeploy.rl.dynamic_weight_manager import (
DynamicWeightManager,
@@ -304,16 +304,16 @@ class PaddleDisWorkerProc:
self.worker.model_runner,
self.parallel_config.engine_worker_queue_port,
)
self.model_weights_signal[0] = 0
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
if self.exist_task_signal.value[0] == 1 or self.task_queue.read_finish_flag.get() == 1:
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
logger.info(f"Rank: {self.local_rank} Detected new requests.")
self.insert_step = True
tasks, read_finish = self.task_queue.get_tasks()
if read_finish:
# Ensure that every worker get the task
self.exist_task_signal.value[0] = 0
self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY
self.task_queue.read_finish_flag.set(0)
req_dicts = []
@@ -389,7 +389,7 @@ class PaddleDisWorkerProc:
name="get_profile_block_num",
array=get_profile_block_num,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
self.get_profile_block_num_signal.value[0] = num_blocks_local
@@ -397,18 +397,7 @@ class PaddleDisWorkerProc:
num_blocks_local = self.fd_config.parallel_config.total_block_num
logger.info(f"------- num_blocks_global: {num_blocks_local} --------")
# wait engine launch cache_manager
if self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed":
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
self.launched_cache_manager_signal = IPCSignal(
name="launched_cache_manager_signal",
array=launched_cache_manager_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False,
)
while np.any(self.launched_cache_manager_signal.value[0] <= 0):
time.sleep(0.01)
# 4. init kv_cache with accurate num_blocks
self.worker.initialize_cache(num_gpu_blocks=num_blocks_local)
@@ -443,7 +432,7 @@ class PaddleDisWorkerProc:
name="loaded_model_signal",
array=loaded_model_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
if self.ranks > 1: