[feat] support prefix cache clearing when /clear_load_weight is called (#4008)

* [feat] support clearing prefix cache (cherry-picked from release/2.1)

* [fix] fix ipc suffix, use port instead

* [fix] fix prefix caching not enabled

* [fix] fix key/value_cache_scales indent

* [fix] fix ep group all-reduce

* [fix] fix clear/update lock not working when workers > 1

* [chore] add preemption triggered info log

* [fix] fix code style

* [fix] fix max_num_seqs config

* [fix] do not force enable_prefix_caching=False in dynamic loading

* [fix] fix ci

* Revert "[fix] fix ci"

This reverts commit 0bc6d55cc8.

* [fix] initialize available_gpu_block_num with max_gpu_block_num

* [fix] fix config splitwise_role

* [fix] fix clearing caches synchronization and add more logs

* [chore] print cache_ready_signal in log

* [fix] fix scheduler_config.splitwise_role

* [fix] fix cache_messager cache_ready_signal create=True

* [fix] stop cache messager from launching in mixed deployment
This commit is contained in:
李泳桦
2025-09-28 19:42:53 +08:00
committed by GitHub
parent 59313ed7f9
commit 6265f4385f
20 changed files with 697 additions and 213 deletions

View File

@@ -0,0 +1,71 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "cuda_multiprocess.h"
#if !defined(_WIN32)
#include <errno.h>
#include <string.h>
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#endif
// 可选:仅删除/解除共享内存命名对象(不依赖之前保存的 addr/fd
static inline int sharedMemoryUnlinkByName(const char* name) {
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
// Windows 上没有 shm_unlink 语义。命名对象在最后一个句柄关闭后消失。
// 这里做“尽力而为”:尝试打开后立即关闭,减少一次引用。
HANDLE hMap = OpenFileMappingA(FILE_MAP_ALL_ACCESS, FALSE, name);
if (hMap) {
CloseHandle(hMap);
return 0;
}
// 已经不存在也算成功
return 0;
#else
// POSIX: 移除名字,未来不可再 open已映射区仍存活直至 munmap
if (shm_unlink(name) != 0) {
if (errno == ENOENT) return 0; // 不存在视作成功
return errno;
}
return 0;
#endif
}
void UnsetDataIpc(const paddle::Tensor& tmp_input,
const std::string& shm_name,
bool close_ipc,
bool unlink_shm) {
// 1) 关闭消费者导入的 IPC 映射(仅当 close_ipc=true 且该指针确为 OpenMemHandle 得来)
if (close_ipc) {
void* ptr = const_cast<void*>(tmp_input.data());
checkCudaErrors(cudaIpcCloseMemHandle(ptr));
}
// 2) 解除共享内存命名对象(仅处理“名字”,不保证解除旧映射)
if (unlink_shm) {
int rc = sharedMemoryUnlinkByName(shm_name.c_str());
if (rc != 0) {
PD_THROW("Unlink shared memory failed: name=%s, err=%d",
shm_name.c_str(), rc);
}
}
}
PD_BUILD_STATIC_OP(unset_data_ipc)
.Inputs({"tmp_input"})
.Attrs({"shm_name: std::string", "close_ipc: bool", "unlink_shm: bool"})
.SetKernelFn(PD_KERNEL(UnsetDataIpc));

View File

@@ -208,6 +208,7 @@ if paddle.is_compiled_with_rocm():
"gpu_ops/rebuild_padding.cu", "gpu_ops/rebuild_padding.cu",
"gpu_ops/step.cu", "gpu_ops/step.cu",
"gpu_ops/set_data_ipc.cu", "gpu_ops/set_data_ipc.cu",
"gpu_ops/unset_data_ipc.cu",
"gpu_ops/moe/tritonmoe_preprocess.cu", "gpu_ops/moe/tritonmoe_preprocess.cu",
"gpu_ops/step_system_cache.cu", "gpu_ops/step_system_cache.cu",
"gpu_ops/get_output_ep.cc", "gpu_ops/get_output_ep.cc",
@@ -278,6 +279,7 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/beam_search_softmax.cu", "gpu_ops/beam_search_softmax.cu",
"gpu_ops/rebuild_padding.cu", "gpu_ops/rebuild_padding.cu",
"gpu_ops/set_data_ipc.cu", "gpu_ops/set_data_ipc.cu",
"gpu_ops/unset_data_ipc.cu",
"gpu_ops/read_data_ipc.cu", "gpu_ops/read_data_ipc.cu",
"gpu_ops/enforce_generation.cu", "gpu_ops/enforce_generation.cu",
"gpu_ops/dequant_int8.cu", "gpu_ops/dequant_int8.cu",

View File

@@ -152,8 +152,8 @@ class CacheMessager:
cache_v = [] cache_v = []
self.messager = {} self.messager = {}
for layer_idx in range(self.num_layers): for layer_idx in range(self.num_layers):
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"] key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}.device{gpu_id}"]
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"] val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}.device{gpu_id}"]
cache_k.append(key_cache) cache_k.append(key_cache)
cache_v.append(val_cache) cache_v.append(val_cache)
cache_k_ptr_list.append(key_cache.data_ptr()) cache_k_ptr_list.append(key_cache.data_ptr())

View File

@@ -16,21 +16,27 @@
import argparse import argparse
import concurrent.futures import concurrent.futures
import gc
import json import json
import queue import queue
import threading
import time import time
import traceback import traceback
import numpy as np import numpy as np
import paddle import paddle
from fastdeploy import envs
from fastdeploy.cache_manager.cache_data import CacheStatus from fastdeploy.cache_manager.cache_data import CacheStatus
from fastdeploy.config import SpeculativeConfig from fastdeploy.config import SpeculativeConfig
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
from fastdeploy.model_executor.ops.gpu import ( from fastdeploy.model_executor.ops.gpu import (
cuda_host_alloc, cuda_host_alloc,
cuda_host_free,
set_data_ipc,
share_external_data, share_external_data,
swap_cache_all_layers, swap_cache_all_layers,
unset_data_ipc,
) )
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger
@@ -93,6 +99,7 @@ def parse_args():
help="speculative config", help="speculative config",
) )
parser.add_argument("--local_data_parallel_id", type=int, default=0) parser.add_argument("--local_data_parallel_id", type=int, default=0)
parser.add_argument("--create_cache_tensor", action="store_true")
args = parser.parse_args() args = parser.parse_args()
return args return args
@@ -110,7 +117,6 @@ class CacheTransferManager:
device = args.device_id device = args.device_id
rank = args.rank rank = args.rank
paddle.set_device(f"gpu:{device}")
self.gpu_cache_kvs = {} self.gpu_cache_kvs = {}
self.cpu_cache_kvs = {} self.cpu_cache_kvs = {}
self.gpu_cache_k_tensors = [] self.gpu_cache_k_tensors = []
@@ -126,6 +132,7 @@ class CacheTransferManager:
self.n_ranks = args.mp_num self.n_ranks = args.mp_num
self.rank = rank self.rank = rank
self.device = device self.device = device
self.engine_pid = args.engine_pid
address = (args.pod_ip, args.cache_queue_port) address = (args.pod_ip, args.cache_queue_port)
self.cache_task_queue = EngineCacheQueue( self.cache_task_queue = EngineCacheQueue(
@@ -136,57 +143,27 @@ class CacheTransferManager:
local_data_parallel_id=args.local_data_parallel_id, local_data_parallel_id=args.local_data_parallel_id,
) )
self.num_cpu_blocks = args.num_cpu_blocks
cache_type = args.cache_dtype
cache_shape = [
args.num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
]
for i in range(args.num_layers + self.num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
cache_shape[0] = num_gpu_blocks
key_name = f"key_caches_{i}_rank{rank}.device{device}"
value_name = f"value_caches_{i}_rank{rank}.device{device}"
key_cache = paddle.empty(shape=[], dtype=cache_type)
value_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache = share_external_data(key_cache, key_name, cache_shape)
value_cache = share_external_data(value_cache, value_name, cache_shape)
self.gpu_cache_kvs[key_name] = key_cache
self.gpu_cache_kvs[value_name] = value_cache
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[value_name])
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
logger.info(f"device :{self.device}")
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
paddle.set_device("cpu")
self.k_dst_ptrs = []
self.v_dst_ptrs = []
for i in range(args.num_layers + self.num_extra_layers):
self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"] = cuda_host_alloc(
args.num_cpu_blocks * args.bytes_per_layer_per_block
)
self.k_dst_ptrs.append(self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"])
self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"] = cuda_host_alloc(
args.num_cpu_blocks * args.bytes_per_layer_per_block
)
self.v_dst_ptrs.append(self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"])
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32) cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
self.cache_ready_signal = IPCSignal( self.cache_ready_signal = IPCSignal(
name="cache_ready_signal", name="cache_ready_signal",
array=cache_ready_signal_data, array=cache_ready_signal_data,
dtype=np.int32, dtype=np.int32,
suffix=args.engine_pid, suffix=self.engine_pid,
create=False, create=False,
) )
self.cache_ready_signal.value[self.rank] = 1 swap_space_ready_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
self.swap_space_ready_signal = IPCSignal(
name="swap_space_ready_signal",
array=swap_space_ready_data,
dtype=np.int32,
suffix=self.engine_pid,
create=False,
)
self.num_cpu_blocks = args.num_cpu_blocks
self._init_cpu_cache(args)
self._init_gpu_cache(args)
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32) cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
self.cache_task_broadcast_signal = IPCSignal( self.cache_task_broadcast_signal = IPCSignal(
@@ -197,6 +174,76 @@ class CacheTransferManager:
create=False, create=False,
) )
threading.Thread(target=self.clear_or_update_caches, args=[args], daemon=True).start()
def _init_gpu_cache(self, args):
if not args.create_cache_tensor:
logger.info(f"[rank {self.rank}/{self.n_ranks}] Waiting for runners to create kv cache.")
while self.cache_ready_signal.value[self.rank] != 1:
time.sleep(0.1)
logger.info(f"[rank {self.rank}/{self.n_ranks}] OK! Stop waiting.")
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
paddle.set_device(f"gpu:{self.device}")
for i in range(args.num_layers + self.num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
cache_shape = [num_gpu_blocks, args.kv_num_head, args.block_size, args.head_dim]
key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}"
val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}"
if args.create_cache_tensor:
logger.info(f"[rank {self.rank}/{self.n_ranks}] ..creating kv cache for layer {i}: {cache_shape}")
key_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype)
val_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype)
set_data_ipc(key_cache, key_name)
set_data_ipc(val_cache, val_name)
else:
logger.info(f"[rank {self.rank}/{self.n_ranks}] ..attaching kv cache for layer {i}: {cache_shape}")
key_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
val_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
key_cache = share_external_data(key_cache, key_name, cache_shape)
val_cache = share_external_data(val_cache, val_name, cache_shape)
self.gpu_cache_kvs[key_name] = key_cache
self.gpu_cache_kvs[val_name] = val_cache
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[val_name])
if args.create_cache_tensor:
logger.info("[rank {self.rank}/{self.n_ranks}] ✅ kv cache is ready!")
self.cache_ready_signal.value[self.rank] = 1
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
logger.info(f"[rank {self.rank}/{self.n_ranks}] device :{self.device}")
logger.info(f"[rank {self.rank}/{self.n_ranks}] cache_kv_size_byte : {cache_kv_size_byte}")
logger.info(
f"[rank {self.rank}/{self.n_ranks}] done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}"
)
def _init_cpu_cache(self, args):
if args.num_cpu_blocks == 0:
logger.info(f"[rank {self.rank}/{self.n_ranks}] 💡 no swap space (cpu cache) is specified.")
self.swap_space_ready_signal.value[self.rank] = 1
return
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing swap space (cpu cache) for all layers.")
paddle.set_device("cpu")
self.k_dst_ptrs = []
self.v_dst_ptrs = []
for i in range(args.num_layers + self.num_extra_layers):
key_name = f"key_caches_{i}_rank{self.rank}"
val_name = f"value_caches_{i}_rank{self.rank}"
need_to_allocate_bytes = args.num_cpu_blocks * args.bytes_per_layer_per_block
logger.info(
f"[rank {self.rank}/{self.n_ranks}] ..creating cpu cache for layer {i}: {2 * need_to_allocate_bytes / 1024 ** 3:.2f}GB"
)
self.cpu_cache_kvs[key_name] = cuda_host_alloc(need_to_allocate_bytes)
self.k_dst_ptrs.append(self.cpu_cache_kvs[key_name])
self.cpu_cache_kvs[val_name] = cuda_host_alloc(need_to_allocate_bytes)
self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name])
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!")
self.swap_space_ready_signal.value[self.rank] = 1
def _do_swap_to_cpu_task( def _do_swap_to_cpu_task(
self, self,
swap_node_ids, swap_node_ids,
@@ -394,6 +441,92 @@ class CacheTransferManager:
transfer_task_id, transfer_task_id,
) )
def clear_or_update_caches(self, args):
logger.info("Start a thread to clear/restore kv cache when model weights are cleared/updated.")
logger.info(f"FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}")
kv_cache_status = np.zeros([1], dtype=np.int32)
kv_cache_status_signal = IPCSignal(
name="kv_cache_status",
array=kv_cache_status,
dtype=np.int32,
suffix=self.engine_pid,
create=False,
)
while True:
if kv_cache_status_signal.value[0] == KVCacheStatus.CLEARING:
try:
logger.info(
f"[rank {self.rank}/{self.n_ranks}] Start clearing caches {self.cache_ready_signal.value}"
)
# clear cpu caches
if envs.FD_ENABLE_SWAP_SPACE_CLEARING:
paddle.set_device("cpu")
for ptrs in self.k_dst_ptrs + self.v_dst_ptrs:
cuda_host_free(ptrs)
self.cpu_cache_kvs.clear()
self.k_dst_ptrs.clear()
self.v_dst_ptrs.clear()
gc.collect()
# reset swap_space_ready_signal
self.swap_space_ready_signal.value[self.rank] = 0
while np.sum(self.swap_space_ready_signal.value) != 0:
time.sleep(0.1)
# clear gpu caches
paddle.set_device(f"gpu:{self.device}")
for name, tensor in self.gpu_cache_kvs.items():
unset_data_ipc(tensor, name, True, False)
self.gpu_cache_kvs.clear()
self.gpu_cache_k_tensors.clear()
self.gpu_cache_v_tensors.clear()
# reset cache_ready_signal
self.cache_ready_signal.value[self.rank] = 0
logger.info(
f"[rank {self.rank}/{self.n_ranks}] Finish clearing caches {self.cache_ready_signal.value}"
)
# wait for all ranks caches to be cleared
if np.sum(self.cache_ready_signal.value) != 0:
time.sleep(0.1)
# reset kv_cache_status_signal
kv_cache_status_signal.value[0] = KVCacheStatus.CLEARED
logger.info("All ranks finish clearing caches")
except Exception as e:
logger.error(f"[rank {self.rank}/{self.n_ranks}] Failed to clear caches: {e}")
elif kv_cache_status_signal.value[0] == KVCacheStatus.UPDATING:
try:
logger.info(
f"[rank {self.rank}/{self.n_ranks}] Start restoring caches {self.cache_ready_signal.value}"
)
# restore cpu cache
if envs.FD_ENABLE_SWAP_SPACE_CLEARING:
self._init_cpu_cache(args)
while np.sum(self.swap_space_ready_signal.value) != args.mp_num:
time.sleep(0.1)
# restore gpu cache and set cache_ready_signal
self._init_gpu_cache(args)
logger.info(
f"[rank {self.rank}/{self.n_ranks}] Finish restoring caches {self.cache_ready_signal.value}"
)
# wait for all ranks caches to be ready
while np.sum(self.cache_ready_signal.value) != args.mp_num:
time.sleep(0.1)
# set kv_cache_status_signal
logger.info("All ranks finish restoring caches")
kv_cache_status_signal.value[0] = KVCacheStatus.NORMAL
except Exception as e:
logger.error(f"[rank {self.rank}/{self.n_ranks}] Failed to restore caches: {e}")
time.sleep(0.1)
def main(): def main():
""" """

View File

@@ -31,7 +31,7 @@ import numpy as np
from fastdeploy import envs from fastdeploy import envs
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
from fastdeploy.cache_manager.cache_metrics import CacheMetrics from fastdeploy.cache_manager.cache_metrics import CacheMetrics
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus
from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger
@@ -71,6 +71,7 @@ class PrefixCacheManager:
else: else:
self.num_gpu_blocks = self.cache_config.prefill_kvcache_block_num self.num_gpu_blocks = self.cache_config.prefill_kvcache_block_num
self.num_cpu_blocks = self.cache_config.num_cpu_blocks self.num_cpu_blocks = self.cache_config.num_cpu_blocks
self.gpu_free_block_list = list(range(self.num_gpu_blocks - 1, -1, -1)) self.gpu_free_block_list = list(range(self.num_gpu_blocks - 1, -1, -1))
if self.num_cpu_blocks > 0: if self.num_cpu_blocks > 0:
self.cpu_free_block_list = list(range(self.num_cpu_blocks - 1, -1, -1)) self.cpu_free_block_list = list(range(self.num_cpu_blocks - 1, -1, -1))
@@ -78,6 +79,7 @@ class PrefixCacheManager:
self.cpu_free_block_list = [] self.cpu_free_block_list = []
heapq.heapify(self.gpu_free_block_list) heapq.heapify(self.gpu_free_block_list)
heapq.heapify(self.cpu_free_block_list) heapq.heapify(self.cpu_free_block_list)
self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks)) self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks))
self.radix_tree_root = BlockNode(-1, [], 0, 0, -1, 0, None, None, None) self.radix_tree_root = BlockNode(-1, [], 0, 0, -1, 0, None, None, None)
@@ -111,6 +113,10 @@ class PrefixCacheManager:
+ f"{self.num_cpu_blocks}, bytes_per_layer_per_block {self.cache_config.bytes_per_layer_per_block}" + f"{self.num_cpu_blocks}, bytes_per_layer_per_block {self.cache_config.bytes_per_layer_per_block}"
) )
main_process_metrics.max_gpu_block_num.set(self.num_gpu_blocks)
main_process_metrics.available_gpu_block_num.set(self.num_gpu_blocks)
main_process_metrics.available_gpu_resource.set(1.0)
@property @property
def available_gpu_resource(self): def available_gpu_resource(self):
return len(self.gpu_free_block_list) / self.num_gpu_blocks if self.num_gpu_blocks > 0 else 0.0 return len(self.gpu_free_block_list) / self.num_gpu_blocks if self.num_gpu_blocks > 0 else 0.0
@@ -123,6 +129,7 @@ class PrefixCacheManager:
pod_ip, pod_ip,
engine_worker_queue_port, engine_worker_queue_port,
pid_suffix, pid_suffix,
create_cache_tensor,
): ):
""" """
launch_cache_manager function used to initialize the cache manager. launch_cache_manager function used to initialize the cache manager.
@@ -133,7 +140,7 @@ class PrefixCacheManager:
name="cache_task_broadcast_signal", name="cache_task_broadcast_signal",
array=broadcast_cache_task_flag_array, array=broadcast_cache_task_flag_array,
dtype=np.int32, dtype=np.int32,
suffix=pid_suffix, suffix=engine_worker_queue_port,
create=True, create=True,
) )
@@ -151,6 +158,7 @@ class PrefixCacheManager:
py_path = os.path.join(current_dir_path, filename) py_path = os.path.join(current_dir_path, filename)
cache_messager_processes = [] cache_messager_processes = []
if self.enable_splitwise:
cache_messager_processes = self.launch_cache_messager( cache_messager_processes = self.launch_cache_messager(
cache_config, cache_config,
tensor_parallel_size, tensor_parallel_size,
@@ -173,20 +181,41 @@ class PrefixCacheManager:
else: else:
kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size
kv_num_head = max(1, kv_num_head) kv_num_head = max(1, kv_num_head)
cache_ready_signal_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32) cache_ready_signal_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
self.cache_ready_signal = IPCSignal( self.cache_ready_signal = IPCSignal(
name="cache_ready_signal", name="cache_ready_signal",
array=cache_ready_signal_data, array=cache_ready_signal_data,
dtype=np.int32, dtype=np.int32,
suffix=pid_suffix, suffix=engine_worker_queue_port,
create=True, create=False,
) )
swap_space_ready_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
self.swap_space_ready_signal = IPCSignal(
name="swap_space_ready_signal",
array=swap_space_ready_data,
dtype=np.int32,
suffix=engine_worker_queue_port,
create=False,
)
prefix_tree_status = np.zeros([1], dtype=np.int32)
self.prefix_tree_status_signal = IPCSignal(
name="prefix_tree_status",
array=prefix_tree_status,
dtype=np.int32,
suffix=engine_worker_queue_port,
create=False,
)
# Run command to launch cache transfer managers
logger.info(f"create_cache_tensor: {create_cache_tensor}")
log_dir = envs.FD_LOG_DIR log_dir = envs.FD_LOG_DIR
cache_manager_processes = [] cache_manager_processes = []
for i in range(tensor_parallel_size): for i in range(tensor_parallel_size):
launch_cmd = ( launch_cmd = (
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7" "FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0" + " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
+ f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}"
+ f" {sys.executable} {py_path}" + f" {sys.executable} {py_path}"
+ f" --device_id {int(device_ids[i])}" + f" --device_id {int(device_ids[i])}"
+ f" --rank {i}" + f" --rank {i}"
@@ -209,23 +238,33 @@ class PrefixCacheManager:
+ f" --local_data_parallel_id {self.local_data_parallel_id}" + f" --local_data_parallel_id {self.local_data_parallel_id}"
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}" + f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'" + f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ (" --create_cache_tensor" if create_cache_tensor else "")
+ f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1" + f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1"
) )
logger.info(f"Launch cache transfer manager, command:{launch_cmd}") logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid)) cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
# 等待cache初始化完毕
logger.info("Waiting for cache transfer manager ready...") logger.info("PrefixCacheManager is waiting for kv cache to be initialized.")
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size: while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
time.sleep(1) time.sleep(1)
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
while np.sum(self.swap_space_ready_signal.value) != tensor_parallel_size:
time.sleep(1)
exit_code = cache_manager_processes[-1].poll() exit_code = cache_manager_processes[-1].poll()
if exit_code is None: if exit_code is None:
logger.info("Launch cache transfer manager successful") logger.info("Launch cache transfer manager successful")
else: else:
logger.info("Launch cache transfer manager failed, see launch_cache_manager.log for more information") logger.info("Launch cache transfer manager failed, see launch_cache_manager.log for more information")
# Start additional threads
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0: if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
logger.info("Enable hierarchical cache.") logger.info("Enable hierarchical cache.")
self._enable_cpu_cache() threading.Thread(target=self.recv_data_transfer_result).start()
if cache_config.enable_prefix_caching:
threading.Thread(target=self.clear_prefix_cache, daemon=True).start()
all_cache_processes = cache_messager_processes + cache_manager_processes all_cache_processes = cache_messager_processes + cache_manager_processes
return all_cache_processes return all_cache_processes
@@ -253,7 +292,7 @@ class PrefixCacheManager:
array=cache_ready_signal_data, array=cache_ready_signal_data,
dtype=np.int32, dtype=np.int32,
suffix=pid_suffix, suffix=pid_suffix,
create=True, create=False,
) )
py_path = os.path.join(current_dir_path, filename) py_path = os.path.join(current_dir_path, filename)
@@ -286,6 +325,7 @@ class PrefixCacheManager:
) )
logger.info(f"Launch cache messager, command:{launch_cmd}") logger.info(f"Launch cache messager, command:{launch_cmd}")
cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid)) cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
logger.info("Waiting for cache ready...") logger.info("Waiting for cache ready...")
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size: while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
time.sleep(1) time.sleep(1)
@@ -317,23 +357,9 @@ class PrefixCacheManager:
self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks)) self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks))
main_process_metrics.max_gpu_block_num.set(self.num_gpu_blocks) main_process_metrics.max_gpu_block_num.set(self.num_gpu_blocks)
main_process_metrics.available_gpu_block_num.set(self.num_gpu_blocks)
main_process_metrics.available_gpu_resource.set(1.0) main_process_metrics.available_gpu_resource.set(1.0)
def _enable_cpu_cache(self):
"""
_enable_cpu_cache function used to enable cpu cache.
"""
# ipc_cache_queue_port = self.cache_config.cache_queue_port
# self.cache_task_queue = CacheQueueManager(
# rank=0,
# mp_num=tensor_parallel_size,
# port=ipc_cache_queue_port,
# )
# 开启获取传输任务结果的监听线程
self.transfer_recv_thread = threading.Thread(target=self.recv_data_transfer_result)
self.transfer_recv_thread.start()
def can_allocate_gpu_blocks(self, num_blocks: int): def can_allocate_gpu_blocks(self, num_blocks: int):
""" """
Check if num_blocks gpu blocks can be allocated. Check if num_blocks gpu blocks can be allocated.
@@ -1377,3 +1403,70 @@ class PrefixCacheManager:
except Exception as e: except Exception as e:
logger.warning(f"recv_data_transfer_result: error: {e}, {str(traceback.format_exc())}") logger.warning(f"recv_data_transfer_result: error: {e}, {str(traceback.format_exc())}")
raise e raise e
def reset(self):
"""
Reset the RadixTree.
"""
if len(self.node_map) == 0:
return
logger.info("Resetting the RadixTree!")
# wait for swap tasks to finish
if self.gpu_free_task_future is not None:
self.gpu_free_task_future.result()
self.gpu_free_task_future = None
for event in list(self.task_swapping_event.values()):
event.wait()
self.task_swapping_event.clear()
# clear node map
self.node_map.clear()
self.req_leaf_map.clear()
self.leaf_req_map.clear()
self.unfilled_req_block_map.clear()
self.cache_info.clear()
# reset gpu cache data structure
self.gpu_lru_leaf_heap.clear()
self.gpu_lru_leaf_set.clear()
# reset cpu cache data structure
self.cpu_lru_leaf_heap.clear()
self.cpu_lru_leaf_set.clear()
# reset gpu/cpu free block list
self.gpu_free_block_list = list(range(self.num_gpu_blocks - 1, -1, -1))
if self.num_cpu_blocks > 0:
self.cpu_free_block_list = list(range(self.num_cpu_blocks - 1, -1, -1))
else:
self.cpu_free_block_list = []
heapq.heapify(self.gpu_free_block_list)
heapq.heapify(self.cpu_free_block_list)
# reset node/tree
self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks))
self.radix_tree_root = BlockNode(-1, [], 0, 0, -1, 0, None, None, None)
# reset metrics
self.metrics.reset_metrics()
main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list))
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)
def clear_prefix_cache(self):
"""
If the model weights status is updating or clearing, reset prefix cache tree
"""
logger.info("Start a thread to clear prefix cache when model weights are cleared.")
prefix_tree_status_signal = self.prefix_tree_status_signal
while True:
if prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARING:
self.reset()
prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARED
logger.info("Prefix cache tree is cleared.")
if prefix_tree_status_signal.value[0] == PrefixTreeStatus.UPDATING:
prefix_tree_status_signal.value[0] = PrefixTreeStatus.NORMAL
logger.info("Prefix cache tree is updated.")
time.sleep(0.01)

View File

@@ -602,13 +602,11 @@ class ParallelConfig:
) )
) )
dist.collective._set_custom_gid(None) dist.collective._set_custom_gid(None)
# same ep group id # same ep group id
if self.enable_expert_parallel: if self.enable_expert_parallel:
dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset) dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset)
self.ep_group = dist.new_group(range(self.expert_parallel_size)) self.ep_group = dist.new_group(range(self.expert_parallel_size))
dist.collective._set_custom_gid(None) dist.collective._set_custom_gid(None)
logger.info( logger.info(
f"data_parallel_size: {self.data_parallel_size}, tensor_parallel_size: {self.tensor_parallel_size}, expert_parallel_size: {self.expert_parallel_size}, data_parallel_rank: {self.data_parallel_rank}, tensor_parallel_rank: {self.tensor_parallel_rank}, expert_parallel_rank: {self.expert_parallel_rank}, tp_group: {self.tp_group}." f"data_parallel_size: {self.data_parallel_size}, tensor_parallel_size: {self.tensor_parallel_size}, expert_parallel_size: {self.expert_parallel_size}, data_parallel_rank: {self.data_parallel_rank}, tensor_parallel_rank: {self.tensor_parallel_rank}, expert_parallel_rank: {self.expert_parallel_rank}, tp_group: {self.tp_group}."
) )

View File

@@ -403,6 +403,7 @@ class EngineArgs:
""" """
Post-initialization processing to set default tokenizer if not provided. Post-initialization processing to set default tokenizer if not provided.
""" """
if not self.tokenizer: if not self.tokenizer:
self.tokenizer = self.model self.tokenizer = self.model
if self.splitwise_role == "decode": if self.splitwise_role == "decode":
@@ -411,8 +412,8 @@ class EngineArgs:
self.enable_prefix_caching = False self.enable_prefix_caching = False
if not current_platform.is_cuda(): if not current_platform.is_cuda():
self.enable_prefix_caching = False self.enable_prefix_caching = False
if self.dynamic_load_weight: # if self.dynamic_load_weight:
self.enable_prefix_caching = False # self.enable_prefix_caching = False
if self.enable_logprob: if self.enable_logprob:
if self.speculative_config is not None: if self.speculative_config is not None:
raise NotImplementedError("Logprob does not support speculation_config.") raise NotImplementedError("Logprob does not support speculation_config.")

View File

@@ -188,6 +188,24 @@ class EngineService:
create=True, create=True,
) )
cache_ready_signal_data = np.zeros(shape=[self.cfg.parallel_config.tensor_parallel_size], dtype=np.int32)
self.cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
swap_space_ready_signal_data = np.zeros(shape=[self.cfg.parallel_config.tensor_parallel_size], dtype=np.int32)
self.swap_space_ready_signal = IPCSignal(
name="swap_space_ready_signal",
array=swap_space_ready_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
model_weights_status = np.zeros([1], dtype=np.int32) model_weights_status = np.zeros([1], dtype=np.int32)
self.model_weights_status_signal = IPCSignal( self.model_weights_status_signal = IPCSignal(
name="model_weights_status", name="model_weights_status",
@@ -197,6 +215,24 @@ class EngineService:
create=True, create=True,
) )
prefix_tree_status = np.zeros([1], dtype=np.int32)
self.prefix_tree_status_signal = IPCSignal(
name="prefix_tree_status",
array=prefix_tree_status,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
kv_cache_status = np.zeros([1], dtype=np.int32)
self.kv_cache_status_signal = IPCSignal(
name="kv_cache_status",
array=kv_cache_status,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
def start_worker_queue_service(self, start_queue): def start_worker_queue_service(self, start_queue):
""" """
start queue service for engine worker communication start queue service for engine worker communication
@@ -935,7 +971,7 @@ class EngineService:
threading.Thread(target=receiver_loop, daemon=True).start() threading.Thread(target=receiver_loop, daemon=True).start()
def start_cache_service(self, device_ids, ipc_signal_suffix): def start_cache_service(self, device_ids, ipc_signal_suffix, create_cache_tensor):
return self.resource_manager.cache_manager.launch_cache_manager( return self.resource_manager.cache_manager.launch_cache_manager(
cache_config=self.cfg.cache_config, cache_config=self.cfg.cache_config,
tensor_parallel_size=self.cfg.parallel_config.tensor_parallel_size, tensor_parallel_size=self.cfg.parallel_config.tensor_parallel_size,
@@ -945,6 +981,7 @@ class EngineService:
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id] self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
), ),
pid_suffix=ipc_signal_suffix, pid_suffix=ipc_signal_suffix,
create_cache_tensor=create_cache_tensor,
) )
def check_and_free_block_tables(self): def check_and_free_block_tables(self):
@@ -971,8 +1008,12 @@ class EngineService:
self.exist_task_signal.clear() self.exist_task_signal.clear()
self.exist_swapped_task_signal.clear() self.exist_swapped_task_signal.clear()
self.worker_healthy_live_signal.clear() self.worker_healthy_live_signal.clear()
self.cache_ready_signal.clear()
self.swap_space_ready_signal.clear()
self.exist_prefill_task_signal.clear() self.exist_prefill_task_signal.clear()
self.model_weights_status_signal.clear() self.model_weights_status_signal.clear()
self.prefix_tree_status_signal.clear()
self.kv_cache_status_signal.clear()
if hasattr(self, "send_response_server") and self.send_response_server is not None: if hasattr(self, "send_response_server") and self.send_response_server is not None:
self.send_response_server.close() self.send_response_server.close()
if hasattr(self, "recv_request_server") and self.recv_request_server is not None: if hasattr(self, "recv_request_server") and self.recv_request_server is not None:

View File

@@ -126,14 +126,14 @@ class LLMEngine:
self.engine.start() self.engine.start()
if self.do_profile == 0 and ( # If block numer is specified and model is deployed in mixed mode, start cache manager first
self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed" if not self.do_profile and self.cfg.scheduler_config.splitwise_role != "mixed":
):
device_ids = self.cfg.device_ids.split(",") device_ids = self.cfg.device_ids.split(",")
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix) self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix, True)
# Start workers
self.worker_proc = self._start_worker_service() self.worker_proc = self._start_worker_service()
console_logger.info("Waiting worker processes ready...") console_logger.info("Waiting for worker processes to be ready...")
time.sleep(5) time.sleep(5)
self.worker_init_status = dict() self.worker_init_status = dict()
@@ -157,10 +157,16 @@ class LLMEngine:
return False return False
time.sleep(1) time.sleep(1)
# If block number is not specified, let workers do profiling to determine the block number,
# and then start the cache manager
if self.do_profile: if self.do_profile:
self._stop_profile() self._stop_profile()
elif self.cfg.cache_config.enable_prefix_caching:
device_ids = self.cfg.device_ids.split(",")
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix, False)
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": # Launch components: scheduler, cache_manager, expert_service et.al.
if self.cfg.scheduler_config.splitwise_role != "mixed":
self.launched_cache_manager_signal.value[0] = 1 self.launched_cache_manager_signal.value[0] = 1
if api_server_pid is not None: if api_server_pid is not None:
@@ -174,6 +180,24 @@ class LLMEngine:
return False return False
console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.") console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.")
# Print blocks number & max running requests to console
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
block_size = self.cfg.cache_config.block_size
num_gpu_blocks = self.cfg.cache_config.num_gpu_blocks_override or self.cfg.cache_config.total_block_num
num_cpu_blocks = self.cfg.cache_config.num_cpu_blocks
max_running_requests = min(
(num_gpu_blocks + num_cpu_blocks) * block_size // self.cfg.max_model_len,
self.cfg.scheduler_config.max_num_seqs,
)
console_logger.info(
f"Detected {num_gpu_blocks} gpu blocks and {num_cpu_blocks} cpu blocks in cache (block size: {block_size})."
)
console_logger.info(
f"FastDeploy will be serving {max_running_requests} running requests "
f"if each sequence reaches its maximum length: {self.cfg.max_model_len}"
)
return True return True
def _get_generated_result(self): def _get_generated_result(self):
@@ -622,7 +646,9 @@ class LLMEngine:
self.engine.resource_manager.reset_cache_config(self.cfg.cache_config) self.engine.resource_manager.reset_cache_config(self.cfg.cache_config)
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
device_ids = self.cfg.device_ids.split(",") device_ids = self.cfg.device_ids.split(",")
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix) self.cache_manager_processes = self.engine.start_cache_service(
device_ids, self.ipc_signal_suffix, self.cfg.scheduler_config.splitwise_role != "mixed"
)
def check_health(self, time_interval_threashold=30): def check_health(self, time_interval_threashold=30):
""" """

View File

@@ -142,7 +142,6 @@ class ExpertService:
if hasattr(self, "cache_manager_processes"): if hasattr(self, "cache_manager_processes"):
self.engine.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear() self.engine.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
self.engine.resource_manager.cache_manager.cache_ready_signal.clear()
for p in self.cache_manager_processes: for p in self.cache_manager_processes:
llm_logger.info(f"Killing cache manager process {p.pid}") llm_logger.info(f"Killing cache manager process {p.pid}")
try: try:

View File

@@ -145,11 +145,13 @@ class ResourceManagerV1(ResourceManager):
if preempted_req.request_id in self.req_dict: if preempted_req.request_id in self.req_dict:
del self.req_dict[preempted_req.request_id] del self.req_dict[preempted_req.request_id]
self._free_blocks(preempted_req) self._free_blocks(preempted_req)
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
main_process_metrics.num_requests_running.dec(1) main_process_metrics.num_requests_running.dec(1)
else: else:
self._free_blocks(preempted_req) self._free_blocks(preempted_req)
preempted_req.cached_block_num = 0 preempted_req.cached_block_num = 0
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id) self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
main_process_metrics.num_requests_waiting.inc(1) main_process_metrics.num_requests_waiting.inc(1)
main_process_metrics.num_requests_running.dec(1) main_process_metrics.num_requests_running.dec(1)
preempted_reqs.append(preempted_req) preempted_reqs.append(preempted_req)

View File

@@ -21,13 +21,20 @@ import traceback
import uuid import uuid
import numpy as np import numpy as np
from filelock import FileLock
from fastdeploy import envs from fastdeploy import envs
from fastdeploy.config import ModelConfig from fastdeploy.config import ModelConfig
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
from fastdeploy.input.preprocess import InputPreprocessor from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient from fastdeploy.inter_communicator import (
IPCSignal,
KVCacheStatus,
ModelWeightsStatus,
PrefixTreeStatus,
ZmqIpcClient,
)
from fastdeploy.metrics.work_metrics import work_process_metrics from fastdeploy.metrics.work_metrics import work_process_metrics
from fastdeploy.multimodal.registry import MultimodalRegistry from fastdeploy.multimodal.registry import MultimodalRegistry
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
@@ -60,6 +67,8 @@ class EngineClient:
enable_logprob=False, enable_logprob=False,
workers=1, workers=1,
tool_parser=None, tool_parser=None,
enable_prefix_caching=None,
splitwise_role=None,
): ):
architectures = ModelConfig({"model": model_name_or_path}).architectures[0] architectures = ModelConfig({"model": model_name_or_path}).architectures[0]
if MultimodalRegistry.contains_model(architectures): if MultimodalRegistry.contains_model(architectures):
@@ -79,6 +88,8 @@ class EngineClient:
self.reasoning_parser = reasoning_parser self.reasoning_parser = reasoning_parser
self.data_processor = input_processor.create_processor() self.data_processor = input_processor.create_processor()
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.enable_prefix_caching = enable_prefix_caching
self.enable_splitwise = splitwise_role != "mixed"
max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
if tensor_parallel_size <= max_chips_per_node: if tensor_parallel_size <= max_chips_per_node:
@@ -104,10 +115,27 @@ class EngineClient:
suffix=port, suffix=port,
create=False, create=False,
) )
prefix_tree_status = np.zeros([1], dtype=np.int32)
self.prefix_tree_status_signal = IPCSignal(
name="prefix_tree_status",
array=prefix_tree_status,
dtype=np.int32,
suffix=port,
create=False,
)
kv_cache_status = np.zeros([1], dtype=np.int32)
self.kv_cache_status_signal = IPCSignal(
name="kv_cache_status",
array=kv_cache_status,
dtype=np.int32,
suffix=port,
create=False,
)
self.connection_manager = DealerConnectionManager( self.connection_manager = DealerConnectionManager(
pid, max_connections=int(os.getenv("FD_DEALER_CONNECTIONS", 50)) pid, max_connections=int(os.getenv("FD_DEALER_CONNECTIONS", 50))
) )
self.connection_initialized = False self.connection_initialized = False
self.clear_update_lock = FileLock(f"/tmp/fd_weight_clear_update_lock__pid{pid}_port{port}.lock")
def create_zmq_client(self, model, mode): def create_zmq_client(self, model, mode):
""" """
@@ -298,7 +326,7 @@ class EngineClient:
Check the health of the model server by checking whether all workers are alive. Check the health of the model server by checking whether all workers are alive.
""" """
if self.model_weights_status_signal.value[0] == 0: if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL:
return True, "" return True, ""
else: else:
return False, "No model weight enabled" return False, "No model weight enabled"
@@ -309,18 +337,41 @@ class EngineClient:
1 : worker receive the signal and start to update model weight 1 : worker receive the signal and start to update model weight
2 : worker update finish and notify client 2 : worker update finish and notify client
""" """
if self.model_weights_status_signal.value[0] == 0: with self.clear_update_lock:
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL:
return True, "" return True, ""
if self.model_weights_status_signal.value[0] == 1: if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING:
return False, "updating model weight already" return False, "worker is updating model weight already"
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING:
return False, "worker is clearing model weight, cannot update now"
self.model_weights_status_signal.value[0] = 1 self.model_weights_status_signal.value[0] = ModelWeightsStatus.UPDATING
if self.enable_prefix_caching or self.enable_splitwise:
self.kv_cache_status_signal.value[0] = KVCacheStatus.UPDATING
if self.enable_prefix_caching:
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.UPDATING
api_server_logger.info(f"start update model weight {self.model_weights_status_signal.value}") api_server_logger.info(f"start update model weight {self.model_weights_status_signal.value}")
while self.model_weights_status_signal.value[0] != 0 and timeout != 0: all_updated = False
while timeout >= 0 and not all_updated:
api_server_logger.info(
f"Updating model weights.. "
f"model_weights_status: {self.model_weights_status_signal.value[0]}, "
f"prefix_tree_status: {self.prefix_tree_status_signal.value[0]}, "
f"kv_cache_status: {self.kv_cache_status_signal.value[0]} "
)
weight_updated = self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL
cache_updated = self.kv_cache_status_signal.value[0] == KVCacheStatus.NORMAL
prefix_updated = self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.NORMAL
if self.enable_prefix_caching or self.enable_splitwise:
if self.enable_prefix_caching:
all_updated = weight_updated and cache_updated and prefix_updated
else:
all_updated = weight_updated and cache_updated
else:
all_updated = weight_updated
time.sleep(1) time.sleep(1)
timeout -= 1 timeout -= 1
continue if timeout < 0:
if self.model_weights_status_signal.value[0] != 0:
return False, "Update model weight timeout" return False, "Update model weight timeout"
time.sleep(1) time.sleep(1)
return True, "" return True, ""
@@ -331,20 +382,45 @@ class EngineClient:
-1 : worker receive the signal and start to clear model weight -1 : worker receive the signal and start to clear model weight
-2 : worker clear finish and notify client -2 : worker clear finish and notify client
""" """
if self.model_weights_status_signal.value[0] == -2:
return True, ""
if self.model_weights_status_signal.value[0] == -1:
return False, "clearing model weight already"
self.model_weights_status_signal.value[0] = -1 with self.clear_update_lock:
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED:
return True, ""
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING:
return False, "worker is clearing model weight already"
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING:
return False, "worker is updating model weight, cannot clear now"
self.model_weights_status_signal.value[0] = ModelWeightsStatus.CLEARING
if self.enable_prefix_caching or self.enable_splitwise:
self.kv_cache_status_signal.value[0] = KVCacheStatus.CLEARING
if self.enable_prefix_caching:
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARING
api_server_logger.info(f"start clear model weight {self.model_weights_status_signal.value}") api_server_logger.info(f"start clear model weight {self.model_weights_status_signal.value}")
while self.model_weights_status_signal.value[0] != -2 and timeout != 0: all_cleared = False
while timeout >= 0 and not all_cleared:
api_server_logger.info(
f"Clearing model weights.. "
f"model_weights_status: {self.model_weights_status_signal.value[0]}, "
f"prefix_tree_status: {self.prefix_tree_status_signal.value[0]}, "
f"kv_cache_status: {self.kv_cache_status_signal.value[0]} "
)
weight_cleared = self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED
cache_cleared = self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARED
prefix_cleared = self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARED
if self.enable_prefix_caching or self.enable_splitwise:
if self.enable_prefix_caching:
all_cleared = weight_cleared and cache_cleared and prefix_cleared
else:
all_cleared = weight_cleared and cache_cleared
else:
all_cleared = weight_cleared
time.sleep(1) time.sleep(1)
timeout -= 1 timeout -= 1
continue
if self.model_weights_status_signal.value[0] != -2: if timeout < 0:
return False, "clear model weight timeout" return False, "Clear model weight timeout"
time.sleep(1) time.sleep(1)
return True, "" return True, ""

View File

@@ -162,6 +162,8 @@ async def lifespan(app: FastAPI):
enable_logprob=args.enable_logprob, enable_logprob=args.enable_logprob,
workers=args.workers, workers=args.workers,
tool_parser=args.tool_call_parser, tool_parser=args.tool_call_parser,
enable_prefix_caching=args.enable_prefix_caching,
splitwise_role=args.splitwise_role,
) )
await engine_client.connection_manager.initialize() await engine_client.connection_manager.initialize()
app.state.dynamic_load_weight = args.dynamic_load_weight app.state.dynamic_load_weight = args.dynamic_load_weight

View File

@@ -116,6 +116,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Max pre-fetch requests number in PD # Max pre-fetch requests number in PD
"FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")), "FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")),
"FD_ENABLE_MODEL_LOAD_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_LOAD_CACHE", "0"))), "FD_ENABLE_MODEL_LOAD_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_LOAD_CACHE", "0"))),
# Whether to clear cpu cache when clearing model weights.
"FD_ENABLE_SWAP_SPACE_CLEARING": lambda: int(os.getenv("FD_ENABLE_SWAP_SPACE_CLEARING", "0")),
} }

View File

@@ -17,15 +17,25 @@
from .engine_cache_queue import EngineCacheQueue from .engine_cache_queue import EngineCacheQueue
from .engine_worker_queue import EngineWorkerQueue from .engine_worker_queue import EngineWorkerQueue
from .ipc_signal import IPCSignal, shared_memory_exists from .ipc_signal import IPCSignal, shared_memory_exists
from .ipc_signal_const import (
ExistTaskStatus,
KVCacheStatus,
ModelWeightsStatus,
PrefixTreeStatus,
)
from .zmq_client import ZmqIpcClient from .zmq_client import ZmqIpcClient
from .zmq_server import ZmqIpcServer, ZmqTcpServer from .zmq_server import ZmqIpcServer, ZmqTcpServer
__all__ = [ __all__ = [
"ZmqIpcClient", "ZmqIpcClient",
"ZmqIpcServer",
"ZmqTcpServer",
"IPCSignal", "IPCSignal",
"EngineWorkerQueue", "EngineWorkerQueue",
"EngineCacheQueue", "EngineCacheQueue",
"ZmqTcpServer",
"ZmqIpcServer",
"shared_memory_exists", "shared_memory_exists",
"ExistTaskStatus",
"PrefixTreeStatus",
"ModelWeightsStatus",
"KVCacheStatus",
] ]

View File

@@ -0,0 +1,32 @@
from dataclasses import dataclass
@dataclass
class ModelWeightsStatus:
NORMAL = 0
UPDATING = 1
CLEARING = -1
CLEARED = -2
@dataclass
class PrefixTreeStatus:
NORMAL = 0
UPDATING = 1
CLEARING = -1
CLEARED = -2
@dataclass
class KVCacheStatus:
NORMAL = 0
UPDATING = 1
CLEARING = -1
CLEARED = -2
@dataclass
class ExistTaskStatus:
EMPTY = 0
EXIST = 1
REFUSE = 2

View File

@@ -298,7 +298,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
) )
if layer.reduce_results and layer.tp_size > 1: if layer.reduce_results and layer.tp_size > 1:
tensor_model_parallel_all_reduce(fused_moe_out) tensor_model_parallel_all_reduce(fused_moe_out, layer.fd_config.parallel_config.tp_group)
return fused_moe_out return fused_moe_out

View File

@@ -25,6 +25,7 @@ from paddle import nn
from paddleformers.utils.log import logger from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.inter_communicator import ModelWeightsStatus
class DynamicWeightManager: class DynamicWeightManager:
@@ -143,12 +144,11 @@ class DynamicWeightManager:
if self.parallel_config.tensor_parallel_size > 1: if self.parallel_config.tensor_parallel_size > 1:
# tp barrier # tp barrier
paddle.distributed.barrier(self.parallel_config.tp_group) paddle.distributed.barrier(self.parallel_config.tp_group)
# shutdown tp group
paddle.distributed.shutdown_process_group(self.parallel_config.tp_group) paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
if self.parallel_config.enable_expert_parallel:
# step3: update model weight signal paddle.distributed.barrier(self.parallel_config.ep_group)
# step4: release kv cache in the runner paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
self._update_shared_status(pid, -2) self._update_shared_status(pid, ModelWeightsStatus.CLEARED)
def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor], src_type: str): def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor], src_type: str):
"""Update model parameters from given state dictionary.""" """Update model parameters from given state dictionary."""
@@ -184,8 +184,7 @@ class DynamicWeightManager:
paddle.distributed.barrier(self.parallel_config.ep_group) paddle.distributed.barrier(self.parallel_config.ep_group)
if not self.first_load: if not self.first_load:
self._update_shared_status(pid, 0) self._update_shared_status(pid, ModelWeightsStatus.NORMAL)
self.first_load = False self.first_load = False
def _get_gpu_id(self) -> int: def _get_gpu_id(self) -> int:
@@ -252,25 +251,19 @@ class DynamicWeightManager:
""" """
check model weights status check model weights status
""" """
is_stop = 0 logger.info(f"dynamic weight manager is check model weights status! {model_weights_status.value[0]}")
while model_weights_status.value[0] != 0: while model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
if model_weights_status.value[0] == 1: if model_weights_status.value[0] == ModelWeightsStatus.UPDATING:
logger.info("infer engine stopped! start to load new checkpoint...") logger.info("infer engine stopped! start to load new checkpoint...")
model_runner.update_parameters(pid) model_runner.update_parameters(pid)
elif model_weights_status.value[0] == -1: while model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
time.sleep(0.01)
logger.info("finished loading new checkpoint")
elif model_weights_status.value[0] == ModelWeightsStatus.CLEARING:
logger.info("infer engine stopped! start to clear checkpoint...") logger.info("infer engine stopped! start to clear checkpoint...")
model_runner.clear_requests() model_runner.clear_requests()
model_runner.clear_parameters(pid) model_runner.clear_parameters(pid)
while model_weights_status.value[0] != ModelWeightsStatus.CLEARED:
while True: time.sleep(0.01)
if model_weights_status.value[0] == 0:
logger.info("finished loading new checkpoint")
break
elif is_stop == 1 or (model_weights_status.value[0] == -2 and is_stop == 0):
if is_stop == 0:
logger.info("finished clearing checkpoint") logger.info("finished clearing checkpoint")
is_stop = 1 time.sleep(0.01)
time.sleep(0.001)
break
else:
time.sleep(0.001)

View File

@@ -59,6 +59,7 @@ else:
set_value_by_flags_and_idx, set_value_by_flags_and_idx,
share_external_data, share_external_data,
speculate_schedule_cache, speculate_schedule_cache,
set_data_ipc,
) )
from fastdeploy.model_executor.pre_and_post_process import ( from fastdeploy.model_executor.pre_and_post_process import (
@@ -75,7 +76,7 @@ import zmq
from fastdeploy import envs from fastdeploy import envs
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
from fastdeploy.inter_communicator import ZmqIpcClient from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
from fastdeploy.worker.model_runner_base import ModelRunnerBase from fastdeploy.worker.model_runner_base import ModelRunnerBase
@@ -1146,7 +1147,7 @@ class GPUModelRunner(ModelRunnerBase):
""" """
Initialize kv cache Initialize kv cache
""" """
cache_kvs = {} # cache_kvs = {}
max_block_num = self.num_gpu_blocks max_block_num = self.num_gpu_blocks
# Get kv cache dtype # Get kv cache dtype
@@ -1169,47 +1170,59 @@ class GPUModelRunner(ModelRunnerBase):
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]] kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not profile and ( cache_ready_signal_data = np.zeros(shape=[self.parallel_config.tensor_parallel_size], dtype=np.int32)
self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed" cache_ready_signal = IPCSignal(
): name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
# Check if gpu runner needs to create kv cache
# 1. During profiling, it creates its own kv cache.
# 2. GPU runner creates kv cache tensor unless p/d disaggregation is enabled.
create_cache_tensor = profile or self.scheduler_config.splitwise_role == "mixed"
if not create_cache_tensor:
logger.info(f"Waiting for cache managers to create kv cache.. {cache_ready_signal.value}")
while cache_ready_signal.value[self.local_rank] != 1:
time.sleep(1)
logger.info(f"OK! Stop waiting. {cache_ready_signal.value}")
logger.info(f"Initializing kv cache for all layers. {cache_ready_signal.value}")
cache_kvs_list = [] cache_kvs_list = []
for i in range(self.model_config.num_hidden_layers): for i in range(self.model_config.num_hidden_layers):
key_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}" key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}" val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape) if create_cache_tensor:
cache_kvs_list.append(key_cache) logger.info(f"..creating kv cache for layer {i}: {kv_cache_shape}")
value_cache = paddle.empty(shape=[], dtype=cache_type) key_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type)
value_cache = share_external_data(value_cache, val_cache_name, kv_cache_shape) val_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type)
cache_kvs_list.append(value_cache) set_data_ipc(key_cache, key_cache_name)
set_data_ipc(val_cache, val_cache_name)
self.share_inputs["caches"] = cache_kvs_list cache_kvs_list.extend([key_cache, val_cache])
else:
for i in range(self.model_config.num_hidden_layers):
cache_kvs[f"key_caches_{i}"] = paddle.full(
shape=kv_cache_shape,
fill_value=0,
dtype=cache_type,
)
cache_kvs[f"value_caches_{i}"] = paddle.full(
shape=kv_cache_shape,
fill_value=0,
dtype=cache_type,
)
if kv_cache_quant_type == "block_wise_fp8": if kv_cache_quant_type == "block_wise_fp8":
cache_kvs[f"key_cache_scales_{i}"] = paddle.full( key_cache_scales = paddle.full(
shape=kv_cache_scale_shape, shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype()
fill_value=0,
dtype=paddle.get_default_dtype(),
) )
cache_kvs[f"value_cache_scales_{i}"] = paddle.full( val_cache_scales = paddle.full(
shape=kv_cache_scale_shape, shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype()
fill_value=0,
dtype=paddle.get_default_dtype(),
) )
self.share_inputs["caches"] = list(cache_kvs.values()) cache_kvs_list.extend([key_cache_scales, val_cache_scales])
for value in cache_kvs.values(): else:
del value logger.info(f"..attaching kv cache for layer {i}: {kv_cache_shape}")
key_cache = paddle.empty(shape=[], dtype=cache_type)
val_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
val_cache = share_external_data(val_cache, val_cache_name, kv_cache_shape)
cache_kvs_list.extend([key_cache, val_cache])
self.share_inputs["caches"] = cache_kvs_list
if not profile and create_cache_tensor:
cache_ready_signal.value[self.local_rank] = 1
logger.info(f"✅ kv cache is ready! {cache_ready_signal.value}")
paddle.device.cuda.empty_cache() paddle.device.cuda.empty_cache()
def initialize_attn_backend(self) -> None: def initialize_attn_backend(self) -> None:
@@ -1935,6 +1948,7 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs.pop("caches", None) self.share_inputs.pop("caches", None)
if self.forward_meta is not None: if self.forward_meta is not None:
self.forward_meta.clear_caches() self.forward_meta.clear_caches()
paddle.device.cuda.empty_cache()
def clear_parameters(self, pid): def clear_parameters(self, pid):
"""Dynamic model loader use to clear parameters use for RL""" """Dynamic model loader use to clear parameters use for RL"""

View File

@@ -42,7 +42,7 @@ from fastdeploy.config import (
) )
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
from fastdeploy.inter_communicator import IPCSignal from fastdeploy.inter_communicator import ExistTaskStatus, IPCSignal, ModelWeightsStatus
from fastdeploy.model_executor.layers.quantization import parse_quant_config from fastdeploy.model_executor.layers.quantization import parse_quant_config
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
from fastdeploy.scheduler import SchedulerConfig from fastdeploy.scheduler import SchedulerConfig
@@ -183,7 +183,7 @@ class PaddleDisWorkerProc:
name="launched_expert_service_signal", name="launched_expert_service_signal",
array=launched_expert_service_signal_data, array=launched_expert_service_signal_data,
dtype=np.int32, dtype=np.int32,
suffix=self.parallel_config.engine_pid, suffix=self.parallel_config.engine_worker_queue_port,
create=False, create=False,
) )
while self.launched_expert_service_signal.value[self.local_rank % self.max_chips_per_node] == 0: while self.launched_expert_service_signal.value[self.local_rank % self.max_chips_per_node] == 0:
@@ -200,7 +200,7 @@ class PaddleDisWorkerProc:
name="worker_ready_signal", name="worker_ready_signal",
array=workers_ready, array=workers_ready,
dtype=np.int32, dtype=np.int32,
suffix=self.parallel_config.engine_pid, suffix=self.parallel_config.engine_worker_queue_port,
create=False, create=False,
) )
self.worker_ready_signal.value[self.local_rank % self.max_chips_per_node] = 1 self.worker_ready_signal.value[self.local_rank % self.max_chips_per_node] = 1
@@ -279,8 +279,8 @@ class PaddleDisWorkerProc:
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
self.model_weights_signal = np.zeros([1], dtype=np.int32) self.model_weights_signal = np.zeros([1], dtype=np.int32)
while True: while True:
if local_rank == 0: if self.local_rank % self.parallel_config.tensor_parallel_size == 0:
if self.model_weights_status.value[0] != 0: if self.model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
self.model_weights_signal[0] = int(self.model_weights_status.value[0]) self.model_weights_signal[0] = int(self.model_weights_status.value[0])
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel: if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel:
self.model_weights_signal[0] = self._broadcast_model_weights_signal( self.model_weights_signal[0] = self._broadcast_model_weights_signal(
@@ -306,7 +306,7 @@ class PaddleDisWorkerProc:
if self.nnode > 1 and self.parallel_config.tensor_parallel_size > self.max_chips_per_node: if self.nnode > 1 and self.parallel_config.tensor_parallel_size > self.max_chips_per_node:
self.task_queue.read_finish_flag.set(1) self.task_queue.read_finish_flag.set(1)
else: else:
self.exist_task_signal.value[0] = 1 self.exist_task_signal.value[0] = ExistTaskStatus.EXIST
if self.parallel_config.tensor_parallel_size > 1: if self.parallel_config.tensor_parallel_size > 1:
# Synchronize the signal for other workers # Synchronize the signal for other workers
@@ -317,7 +317,7 @@ class PaddleDisWorkerProc:
paddle.distributed.barrier(self.parallel_config.ep_group) paddle.distributed.barrier(self.parallel_config.ep_group)
else: else:
paddle.distributed.barrier(self.parallel_config.tp_group) paddle.distributed.barrier(self.parallel_config.tp_group)
if self.model_weights_signal[0] != 0: if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL:
logger.info( logger.info(
f"Rank: {self.local_rank} to update or clear parameters, signal is {self.model_weights_signal[0]}, [-1:clear, 1:update]" f"Rank: {self.local_rank} to update or clear parameters, signal is {self.model_weights_signal[0]}, [-1:clear, 1:update]"
) )
@@ -332,17 +332,17 @@ class PaddleDisWorkerProc:
self.worker.model_runner, self.worker.model_runner,
self.parallel_config.engine_worker_queue_port, self.parallel_config.engine_worker_queue_port,
) )
self.model_weights_signal[0] = 0 self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.") logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.")
if self.exist_task_signal.value[0] == 1 or self.task_queue.read_finish_flag.get() == 1: if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
logger.info(f"Rank: {self.local_rank} Detected new requests.") logger.info(f"Rank: {self.local_rank} Detected new requests.")
self.insert_step = True self.insert_step = True
tasks, read_finish = self.task_queue.get_tasks() tasks, read_finish = self.task_queue.get_tasks()
if read_finish: if read_finish:
# Ensure that every worker get the task # Ensure that every worker get the task
self.exist_task_signal.value[0] = 0 self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY
self.task_queue.read_finish_flag.set(0) self.task_queue.read_finish_flag.set(0)
req_dicts = [] req_dicts = []
@@ -418,25 +418,14 @@ class PaddleDisWorkerProc:
name="get_profile_block_num", name="get_profile_block_num",
array=get_profile_block_num, array=get_profile_block_num,
dtype=np.int32, dtype=np.int32,
suffix=self.parallel_config.engine_pid, suffix=self.parallel_config.engine_worker_queue_port,
create=False, create=False,
) )
self.get_profile_block_num_signal.value[0] = num_blocks_local self.get_profile_block_num_signal.value[0] = num_blocks_local
else: else:
num_blocks_local = self.fd_config.parallel_config.total_block_num num_blocks_local = self.fd_config.parallel_config.total_block_num
logger.info(f"------- num_blocks_global: {num_blocks_local} --------") logger.info(f"------- num_blocks_global: {num_blocks_local} --------")
# wait engine launch cache_manager
if self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed":
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
self.launched_cache_manager_signal = IPCSignal(
name="launched_cache_manager_signal",
array=launched_cache_manager_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False,
)
while np.any(self.launched_cache_manager_signal.value[0] <= 0):
time.sleep(0.01)
# 4. init kv_cache with accurate num_blocks # 4. init kv_cache with accurate num_blocks
self.worker.initialize_cache(num_gpu_blocks=num_blocks_local) self.worker.initialize_cache(num_gpu_blocks=num_blocks_local)
@@ -488,7 +477,7 @@ class PaddleDisWorkerProc:
name="loaded_model_signal", name="loaded_model_signal",
array=loaded_model_signal_data, array=loaded_model_signal_data,
dtype=np.int32, dtype=np.int32,
suffix=self.parallel_config.engine_pid, suffix=self.parallel_config.engine_worker_queue_port,
create=False, create=False,
) )
if self.ranks > 1: if self.ranks > 1: