diff --git a/custom_ops/gpu_ops/update_inputs_v1.cu b/custom_ops/gpu_ops/update_inputs_v1.cu index 9229fdcf0..33076b073 100644 --- a/custom_ops/gpu_ops/update_inputs_v1.cu +++ b/custom_ops/gpu_ops/update_inputs_v1.cu @@ -32,7 +32,8 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop, const int max_bsz, const int input_ids_stride, const int block_num_per_seq, - const int block_size) { + const int block_size, + bool prefill_one_step_stop) { int thread_idx = threadIdx.x; typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; @@ -54,23 +55,32 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop, seq_lens_encoder[thread_idx] = 0; } else { if (seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >= prompt_lens[thread_idx]) { - // decoding - seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx]; - seq_lens_this_time[thread_idx] = 1; - seq_lens_encoder[thread_idx] = 0; - int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride; - input_ids_now[0] = next_tokens[thread_idx]; + if (prefill_one_step_stop) { + // prefill done, stop + stop_flags[thread_idx] = true; + seq_lens_this_time[thread_idx] = 0; + seq_lens_decoder[thread_idx] = 0; + seq_lens_encoder[thread_idx] = 0; + stop_flag_now_int = 1; + } else{ + // decoding + seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx]; + seq_lens_this_time[thread_idx] = 1; + seq_lens_encoder[thread_idx] = 0; + int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride; + input_ids_now[0] = next_tokens[thread_idx]; - // to judge whether block is not enough - int *block_table_now = block_tables + thread_idx * block_num_per_seq; - if (seq_lens_this_time[thread_idx] != 0 && block_table_now[seq_lens_decoder[thread_idx] / block_size] == -1) { - // should be scheduled by server - is_block_step[thread_idx] = true; - seq_lens_this_time[thread_idx]= 0; - stop_flags[thread_idx] = true; - step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx]; - seq_lens_decoder[thread_idx] = 0; - stop_flag_now_int = 1; + // to judge whether block is not enough + int *block_table_now = block_tables + thread_idx * block_num_per_seq; + if (seq_lens_this_time[thread_idx] != 0 && block_table_now[seq_lens_decoder[thread_idx] / block_size] == -1) { + // should be scheduled by server + is_block_step[thread_idx] = true; + seq_lens_this_time[thread_idx]= 0; + stop_flags[thread_idx] = true; + step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx]; + seq_lens_decoder[thread_idx] = 0; + stop_flag_now_int = 1; + } } } else { @@ -110,6 +120,12 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags, #else auto cu_stream = input_ids.stream(); #endif + bool prefill_one_step_stop = false; + if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) { + if (env_p[0] == '1') { + prefill_one_step_stop = true; + } + } const int max_bsz = stop_flags.shape()[0]; const int now_bsz = seq_lens_this_time.shape()[0]; const int input_ids_stride = input_ids.shape()[1]; @@ -133,7 +149,8 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags, max_bsz, input_ids_stride, block_num_per_seq, - block_size); + block_size, + prefill_one_step_stop); auto not_need_stop_cpu = not_need_stop_gpu.copy_to(not_need_stop.place(), false); bool *not_need_stop_data = const_cast(not_need_stop.data()); diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index 65d412f39..4502860d0 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -14,7 +14,10 @@ # limitations under the License. """ +import argparse +import json import math +import queue import threading import time import traceback @@ -23,16 +26,72 @@ import numpy as np import paddle from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager +from fastdeploy.config import SpeculativeConfig from fastdeploy.inter_communicator import ( EngineWorkerQueue, IPCSignal, shared_memory_exists, ) -from fastdeploy.utils import get_logger +from fastdeploy.model_executor.ops.gpu import get_output_kv_signal, set_data_ipc +from fastdeploy.utils import envs, get_logger logger = get_logger("cache_messager", "cache_messager.log") +def parse_args(): + """ + 从命令行解析参数 + """ + parser = argparse.ArgumentParser("Cache Messager") + parser.add_argument( + "--splitwise_role", + type=str, + default="mixed", + help="splitwise role, can be decode, prefill or mixed", + ) + parser.add_argument("--rank", type=int, default=0, help="current rank") + parser.add_argument("--device_id", type=int, default=0, help="device id") + parser.add_argument("--num_layers", type=int, default=1, help="model num layers") + parser.add_argument("--head_dim", type=int, default=1, help="model head dim") + parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head") + parser.add_argument("--rdma_port", type=str, default="", help="rmda port") + parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel") + parser.add_argument("--engine_pid", type=str, default=None, help="engine pid") + parser.add_argument( + "--protocol", + type=str, + default="ipc", + help="cache transfer protocol, only surport ipc now", + ) + parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip") + parser.add_argument("--cache_queue_port", type=int, default=9924, help="cache queue port") + parser.add_argument( + "--engine_worker_queue_port", + type=int, + default=9923, + help="engine worker queue port", + ) + parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number") + parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)") + parser.add_argument( + "--cache_dtype", + type=str, + default="bfloat16", + choices=["uint8", "bfloat16"], + help="cache dtype", + ) + parser.add_argument( + "--speculative_config", + type=json.loads, + default="{}", + help="speculative config", + ) + parser.add_argument("--local_data_parallel_id", type=int, default=0) + + args = parser.parse_args() + return args + + class CacheMessager: """ CacheMessager is used to send the cache data between the engine worker and the cache server. @@ -69,11 +128,6 @@ class CacheMessager: Returns: None """ - - assert splitwise_role in [ - "prefill", - "decode", - ], "splitwise_role must be prefill or decode" self.splitwise_role = splitwise_role self.gpu_cache_kvs = gpu_cache_kvs self.rank = rank @@ -147,15 +201,16 @@ class CacheMessager: self.gpu_id = gpu_id self.cache_info = dict() - self.dp_rank_id = self.rank + local_data_parallel_id * self.nranks + self.rank_id = self.rank + local_data_parallel_id * self.nranks - layerwise_send_cache_thread = threading.Thread(target=self._prefill_layerwise_send_cache_thread) - layerwise_send_cache_thread.daemon = True - layerwise_send_cache_thread.start() + if self.splitwise_role != "mixed": + connect_rdma_thread = threading.Thread(target=self._handle_connect_task) + connect_rdma_thread.daemon = True + connect_rdma_thread.start() logger.info(f"cache messager init finished, use {transfer_protocol}") - def _prefill_layerwise_send_cache_thread(self): + def prefill_layerwise_send_cache_thread(self): """ layerwise_send_cache_thread: send cache to other instance @@ -163,23 +218,23 @@ class CacheMessager: try: prefilled_step_idx_data = np.zeros(shape=[1], dtype=np.int32) prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32) - prefilled_layer_name = f"splitwise_complete_prefilled_layer_{self.dp_rank_id}.{self.gpu_id}" - prefilled_step_name = f"splitwise_complete_prefilled_step_{self.dp_rank_id}.{self.gpu_id}" + prefilled_layer_name = f"splitwise_complete_prefilled_step_{self.rank_id}.{self.gpu_id}" + prefilled_step_name = f"splitwise_complete_prefilled_step_{self.rank_id}.{self.gpu_id}" step_shm_value = IPCSignal( - name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}", + name=f"splitwise_complete_prefilled_step_{self.rank_id}", array=prefilled_step_idx_data, dtype=np.int32, suffix=self.gpu_id, create=not shared_memory_exists(prefilled_step_name), ) layer_shm_value = IPCSignal( - name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}", + name=f"splitwise_complete_prefilled_layer_{self.rank_id}", array=prefilled_layer_idx_data, dtype=np.int32, suffix=self.gpu_id, create=not shared_memory_exists(prefilled_layer_name), ) - logger.info(f"splitwise_complete_prefilled_step_{self.dp_rank_id}, gpu_id: {self.gpu_id}") + logger.info(f"splitwise_complete_prefilled_step_{self.rank_id}, gpu_id: {self.gpu_id}") step_shm_value.value[0] = -1 layer_shm_value.value[0] = -1 @@ -187,6 +242,9 @@ class CacheMessager: self.last_step_idx = -1 self.last_layer_idx = -1 # int32 + max_step_idx = 100003 + engine_recycled_count = 0 + while True: cache_info = self.engine_worker_queue.get_cache_info() @@ -202,11 +260,9 @@ class CacheMessager: -len(current_info["dest_block_ids"]) : ] current_info["src_block_ids"] = current_src_blocks - current_info["current_layer_ids"] = 0 current_info["status"] = "init" logger.info(f"start cache_infos: {current_info}") self.cache_info[info["request_id"]] = current_info - self.last_step_idx = min(self.last_step_idx, current_info["current_id"]) else: self.cache_info[info["request_id"]] = info prefilled_layer_idx = layer_shm_value.value[0] @@ -223,7 +279,18 @@ class CacheMessager: if not self.cache_info: time.sleep(0.001) continue - logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}") + + if self.last_step_idx > prefilled_step_idx: + engine_recycled_count += 1 + self.last_step_idx = prefilled_step_idx # only copy value read from shm memory + prefilled_step_idx = ( + prefilled_step_idx + max_step_idx * engine_recycled_count + ) # remap prefilled_step_idx for comparison + + logger.debug( + f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx in shm: {self.last_step_idx}," + f"prefilled_step_idx: {prefilled_step_idx} engine_recycled_count {engine_recycled_count}" + ) for req_id, item in list(self.cache_info.items()): if "status" not in item: continue @@ -294,12 +361,493 @@ class CacheMessager: logger.info(f"finish write cache {item['request_id']}") self.engine_worker_queue.finish_request_barrier.wait() if self.rank == 0: + # to do: robust in TP: here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok. self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")]) logger.info(f"put write cache {item['request_id']}") del self.cache_info[req_id] - - self.last_step_idx = prefilled_step_idx - self.last_layer_idx = prefilled_layer_idx + self.last_layer_idx = prefilled_layer_idx except Exception as e: logger.error(f"prefill layerwise send cache thread has exception: {e}, {str(traceback.format_exc())}") + + def _handle_connect_task(self): + while True: + try: + task = self.engine_worker_queue.get_connect_rdma_task() + if task is None: + time.sleep(0.001) + continue + logger.info(f"_handle_connect_task recv task: {task}") + task_id = task["task_id"] + ip, rdma_port = task["ip"], task["rdma_port"] + status = self.messager["rdma"].connect(ip, rdma_port) + if not status: + response = {"task_id": task_id, "success": False} + else: + response = {"task_id": task_id, "success": True} + self.engine_worker_queue.put_connect_rdma_task_response(response) + except Exception as e: + logger.error(f"handle_connect_task has exception: {e}") + + +class CacheMessagerV1: + """ + CacheMessager is used to send the cache data between the engine worker and the cache server. + """ + + def __init__( + self, + splitwise_role, + transfer_protocol, + pod_ip, + engine_worker_queue_port, + local_data_parallel_id, + gpu_cache_kvs, + rank, + nranks, + num_layers, + gpu_id=0, + block_size=64, + rdma_port=None, + ): + """ + Initialize the CacheMessager object. + + Args: + splitwise_role (str): splitwise_role only can be 'prefill' or 'decode'. + transfer_protocol (str): support ipc and rdma + engine_worker_queue_port (int): engine_worker_queue port + gpu_cache_kvs (dict): GPU kv cache + rank (int): current rank + nranks (int): global rank number + num_layers (int): model layer number + gpu_id (int, optional): GPU ID + rdma_port (int, optional): RDMA port + + Returns: + None + """ + self.splitwise_role = splitwise_role + self.gpu_cache_kvs = gpu_cache_kvs + self.rank = rank + self.nranks = nranks + address = (pod_ip, engine_worker_queue_port) + self.engine_worker_queue = EngineWorkerQueue( + address=address, + is_server=False, + num_client=self.nranks, + client_id=self.rank, + local_data_parallel_id=local_data_parallel_id, + ) + self.block_size = block_size + transfer_protocol = transfer_protocol.split(",") + + logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}") + + # 1. initialize the cache_k_ptr_list and cache_v_ptr_list + self.num_layers = num_layers + cache_k_ptr_list = [] + cache_v_ptr_list = [] + cache_k = [] + cache_v = [] + self.messager = {} + for layer_idx in range(self.num_layers): + key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"] + val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"] + cache_k.append(key_cache) + cache_v.append(val_cache) + cache_k_ptr_list.append(key_cache.data_ptr()) + cache_v_ptr_list.append(val_cache.data_ptr()) + cache_k_ptr_list = np.array(cache_k_ptr_list) + cache_v_ptr_list = np.array(cache_v_ptr_list) + + # 2. initialize the block_bytes + cache_shape = key_cache.shape + max_block_num = cache_shape[0] + block_bytes = math.prod(cache_shape[1:]) + if key_cache.dtype == paddle.bfloat16: + block_bytes *= 2 + logger.info( + f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, " + f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}" + ) + self.block_bytes = block_bytes + + # 3. initialize the messager + for protocol in transfer_protocol: + if protocol == "ipc": + self.messager[protocol] = IPCCommManager( + self.rank, + gpu_id, + cache_k, + cache_v, + ) + local_device_id = int(str(cache_k[0].place)[-2]) + logger.info(f"done create ipc_comm with local_device_id:{local_device_id}, ") + + elif protocol == "rdma": + logger.info(f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}") + + self.messager[protocol] = RDMACommManager( + splitwise_role, + rank, + gpu_id, + cache_k_ptr_list, + cache_v_ptr_list, + max_block_num, + block_bytes, + rdma_port, + ) + + self.gpu_id = gpu_id + self.cache_info = dict() + self.rank_id = self.rank + local_data_parallel_id * self.nranks + self.engine_cache_task_thread_lock = threading.Lock() + self.engine_cache_tasks = [dict() for _ in range(512)] + self.idx_cache_task_dict = {} + self.cache_prefilled_engine_ids_queue = queue.Queue() # keep batch slot index for each prefill step + if splitwise_role == "prefill": + consume_signals_thread = threading.Thread(target=self.consume_signals) + consume_signals_thread.daemon = True + consume_signals_thread.start() + add_cache_task_thread = threading.Thread(target=self._add_cache_task_thread) + add_cache_task_thread.daemon = True + add_cache_task_thread.start() + + if self.splitwise_role != "mixed": + connect_rdma_thread = threading.Thread(target=self._handle_connect_task) + connect_rdma_thread.daemon = True + connect_rdma_thread.start() + + logger.info(f"cache messager init finished, use {transfer_protocol}") + + def _add_cache_task_thread(self): + while True: + try: + cache_info = self.engine_worker_queue.get_cache_info() + self.engine_worker_queue.finish_add_cache_task_barrier.wait() + finished_add_cache_task_req_ids = [] + if cache_info: + for info in cache_info: + if info["request_id"] in self.cache_info: + self.cache_info[info["request_id"]].update(info) + current_info = self.cache_info[info["request_id"]] + assert "dest_block_ids" in current_info and "src_block_ids" in current_info + finished_add_cache_task_req_ids.append(info["request_id"]) + decode_cached_block_num = len(current_info["src_block_ids"]) - len( + current_info["dest_block_ids"] + ) + padding_decode_block_ids = [-1 for i in range(decode_cached_block_num)] + current_info[ + "dest_block_ids" + ] + current_info["dest_block_ids"] = padding_decode_block_ids + current_info["decode_cached_tokens"] = decode_cached_block_num * self.block_size + current_info["sended_layer_id"] = -1 + current_info["sended_block_num"] = current_info["decode_cached_tokens"] // self.block_size + current_info["status"] = "init" + logger.info(f"finish add cache task: {current_info}") + self.cache_info[info["request_id"]] = current_info + self.idx_cache_task_dict[current_info["current_id"]] = current_info + else: + self.cache_info[info["request_id"]] = info + if self.rank == 0 and finished_add_cache_task_req_ids: + self.engine_worker_queue.put_finished_add_cache_task_req(finished_add_cache_task_req_ids) + else: + time.sleep(0.001) + except Exception as e: + logger.info(f"add cache task occured error: {e}, {traceback.format_exc()!s}.") + + def prefill_layerwise_send_cache_thread(self): + """ + layerwise_send_cache_thread: + send cache to other instance + """ + while True: + try: + engine_indexes = self.cache_prefilled_engine_ids_queue.get() + self.engine_worker_queue.finish_request_barrier.wait() + block_start_end_list = [] + current_prefilled_token_num_list = [] + for engine_index in engine_indexes: + assert engine_index in self.idx_cache_task_dict + block_id_start = self.idx_cache_task_dict[engine_index]["sended_block_num"] + prefilled_token_num = self.engine_cache_tasks[engine_index]["prefilled_token_num"] + if ( + prefilled_token_num == self.idx_cache_task_dict[engine_index]["need_prefill_tokens"] + ): # all chunks have been prefilled + block_id_end = len(self.idx_cache_task_dict[engine_index]["src_block_ids"]) + else: + block_id_end = prefilled_token_num // self.block_size # [block_id_start, block_id_end) + block_start_end_list.append((block_id_start, block_id_end)) + current_prefilled_token_num_list.append(prefilled_token_num) + while True: # from layer0 to last layer + sended_layer_idx = self.idx_cache_task_dict[engine_indexes[0]]["sended_layer_id"] + start_layer_idx = sended_layer_idx + 1 + with self.engine_cache_task_thread_lock: # to check end_layer_idx + prefilled_layer_idx = self.engine_cache_tasks[engine_indexes[0]]["prefilled_layer_idx"] + if sended_layer_idx > prefilled_layer_idx: # computation must in next chunk + logger.info( + f"current_prefilled_token_num_list[0] {current_prefilled_token_num_list[0]} prefilled_token_num {self.engine_cache_tasks[engine_indexes[0]]['prefilled_token_num']}" + ) + assert ( + current_prefilled_token_num_list[0] + < self.engine_cache_tasks[engine_indexes[0]]["prefilled_token_num"] + ), "when sended_layer_idx > prefilled_layer_idx, must be in next chunk, but not, sth wrong" + end_layer_idx = self.num_layers - 1 # [start_layer_idx, end_layer_idx) + else: + end_layer_idx = prefilled_layer_idx + if sended_layer_idx == prefilled_layer_idx: # computation not in next layer + time.sleep(0.01) + for layer_idx in range(start_layer_idx, end_layer_idx + 1): + for i, (block_id_start, block_id_end) in enumerate(block_start_end_list): + engine_index = engine_indexes[i] + task = self.idx_cache_task_dict[engine_index] + req_id = task["request_id"] + if ( + block_id_start >= block_id_end + ): # no blocks need to transfer for this request in this chunk + task["sended_layer_id"] += 1 + assert task["sended_layer_id"] == layer_idx + if task["sended_layer_id"] == self.num_layers - 1: + task["sended_layer_id"] = -1 + continue + else: + current_transfer_protocol = task["transfer_protocol"] + if task["transfer_protocol"] == "rdma": + target_ip = task["ip"] + target_id = int(task["rdma_ports"][self.rank]) + if task["status"] == "error": + continue + status = self.messager[current_transfer_protocol].connect(target_ip, target_id) + if not status: + logger.error(f"connect to {target_ip}:{target_id} failed") + task["status"] = "connection error" + continue + elif task["transfer_protocol"] == "ipc": + target_ip = "0.0.0.0" + target_id = int(task["device_ids"][self.rank]) + + src_block_ids = task["src_block_ids"][block_id_start:block_id_end] + dest_block_ids = task["dest_block_ids"][block_id_start:block_id_end] + src_block_ids = paddle.to_tensor(src_block_ids, dtype="int32", place="cpu") + dest_block_ids = paddle.to_tensor(dest_block_ids, dtype="int32", place="cpu") + + logger.info( + f"start write cache for a layer, {req_id}, {layer_idx}, {target_ip}, {target_id}, block_id_start {block_id_start} block_id_end {block_id_end}" + ) + tic = time.time() + return_code = self.messager[current_transfer_protocol].write_cache( + target_ip, + target_id, + src_block_ids, + dest_block_ids, + layer_idx, + ) + if return_code != 0: + task["status"] = "write cache error" + logger.error( + f"write cache failed, layer_idx: {layer_idx}, req_id: {req_id}, dest_ip: {target_ip}, block_id_start {block_id_start} block_id_end {block_id_end}" + ) + tok = time.time() + cost_time = tok - tic + block_num = len(src_block_ids) + avg_time_per_block = cost_time * 1000 / block_num # ms + send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s + logger.debug( + f"finish write cache for a layer, {req_id}, {layer_idx}, {target_ip}, {target_id}," + f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)}," + f"avg_time per block(ms): {round(avg_time_per_block, 5)} block_id_start {block_id_start} block_id_end {block_id_end}" + ) + + task["sended_layer_id"] += 1 + assert task["sended_layer_id"] == layer_idx + if task["sended_layer_id"] == self.num_layers - 1: + self.idx_cache_task_dict[engine_index]["sended_block_num"] += ( + block_id_end - block_id_start + ) + if current_prefilled_token_num_list[i] == task["need_prefill_tokens"]: + if task["status"] != "error": + task["status"] = "finished" + logger.info( + f"finish write cache for all layers, req_id: {req_id}, block_id_end {block_id_end} need_prefill_tokens {task['need_prefill_tokens']}" + ) + else: + task["sended_layer_id"] = -1 + if end_layer_idx == self.num_layers - 1: + with self.engine_cache_task_thread_lock: + for engine_idx in engine_indexes: + task = self.idx_cache_task_dict[engine_idx] + if task["status"] == "finished" or ("error" in task["status"]): + target_id = int(task["rdma_ports"][self.rank]) + if task["transfer_protocol"] == "ipc": + self.messager["ipc"].write_block_by_sync(target_id) + if self.rank == 0: + # to do: robust in TP, here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok. + self.engine_worker_queue.put_finished_req( + [(task["request_id"], task["status"])] + ) + logger.info(f"put write cache {task['request_id']}, status {task['status']}") + self.engine_cache_tasks[task["current_id"]] = dict() + del self.cache_info[task["request_id"]] + del self.idx_cache_task_dict[task["current_id"]] + break + except Exception as e: + logger.error(f"prefill layerwise send cache thread has exception: {e} {traceback.format_exc()!s}") + time.sleep(0.01) + + def consume_signals(self): + paddle.device.set_device("cpu") + kv_signal_data = paddle.full(shape=[512 * 3 + 2], fill_value=-1, dtype="int32") + while True: + try: + get_output_kv_signal(kv_signal_data, self.rank_id, 0) # wait_flag + if not self.cache_info: + time.sleep(0.01) + continue + tasks_count = kv_signal_data[0] + if tasks_count == -1: + time.sleep(0.001) + continue + layer_id = kv_signal_data[1].numpy().tolist() + if layer_id == self.num_layers - 1: + logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id}") + batch_engine_ids = [] + with self.engine_cache_task_thread_lock: + for bi in range(tasks_count): + engine_idx = kv_signal_data[3 * bi + 2].numpy().tolist() + chuck_token_offset = kv_signal_data[3 * bi + 3].numpy().tolist() + current_seq_len = kv_signal_data[3 * bi + 4].numpy().tolist() + self.engine_cache_tasks[engine_idx]["prefilled_layer_idx"] = layer_id + self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = ( + chuck_token_offset + current_seq_len + ) + batch_engine_ids.append(engine_idx) + if layer_id == 0: + self.cache_prefilled_engine_ids_queue.put(batch_engine_ids) + except Exception as e: + logger.error(f"Consume signals get exception: {e}") + + def _handle_connect_task(self): + while True: + try: + task = self.engine_worker_queue.get_connect_rdma_task() + if task is None: + time.sleep(0.001) + continue + logger.info(f"_handle_connect_task recv task: {task}") + task_id = task["task_id"] + ip, rdma_port = task["ip"], task["rdma_port"] + status = self.messager["rdma"].connect(ip, rdma_port) + if not status: + response = {"task_id": task_id, "success": False} + else: + response = {"task_id": task_id, "success": True} + self.engine_worker_queue.put_connect_rdma_task_response(response) + except Exception as e: + logger.error(f"handle_connect_task has exception: {e}") + + +def main(): + device = args.device_id + rank = args.rank + paddle.set_device(f"gpu:{device}") + cache_type = args.cache_dtype + speculative_config = SpeculativeConfig(args.speculative_config) + num_extra_layers = speculative_config.num_extra_cache_layer + num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * speculative_config.num_gpu_block_expand_ratio) + gpu_cache_kvs = {} + gpu_cache_k_tensors = [] + gpu_cache_v_tensors = [] + + for i in range(args.num_layers + num_extra_layers): + num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else num_extra_layer_gpu_blocks + + gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full( + shape=[ + num_gpu_blocks, + args.kv_num_head, + args.block_size, + args.head_dim, + ], + fill_value=0, + dtype=cache_type, + ) + gpu_cache_k_tensors.append(gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"]) + gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full( + shape=[ + num_gpu_blocks, + args.kv_num_head, + args.block_size, + args.head_dim, + ], + fill_value=0, + dtype=cache_type, + ) + gpu_cache_v_tensors.append(gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"]) + + set_data_ipc( + gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"], + f"key_caches_{i}_rank{rank}.device{device}", + ) + set_data_ipc( + gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"], + f"value_caches_{i}_rank{rank}.device{device}", + ) + cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()]) + logger.info(f"device :{device}") + logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}") + logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}") + + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + cache_messager = CacheMessagerV1( + splitwise_role=args.splitwise_role, + transfer_protocol=args.protocol, + pod_ip=args.pod_ip, + engine_worker_queue_port=args.engine_worker_queue_port, + local_data_parallel_id=args.local_data_parallel_id, + gpu_cache_kvs=gpu_cache_kvs, + rank=rank, + nranks=args.mp_num, + num_layers=args.num_layers + num_extra_layers, + gpu_id=device, + rdma_port=args.rdma_port, + ) + else: + cache_messager = CacheMessager( + splitwise_role=args.splitwise_role, + transfer_protocol=args.protocol, + pod_ip=args.pod_ip, + engine_worker_queue_port=args.engine_worker_queue_port, + local_data_parallel_id=args.local_data_parallel_id, + gpu_cache_kvs=gpu_cache_kvs, + rank=rank, + nranks=args.mp_num, + num_layers=args.num_layers + num_extra_layers, + gpu_id=device, + rdma_port=args.rdma_port, + ) + + cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32) + cache_ready_signal = IPCSignal( + name="cache_ready_signal", + array=cache_ready_signal_data, + dtype=np.int32, + suffix=args.engine_pid, + create=False, + ) + cache_ready_signal.value[rank] = 1 + if args.splitwise_role == "mixed": + while True: + time.sleep(1) + cache_messager.prefill_layerwise_send_cache_thread() + + +if __name__ == "__main__": + + args = parse_args() + rank_id = args.rank + args.local_data_parallel_id * args.mp_num + logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log") + + logger.info("create cache messager...") + logger.info(f"{args}") + main() diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index cb793df44..fead9f8cb 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -29,7 +29,7 @@ from fastdeploy.config import SpeculativeConfig from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal from fastdeploy.model_executor.ops.gpu import ( cuda_host_alloc, - set_data_ipc, + share_external_data, swap_cache_all_layers, ) from fastdeploy.utils import get_logger @@ -139,40 +139,27 @@ class CacheTransferManager: self.num_cpu_blocks = args.num_cpu_blocks cache_type = args.cache_dtype + cache_shape = [ + args.num_gpu_blocks, + args.kv_num_head, + args.block_size, + args.head_dim, + ] + for i in range(args.num_layers + self.num_extra_layers): num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks + cache_shape[0] = num_gpu_blocks + key_name = f"key_caches_{i}_rank{rank}.device{device}" + value_name = f"value_caches_{i}_rank{rank}.device{device}" + key_cache = paddle.empty(shape=[], dtype=cache_type) + value_cache = paddle.empty(shape=[], dtype=cache_type) + key_cache = share_external_data(key_cache, key_name, cache_shape) + value_cache = share_external_data(value_cache, value_name, cache_shape) + self.gpu_cache_kvs[key_name] = key_cache + self.gpu_cache_kvs[value_name] = value_cache + self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name]) + self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[value_name]) - self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full( - shape=[ - num_gpu_blocks, - args.kv_num_head, - args.block_size, - args.head_dim, - ], - fill_value=0, - dtype=cache_type, - ) - self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"]) - self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full( - shape=[ - num_gpu_blocks, - args.kv_num_head, - args.block_size, - args.head_dim, - ], - fill_value=0, - dtype=cache_type, - ) - self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"]) - - set_data_ipc( - self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"], - f"key_caches_{i}_rank{rank}.device{device}", - ) - set_data_ipc( - self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"], - f"value_caches_{i}_rank{rank}.device{device}", - ) cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()]) logger.info(f"device :{self.device}") logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}") @@ -201,28 +188,6 @@ class CacheTransferManager: ) self.cache_ready_signal.value[self.rank] = 1 - paddle.set_device(f"gpu:{device}") - if args.enable_splitwise: - logger.debug("create cache messager...") - logger.info(f"{args}") - from fastdeploy.cache_manager.cache_messager import CacheMessager - - self.cache_messager = CacheMessager( - splitwise_role=args.splitwise_role, - transfer_protocol=args.protocol, - pod_ip=args.pod_ip, - engine_worker_queue_port=args.engine_worker_queue_port, - local_data_parallel_id=args.local_data_parallel_id, - gpu_cache_kvs=self.gpu_cache_kvs, - rank=self.rank, - nranks=args.mp_num, - num_layers=args.num_layers + self.num_extra_layers, - gpu_id=self.device, - rdma_port=args.rdma_port, - ) - logger.info("successfully create cache messager") - logger.info(f"done init CacheMessager gmem alloc : {paddle.device.cuda.memory_allocated()}") - cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32) self.cache_task_broadcast_signal = IPCSignal( name="cache_task_broadcast_signal", @@ -443,5 +408,7 @@ def main(): if __name__ == "__main__": args = parse_args() - logger = get_logger("cache_transfer_manager", "cache_transfer_manager.log") + rank_id = args.rank + args.local_data_parallel_id * args.mp_num + logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_rank{rank_id}.log") + paddle.set_device(f"gpu:{args.device_id}") main() diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index a0b110bde..8b3c8798e 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -150,6 +150,19 @@ class PrefixCacheManager: filename = "cache_transfer_manager.py" py_path = os.path.join(current_dir_path, filename) + cache_messager_processes = [] + cache_messager_processes = self.launch_cache_messager( + cache_config, + tensor_parallel_size, + device_ids, + pod_ip, + engine_worker_queue_port, + pid_suffix, + ) + if cache_messager_processes is None: + raise RuntimeError("Launch cache messager failed") + return [] + if ( hasattr(cache_config.model_cfg, "num_key_value_heads") and hasattr(cache_config.model_cfg, "num_key_value_heads") @@ -213,7 +226,76 @@ class PrefixCacheManager: if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0: logger.info("Enable hierarchical cache.") self._enable_cpu_cache() - return cache_manager_processes + all_cache_processes = cache_messager_processes + cache_manager_processes + return all_cache_processes + + def launch_cache_messager( + self, cache_config, tensor_parallel_size, device_ids, pod_ip, engine_worker_queue_port, pid_suffix + ): + """ + launch_cache_messager function used to initialize the cache messager. + """ + current_dir_path = os.path.split(os.path.abspath(__file__))[0] + filename = "cache_messager.py" + if ( + hasattr(cache_config.model_cfg, "num_key_value_heads") + and hasattr(cache_config.model_cfg, "num_key_value_heads") + and cache_config.model_cfg.num_key_value_heads is not None + and int(cache_config.model_cfg.num_key_value_heads) > 0 + ): + kv_num_head = int(cache_config.model_cfg.num_key_value_heads) // tensor_parallel_size + else: + kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size + + cache_ready_signal_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32) + self.cache_ready_signal = IPCSignal( + name="cache_ready_signal", + array=cache_ready_signal_data, + dtype=np.int32, + suffix=pid_suffix, + create=True, + ) + + py_path = os.path.join(current_dir_path, filename) + log_dir = envs.FD_LOG_DIR + cache_messager_processes = [] + for i in range(tensor_parallel_size): + launch_cmd = ( + "FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7" + + " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0" + + f" {sys.executable} {py_path}" + + f" --device_id {int(device_ids[i])}" + + f" --rank {i}" + + f" --splitwise_role {self.splitwise_role}" + + f" --num_layers {cache_config.model_cfg.num_hidden_layers}" + + f" --head_dim {cache_config.model_cfg.head_dim}" + + f" --kv_num_head {kv_num_head}" + + f" --mp_num {tensor_parallel_size}" + + f" --cache_dtype {cache_config.cache_dtype}" + + f" --pod_ip {pod_ip}" + + f" --cache_queue_port {cache_config.cache_queue_port}" + + f" --engine_worker_queue_port {engine_worker_queue_port}" + + f" --num_gpu_blocks {cache_config.total_block_num}" + + f" --block_size {cache_config.block_size}" + + f" --protocol {cache_config.cache_transfer_protocol}" + + f" --local_data_parallel_id {self.local_data_parallel_id}" + + f" --engine_pid {pid_suffix}" + + f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}" + + f" --speculative_config '{self.speculative_config.to_json_string()}'" + + f" >{log_dir}/launch_cache_messager_{int(device_ids[i])}.log 2>&1" + ) + logger.info(f"Launch cache messager, command:{launch_cmd}") + cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid)) + logger.info("Waiting for cache ready...") + while np.sum(self.cache_ready_signal.value) != tensor_parallel_size: + time.sleep(1) + exit_code = cache_messager_processes[-1].poll() + if exit_code is None: + logger.info("Launch cache messager successful") + else: + logger.info("Launch cache messager failed, see launch_cache_messager.log for more information") + cache_messager_processes = None + return cache_messager_processes def update_cache_config(self, cache_config): """ diff --git a/fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py b/fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py index 94abbb3b8..6a0c0ac36 100644 --- a/fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py +++ b/fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py @@ -61,18 +61,12 @@ class RDMACommManager: Connect to remote gpu and write cache. """ assert self.splitwise_role == "prefill", "only prefill can call this method" - addr = f"{ip}:{port!s}" - if addr in self.connected_rdma: - return True ret = self.messager.is_connected(ip, str(port)) if ret: - self.connected_rdma.add(addr) return True ret = self.messager.connect(ip, str(port)) logger.info(f"connect to remote rdma address {ip}:{port} status is {ret}") - if ret == 0: - self.connected_rdma.add(addr) return ret == 0 def write_cache(self, ip, port, local_block_ids, remote_block_ids, layer_idx): diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 741daa512..42f86bc85 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1481,7 +1481,7 @@ class FDConfig: self.model_config.model_format = "torch" # TODO - self.max_prefill_batch = 3 + self.max_prefill_batch = int(os.getenv("MAX_PREFILL_NUM", "3")) if current_platform.is_xpu(): self.max_prefill_batch = 1 if self.model_config is not None and self.model_config.enable_mm: diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 1181dba7d..82dacc1c2 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -422,7 +422,7 @@ class EngineArgs: raise NotImplementedError("Only CUDA platform supports logprob.") if self.speculative_config is not None: envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 - if self.splitwise_role != "mixed": + if self.splitwise_role != "mixed" and self.cache_transfer_protocol != "rdma": envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 if not current_platform.is_cuda(): envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index b2f9c39e5..0456f9ade 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -46,7 +46,7 @@ from fastdeploy.model_executor.guided_decoding import schema_checker from fastdeploy.plugins.token_processor import load_token_processor_plugins from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector -from fastdeploy.utils import EngineError, envs, llm_logger +from fastdeploy.utils import EngineError, envs, get_logger, llm_logger try: TokenProcessor = load_token_processor_plugins() @@ -69,6 +69,13 @@ class EngineService: """ self.cfg = cfg + if self.cfg.parallel_config.enable_expert_parallel: + self.llm_logger = get_logger( + "fastdeploy", f"fastdeploy_rank{self.cfg.parallel_config.local_data_parallel_id}.log" + ) + else: + self.llm_logger = llm_logger + self.scheduler = cfg.scheduler_config.scheduler() if envs.ENABLE_V1_KVCACHE_SCHEDULER: @@ -79,10 +86,6 @@ class EngineService: cfg.scheduler_config.splitwise_role, cfg.parallel_config.local_data_parallel_id, ) - if cfg.scheduler_config.splitwise_role != "mixed": - raise NotImplementedError( - "Currently ENABLE_V1_KVCACHE_SCHEDULER=1 only supported in mixed sampling now." - ) else: self.resource_manager = ResourceManager( cfg.scheduler_config.max_num_seqs, @@ -135,12 +138,14 @@ class EngineService: self.insert_task_to_worker_thread.start() self.token_processor.tasks_queue = self.engine_worker_queue self.token_processor.run() + if self.cfg.scheduler_config.splitwise_role != "mixed": + self.split_mode_get_tasks() def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进程感知是否有新Task需要处理 current_suffix = int( self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id] ) - llm_logger.info(f"current_suffix: {current_suffix}") + self.llm_logger.info(f"current_suffix: {current_suffix}") exist_task_signal_data = np.zeros([1], dtype=np.int32) self.exist_task_signal = IPCSignal( name="exist_task_signal", @@ -201,7 +206,7 @@ class EngineService: ) if start_queue and (self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0"): - llm_logger.info(f"Starting engine worker queue server service at {address}") + self.llm_logger.info(f"Starting engine worker queue server service at {address}") self.engine_worker_queue_server = EngineWorkerQueue( address=address, is_server=True, @@ -225,7 +230,7 @@ class EngineService: client_id=-1, local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, ) - llm_logger.info( + self.llm_logger.info( f"local {min(self.cfg.worker_num_per_node * self.cfg.node_rank + self.cfg.parallel_config.local_data_parallel_id,self.cfg.parallel_config.data_parallel_size - 1)}" ) self.engine_worker_queue = EngineWorkerQueue( @@ -254,7 +259,17 @@ class EngineService: cur_task_idx = self.resource_manager.req_dict[task.request_id] del self.resource_manager.req_dict[task.request_id] cur_task = self.resource_manager.tasks_list[cur_task_idx] + if envs.FD_ENABLE_INTERNAL_ADAPTER: + if not task.outputs.token_ids: # first token is eos in Prefill, just recycle resource and continue + self.resource_manager.stop_flags[cur_task_idx] = True + self.resource_manager.tasks_list[cur_task_idx] = None + self.resource_manager._recycle_block_tables(cur_task) + if task.request_id in self.token_processor.tokens_counter: + del self.token_processor.tokens_counter[task.request_id] + self.llm_logger.warning(f"{task.request_id} need not decode after first token") + continue cur_task.prompt_token_ids[0] = task.outputs.token_ids[0] + cur_task.num_cached_tokens = task.num_cached_tokens if ( self.cfg.speculative_config.method in ["mtp"] and self.cfg.scheduler_config.splitwise_role == "decode" @@ -267,13 +282,14 @@ class EngineService: if task.request_id in self.token_processor.tokens_counter: del self.token_processor.tokens_counter[task.request_id] self.scheduler.put_results([task]) - llm_logger.warning( + self.llm_logger.warning( f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource." ) continue self.token_processor.tokens_counter[task.request_id] = 1 current_tasks.append(cur_task) - self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz)) + if current_tasks: + self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz)) return True self.resource_manager.check_and_free_block_tables() @@ -281,13 +297,34 @@ class EngineService: if not isinstance(tasks, list): tasks = [tasks] + need_delete_tasks = [] + for task in tasks: + if self.cfg.scheduler_config.splitwise_role != "mixed": + status, msg = self.split_connector.check_decode_allocated(task) + if not status: + self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.") + self.scheduler.put_results( + [ + RequestOutput( + request_id=task.request_id, + finished=True, + error_code=500, + error_msg=msg, + ) + ] + ) + need_delete_tasks.append(task) + continue + for tmp_task in need_delete_tasks: + tasks.remove(tmp_task) + for item in tasks: item.schedule_start_time = time.time() available_batch = np.sum(self.resource_manager.stop_flags) if len(tasks) > available_batch: - llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.") - llm_logger.error("The exceeded part will be ignored!") + self.llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.") + self.llm_logger.error("The exceeded part will be ignored!") tasks = tasks[:available_batch] req_ids = [t.request_id for t in tasks] @@ -296,7 +333,7 @@ class EngineService: if not tasks: error_msg = f"The request required resources is exceed the limit, request id={req_ids}." - llm_logger.error(error_msg) + self.llm_logger.error(error_msg) raise EngineError(error_msg, error_code=500) return False @@ -314,7 +351,7 @@ class EngineService: self.split_connector.send_cache_infos(tasks, current_id) if not is_decode: - llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}") + self.llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}") for task in tasks: task.inference_start_time = time.time() if not is_prefill: @@ -473,7 +510,7 @@ class EngineService: Insert task to engine thread, monitor scheduler request queue. if the engine has resource, insert task to engine """ - current_id = -1 + current_id = 0 while getattr(self, "running", True): try: if self.resource_manager.available_batch() == 0: @@ -514,18 +551,21 @@ class EngineService: time.sleep(0.001) continue - current_id = (current_id + 1) % 100003 if self.cfg.scheduler_config.splitwise_role != "mixed": - llm_logger.info("Inserting splitwise tasks") + self.llm_logger.info("Inserting splitwise tasks") self.split_connector.send_splitwise_tasks(tasks, current_id) - self.insert_tasks(tasks, current_id) + insert_successful = self.insert_tasks(tasks, current_id) + if insert_successful: + current_id = current_id + 1 + else: + continue main_process_metrics.num_requests_waiting.dec(len(tasks)) main_process_metrics.num_requests_running.inc(len(tasks)) except Exception as e: - err_msg = f"Error happened while insert task to engine: {e}, {traceback.format_exc()!s}." - llm_logger.error(err_msg) + err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}." + self.llm_logger.error(err_msg) def _scheduler_task_to_worker_v1(self): """ @@ -535,40 +575,100 @@ class EngineService: is_fetching = False def _fetch_request(): - nonlocal is_fetching - is_fetching = True - num_prefill_batch = min( - int(self.resource_manager.available_batch()), - self.cfg.max_prefill_batch, - ) - if self.cfg.model_config.enable_mm: - available_blocks = self.resource_manager.available_block_num() - else: - available_blocks = self.cfg.cache_config.max_block_num_per_seq + try: + nonlocal is_fetching + is_fetching = True + num_prefill_batch = min( + int(self.resource_manager.available_batch()), + self.cfg.max_prefill_batch, + ) + if self.cfg.model_config.enable_mm: + available_blocks = self.resource_manager.available_block_num() + else: + available_blocks = self.cfg.cache_config.max_block_num_per_seq - tasks = self.scheduler.get_requests( - available_blocks=available_blocks, - block_size=self.cfg.cache_config.block_size, - reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num, - max_num_batched_tokens=self.cfg.max_model_len, - batch=num_prefill_batch, - ) - # Fetch requests and add them to the scheduling queue - for task in tasks: - self.resource_manager.add_request(task) - is_fetching = False + tasks = self.scheduler.get_requests( + available_blocks=available_blocks, + block_size=self.cfg.cache_config.block_size, + reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num, + max_num_batched_tokens=self.cfg.max_model_len, + batch=num_prefill_batch, + ) + if self.cfg.scheduler_config.splitwise_role != "mixed": + for task in tasks: + # assure can allocate block ids in P + while not self.resource_manager.preallocate_resource_in_p(task): + time.sleep(0.005) + self.llm_logger.info(f"ask D resource for req_id: {task.request_id}") + self.split_connector.send_splitwise_tasks([task], task.idx) + need_delete_tasks = [] + for task in tasks: + if self.cfg.scheduler_config.splitwise_role != "mixed": + # assure fetch block ids from D + status, msg = self.split_connector.check_decode_allocated(task) + if not status: + self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.") + self.scheduler.put_results( + [ + RequestOutput( + request_id=task.request_id, + finished=True, + error_code=500, + error_msg=msg, + ) + ] + ) + need_delete_tasks.append(task) + continue + for tmp_task in need_delete_tasks: + tasks.remove(tmp_task) + # release resource in P + self.resource_manager.prerelease_resource(task) + if self.cfg.scheduler_config.splitwise_role == "prefill": + # to send cache info to cache messager + if tasks: + self.split_connector.send_cache_infos(tasks, 0) + # ensure cache tasks has sent to cache_messager + need_check_req_ids = [task.request_id for task in tasks] + while need_check_req_ids: + req_ids = self.engine_worker_queue.get_finished_add_cache_task_req() + self.llm_logger.info(f"get_finished_add_cache_task_req: {req_ids}") + if req_ids: + for req_id in req_ids: + assert req_id in need_check_req_ids + need_check_req_ids.remove(req_id) + else: + time.sleep(0.001) + # Fetch requests and add them to the scheduling queue + if tasks: + if self.cfg.scheduler_config.splitwise_role == "prefill": + self.resource_manager.add_request_in_p(tasks) + else: + for task in tasks: + self.resource_manager.add_request(task) + is_fetching = False + except Exception as e: + self.llm_logger.error(f"fetching request error {e} {str(traceback.format_exc())}") + is_fetching = False while self.running: try: if self.engine_worker_queue.num_tasks() > 0: time.sleep(0.001) continue - if ( - len(self.resource_manager.waiting) == 0 - and (not is_fetching) - and self.exist_prefill_task_signal.value[0] == 0 - ): - get_request_pool.submit(_fetch_request) + if self.cfg.scheduler_config.splitwise_role != "mixed": + if self.scheduler.get_unhandled_request_num() <= envs.FD_EP_MAX_PREFETCH_TASK_NUM and ( + not is_fetching + ): + get_request_pool.submit(_fetch_request) + + else: + if ( + len(self.resource_manager.waiting) == 0 + and (not is_fetching) + and self.exist_prefill_task_signal.value[0] == 0 + ): + get_request_pool.submit(_fetch_request) # 2. Schedule requests tasks = self.resource_manager.schedule() # 3. Send to engine @@ -579,8 +679,8 @@ class EngineService: time.sleep(0.005) except Exception as e: - err_msg = "Error happened while insert task to engine: {}, {}.".format(e, str(traceback.format_exc())) - llm_logger.error(err_msg) + err_msg = "Error happend while insert task to engine: {}, {}.".format(e, str(traceback.format_exc())) + self.llm_logger.error(err_msg) def start_zmq_service(self, api_server_pid=None): if api_server_pid is None: @@ -608,6 +708,9 @@ class EngineService: def _insert_zmq_task_to_scheduler(self): added_requests: Dict[str, int] = dict() + if envs.FD_ENABLE_INTERNAL_ADAPTER: + if self.cfg.scheduler_config.splitwise_role == "decode": + return while self.running: try: block = True if len(added_requests) == 0 else False @@ -616,7 +719,7 @@ class EngineService: else: err, data = self.recv_request_server.receive_pyobj_once(block) if err is not None: - llm_logger.error(f"Engine stops inserting zmq task into scheduler, err:{err}") + self.llm_logger.error(f"Engine stops inserting zmq task into scheduler, err:{err}") break request, insert_task = None, [] @@ -627,16 +730,16 @@ class EngineService: request = Request.from_dict(data) start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER) main_process_metrics.requests_number.inc() - llm_logger.debug(f"Receive request: {request}") + self.llm_logger.debug(f"Receive request: {request}") except Exception as e: - llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}") + self.llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}") err_msg = str(e) results.append((data["request_id"], err_msg)) if self.guided_decoding_checker is not None and err_msg is None: request, err_msg = self.guided_decoding_checker.schema_format(request) if err_msg is not None: - llm_logger.error(f"Receive request error: {err_msg}") + self.llm_logger.error(f"Receive request error: {err_msg}") results.append((request.request_id, err_msg)) if err_msg is None: @@ -670,7 +773,7 @@ class EngineService: # Send result by zmq directly self.send_response_server.send_response(request_id, [error_result]) except Exception as e: - llm_logger.error( + self.llm_logger.error( f"Error happened while receiving new request from zmq, details={e}, " f"traceback={traceback.format_exc()}" ) @@ -689,7 +792,7 @@ class EngineService: self.send_response_server.send_response(request_id, contents) except Exception as e: - llm_logger.error(f"Unexcepted error happened: {e}, {traceback.format_exc()!s}") + self.llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}") def split_mode_get_tasks(self): """ @@ -702,13 +805,22 @@ class EngineService: processed_indices = [] for idx, task in enumerate(self.waiting_requests): - if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): - self.insert_tasks([task]) - llm_logger.info(f"Resource available, processing task {task.request_id}") - processed_indices.append(idx) + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + if self.resource_manager.preallocate_resource_in_d(task): + self.llm_logger.info(f"Resource available, processing task {task.request_id}") + self.split_connector.send_cache_infos([task], -1) + processed_indices.append(idx) + else: + self.llm_logger.debug(f"Still waiting for resources {task.request_id}") + break else: - llm_logger.debug(f"Still waiting for resources {task.request_id}") - break + if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): + self.insert_tasks([task]) + self.llm_logger.info(f"Resource available, processing task {task.request_id}") + processed_indices.append(idx) + else: + self.llm_logger.debug(f"Still waiting for resources {task.request_id}") + break for idx in sorted(processed_indices, reverse=True): self.waiting_requests.pop(idx) @@ -730,32 +842,79 @@ class EngineService: tasks = [tasks] for task in tasks: task.finished = False - self.insert_tasks(tasks, allocated=True) - - if self.cfg.innode_prefill_ports is not None: - self.scheduler.put_results(tasks) + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + for task in tasks: + if envs.FD_ENABLE_INTERNAL_ADAPTER: + if ( + not task.outputs.token_ids + ): # first token is eos in Prefill, just recycle resource and continue + cur_task = self.resource_manager.requests[task.request_id] + self.resource_manager.stop_flags[cur_task.idx] = True + self.resource_manager.tasks_list[cur_task.idx] = None + self.resource_manager._free_blocks(cur_task) + if cur_task.request_id in self.token_processor.tokens_counter: + del self.token_processor.tokens_counter[task.request_id] + self.llm_logger.warning( + f"{task.request_id} need not decode after first token" + ) + del self.resource_manager.requests[task.request_id] + del self.resource_manager.req_dict[task.request_id] + continue + if task.error_code != 200: + cur_task = self.resource_manager.requests[task.request_id] + self.resource_manager.stop_flags[cur_task.idx] = True + self.resource_manager.tasks_list[cur_task.idx] = None + self.resource_manager._free_blocks(cur_task) + if cur_task.request_id in self.token_processor.tokens_counter: + del self.token_processor.tokens_counter[task.request_id] + self.scheduler.put_results([task]) + self.llm_logger.warning( + f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource." + ) + continue + self.resource_manager.insert_task_for_decoding(task) + else: + self.insert_tasks(tasks, allocated=True) + if self.cfg.innode_prefill_ports is not None: + self.scheduler.put_results(tasks) else: if len(self.waiting_requests): - llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}") + self.llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}") self.waiting_requests.extend(tasks) else: new_waiting = [] for task in tasks: - if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): - self.insert_tasks([task]) + can_allocate_resource = False + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + if self.resource_manager.preallocate_resource_in_d(task): + self.split_connector.send_cache_infos([task], -1) + can_allocate_resource = True else: + if self.resource_manager.is_resource_sufficient( + task.prompt_token_ids_len + ): + self.insert_tasks([task]) + can_allocate_resource = True + if can_allocate_resource is False: + if not self.enable_decode_cache_task: + task.error_msg = "Not enough resources" new_waiting.append(task) if new_waiting: - self.waiting_requests.extend(new_waiting) - llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue") + if not self.enable_decode_cache_task: + self.split_connector.send_cache_infos(new_waiting, -1) + else: + self.waiting_requests.extend(new_waiting) + self.llm_logger.info( + f"Added {len(new_waiting)} tasks to waiting queue" + ) else: time.sleep(0.001) except Exception as e: - llm_logger.error(f"Error in main loop: {e}") + self.llm_logger.error(f"Error in main loop: {e}") time.sleep(0.1) threading.Thread(target=receiver_loop, daemon=True).start() diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index bbd9f34b3..84890b1e1 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -120,11 +120,10 @@ class LLMEngine: self.data_processor = self.input_processor.create_processor() self.engine.data_processor = self.data_processor + # Launch components: scheduler, cache_manager, expert_service et.al. + self.launch_components() self.engine.start() - if api_server_pid is not None: - llm_logger.info(f"Start zmq server, api_server_pid: {api_server_pid}") - self.engine.start_zmq_service(api_server_pid) if self.do_profile == 0 and ( self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed" @@ -159,11 +158,14 @@ class LLMEngine: if self.do_profile: self._stop_profile() - # Launch components: scheduler, cache_manager, expert_service et.al. - self.launch_components() + if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": self.launched_cache_manager_signal.value[0] = 1 + if api_server_pid is not None: + llm_logger.info(f"Start zmq server, api_server_pid: {api_server_pid}") + self.engine.start_zmq_service(api_server_pid) + # Worker launched self.check_worker_initialize_status_func_thread.join() if not result_container["worker_is_alive"]: @@ -427,7 +429,10 @@ class LLMEngine: ) if self.cfg.scheduler_config.splitwise_role != "mixed": - variables["FLAGS_use_pd_disaggregation"] = 1 + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + variables["FLAGS_use_pd_disaggregation_per_chunk"] = 1 + else: + variables["FLAGS_use_pd_disaggregation"] = 1 # TODO dynamic load environment variable if self.cfg.scheduler_config.splitwise_role == "prefill": variables["FLAGS_fmt_write_cache_completed_signal"] = 1 @@ -498,6 +503,7 @@ class LLMEngine: f" --load_choices {self.cfg.load_config.load_choices}" f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'" f" --ips {ips}" + f" --cache-transfer-protocol {self.cfg.cache_config.cache_transfer_protocol}" f" --runner {self.cfg.model_config.runner}" f" --convert {self.cfg.model_config.convert}" f" --override-pooler-config {self.cfg.model_config.override_pooler_config}" @@ -625,13 +631,11 @@ class LLMEngine: if self.cfg.scheduler_config.splitwise_role != "mixed": # 单机逻辑 self.engine.engine_worker_queue.available_prefill_instances.put(1) - self.engine.split_mode_get_tasks() - if self.cfg.scheduler_config.name == "splitwise": - self.splitwise_receive_thread = threading.Thread( - target=self.engine.split_connector.start_receiver, args=() - ) - self.splitwise_receive_thread.daemon = True - self.splitwise_receive_thread.start() + self.splitwise_receive_thread = threading.Thread( + target=self.engine.split_connector.start_receiver, args=() + ) + self.splitwise_receive_thread.daemon = True + self.splitwise_receive_thread.start() self.cfg.init_cache_info() @@ -640,6 +644,14 @@ class LLMEngine: disaggregate = self.cfg.disaggregate_info if self.cfg.scheduler_config.name == "splitwise": self.engine.scheduler.start(role, host_ip, disaggregate) + elif self.cfg.scheduler_config.name == "dp": + request_queues_for_dp_ipc = [] + result_queue_for_dp_ipc = multiprocessing.Queue() + for i in range(self.cfg.parallel_config.data_parallel_size): + request_queues_for_dp_ipc.append(multiprocessing.Queue()) + self.engine.scheduler.start( + self.cfg.node_rank * self.cfg.worker_num_per_node, request_queues_for_dp_ipc, result_queue_for_dp_ipc + ) if not envs.FD_ENABLE_MULTI_API_SERVER: if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1: @@ -669,6 +681,9 @@ class LLMEngine: args=( self.cfg, i, + None, + request_queues_for_dp_ipc, + result_queue_for_dp_ipc, ), ) ) diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index 65d86f47c..552c13b6a 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -27,6 +27,7 @@ import numpy as np from fastdeploy.engine.common_engine import EngineService from fastdeploy.inter_communicator import IPCSignal +from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter from fastdeploy.utils import console_logger, envs, llm_logger @@ -69,8 +70,12 @@ class ExpertService: self.engine.scheduler.reset_nodeid(f"{self.engine.scheduler.infer.nodeid}_{local_data_parallel_id!s}") self._finalizer = weakref.finalize(self, self._exit_sub_services) + if envs.FD_ENABLE_INTERNAL_ADAPTER: + self.internal_adapter = InternalAdapter(cfg=self.cfg, engine=self.engine, dp_rank=local_data_parallel_id) - def start(self, ipc_signal_suffix, local_data_parallel_id): + def start( + self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None + ): """ Initializes the engine and starts its sub-services. If `api_server_pid` is defined, will launch a thread @@ -80,6 +85,11 @@ class ExpertService: start_time = time.time() self.engine.start() + if self.cfg.scheduler_config.name == "dp": + self.cfg.init_cache_info() + assert (request_queues_for_dp_ipc is not None) and (result_queue_for_dp_ipc is not None) + self.engine.scheduler.start(local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc) + if ipc_signal_suffix is not None: self.api_server_pid = ipc_signal_suffix self.engine.start_zmq_service(ipc_signal_suffix) @@ -88,8 +98,8 @@ class ExpertService: llm_logger.info(f"start expert service {local_data_parallel_id}") if self.cfg.scheduler_config.splitwise_role != "mixed": - self.engine.start_cache_service(self.cfg.local_device_ids, ipc_signal_suffix) - self.engine.split_mode_get_tasks() + ipc_signal_suffix_cache = self.cfg.parallel_config.engine_worker_queue_port[local_data_parallel_id] + self.engine.start_cache_service(self.cfg.local_device_ids, ipc_signal_suffix_cache) if self.cfg.scheduler_config.name == "splitwise": self.cfg.init_cache_info() @@ -144,14 +154,18 @@ class ExpertService: self.zmq_server.close() -def start_data_parallel_service(cfg, local_data_parallel_id, ipc_signal_suffix=None): +def start_data_parallel_service( + cfg, local_data_parallel_id, ipc_signal_suffix=None, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None +): """ Start expert service """ expert_service = ExpertService(cfg, local_data_parallel_id, start_queue=False) try: - expert_service.start(ipc_signal_suffix, local_data_parallel_id) + expert_service.start( + ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc + ) def deamon_thread(): while True: @@ -159,5 +173,6 @@ def start_data_parallel_service(cfg, local_data_parallel_id, ipc_signal_suffix=N t_deamon = threading.Thread(target=deamon_thread, daemon=True) t_deamon.start() + t_deamon.join() except Exception as e: llm_logger.exception(f"Expert service failed to start: {e}, {str(traceback.format_exc())}") diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 2a0def97a..3906cd29b 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -73,6 +73,7 @@ class Request: guided_json_object: Optional[bool] = None, enable_thinking: Optional[bool] = True, trace_carrier: dict = dict(), + dp_rank: Optional[int] = None, chat_template: Optional[str] = None, image_start: int = 0, video_start: int = 0, @@ -145,6 +146,8 @@ class Request: # extend block tables self.use_extend_tables = False self.extend_block_tables = [] + # dp + self.dp_rank = dp_rank @classmethod def from_dict(cls, d: dict): @@ -187,6 +190,7 @@ class Request: image_end=d.get("image_end", 0), video_end=d.get("video_end", 0), audio_end=d.get("audio_end", 0), + dp_rank=d.get("dp_rank", None), ) @property diff --git a/fastdeploy/engine/resource_manager.py b/fastdeploy/engine/resource_manager.py index 556547073..39f6a80e4 100644 --- a/fastdeploy/engine/resource_manager.py +++ b/fastdeploy/engine/resource_manager.py @@ -328,8 +328,8 @@ class ResourceManager: Delete cached data from the task's prompt token ids based on the cached length. """ if cached_len == len(task.prompt_token_ids): - task.prompt_token_ids = task.prompt_token_ids[cached_len - 1 :] - task.seq_lens_decoder = cached_len - 1 + task.prompt_token_ids = task.prompt_token_ids[cached_len - self.cfg.block_size :] + task.seq_lens_decoder = cached_len - self.cfg.block_size else: task.prompt_token_ids = task.prompt_token_ids[cached_len:] task.seq_lens_decoder = cached_len diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 254ba478a..eb3329536 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -14,6 +14,7 @@ # limitations under the License. """ +import copy import threading import time import traceback @@ -26,7 +27,7 @@ from typing import Union import numpy as np import paddle -from fastdeploy.engine.request import Request, RequestStatus, RequestType +from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType from fastdeploy.engine.resource_manager import ResourceManager from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.utils import llm_logger @@ -297,6 +298,11 @@ class ResourceManagerV1(ResourceManager): while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding + if ( + self.config.scheduler_config.splitwise_role == "prefill" + ): # do not need to schedule for decoding + req_index += 1 + continue if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens request.num_computed_tokens = request.num_total_tokens - 1 if ( @@ -400,11 +406,12 @@ class ResourceManagerV1(ResourceManager): request.status = RequestStatus.RUNNING main_process_metrics.num_requests_waiting.dec(1) main_process_metrics.num_requests_running.inc(1) - allocated_position = self.get_available_position() - request.idx = allocated_position - self.tasks_list[allocated_position] = request - self.stop_flags[allocated_position] = False - self.req_dict[request.request_id] = allocated_position + if self.config.scheduler_config.splitwise_role == "mixed": + allocated_position = self.get_available_position() + request.idx = allocated_position + self.tasks_list[allocated_position] = request + self.stop_flags[allocated_position] = False + self.req_dict[request.request_id] = allocated_position else: if self.config.cache_config.enable_prefix_caching: self._free_blocks(request) @@ -569,6 +576,127 @@ class ResourceManagerV1(ResourceManager): self.waiting.append(request) self.requests[request.request_id] = request + def prerelease_resource(self, request: Request): + """ + Release resource in P or D before finished due to unexpected error. + """ + with self.lock: + self.tasks_list[request.idx] = None + self.stop_flags[request.idx] = True + del self.requests[request.request_id] + del self.req_dict[request.request_id] + self._free_blocks(request) + + def add_request_in_p(self, requests: list[Request]): + with self.lock: + for request in requests: + request.inference_start_time = time.time() + request.schedule_start_time = time.time() + self.running.append(request) + + def preallocate_resource_in_p(self, request: Request): + """ + In P/D aggregated deployment, preallocate resource for P. + If can allocate, allocate resources and return True + If can not, return False + """ + assert self.config.scheduler_config.splitwise_role == "prefill", "Only P instance can call this method" + with self.lock: + if self.available_batch() == 0: + return False + request.need_prefill_tokens = len(request.prompt_token_ids) + need_prealloc_prefill_blocks = ( + request.need_prefill_tokens + self.config.cache_config.block_size - 1 + ) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num + if self.config.cache_config.enable_prefix_caching: + # Enable prefix caching + if self.config.cache_config.enable_hierarchical_cache and self.cache_manager.num_cpu_blocks > 0: + if not self.cache_manager.can_allocate_gpu_blocks( + need_prealloc_prefill_blocks + ): # to prevent block allocation for matching in hierarchical cache and cause dead lock + return False + success = self.get_prefix_cached_blocks(request) + if not success: + self._free_blocks(request) + return False + # consider for mtp, plus enc_dec_block_num + need_extra_prefill_blocks = need_prealloc_prefill_blocks - request.cache_info[0] + if self.cache_manager.can_allocate_gpu_blocks(need_extra_prefill_blocks): + request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_extra_prefill_blocks)) + allocated_position = self.get_available_position() + request.idx = allocated_position + self.tasks_list[request.idx] = request + self.stop_flags[request.idx] = False + self.requests[request.request_id] = request + self.req_dict[request.request_id] = allocated_position + return True + else: + self._free_blocks(request) + return False + + else: + if self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks): + request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks)) + request.num_computed_tokens = 0 + allocated_position = self.get_available_position() + request.idx = allocated_position + self.tasks_list[request.idx] = request + self.stop_flags[request.idx] = False + self.requests[request.request_id] = request + self.req_dict[request.request_id] = allocated_position + return True + + return False + + def preallocate_resource_in_d(self, request: Request): + """ + In P/D aggregated deployment, D should preallocate resource for P. + If can allocate, allocate resources and return True + If can not, return False + """ + assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method" + with self.lock: + if len(self.waiting) > 0: + return False + if self.available_batch() == 0: + return False + request.need_prefill_tokens = len(request.prompt_token_ids) + need_prealloc_prefill_blocks = ( + request.need_prefill_tokens + self.config.cache_config.block_size - 1 + ) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num + if self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks): + request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks)) + request.num_computed_tokens = request.need_prefill_tokens + request.disaggregate_info["block_tables"] = request.block_tables + allocated_position = self.get_available_position() + request.idx = allocated_position + self.tasks_list[request.idx] = request + self.stop_flags[request.idx] = False + self.requests[request.request_id] = request + self.req_dict[request.request_id] = allocated_position + return True + return False + + def insert_task_for_decoding(self, request_output_in_p: RequestOutput): + """ + In P/D aggregated deployment, D should continue to decode after recieving first token and cache from P. + """ + assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method" + with self.lock: + request = self.requests[request_output_in_p.request_id] + request.output_token_ids.append(request_output_in_p.outputs.token_ids[0]) + request.num_cached_tokens = request_output_in_p.num_cached_tokens + if ( + self.config.speculative_config.method in ["mtp"] + and self.config.scheduler_config.splitwise_role == "decode" + ): + request.draft_token_ids = copy.deepcopy(request_output_in_p.outputs.draft_token_ids) + # update request.need_prefill_tokens + request.need_prefill_tokens = len(request.prompt_token_ids) + 1 + request.inference_start_time = time.time() + request.schedule_start_time = time.time() + self.running.append(request) + def _free_blocks(self, request: Request): if self.config.cache_config.enable_prefix_caching: self.cache_manager.release_block_ids(request) @@ -620,5 +748,7 @@ class ResourceManagerV1(ResourceManager): self.tasks_list[request.idx] = None self.stop_flags[request.idx] = True del self.requests[req_id] + if req_id in self.req_dict: + del self.req_dict[req_id] except Exception as e: llm_logger.error(f"finish_request err: {e}, {str(traceback.format_exc())}") diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 32071f682..2e8704984 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -109,6 +109,12 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"), # LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1 "FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"), + # Whether to enable cache task in decode node + "FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "1"), + # Batched token timeout in EP + "FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")), + # Max pre-fetch requests number in PD + "FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")), "FD_ENABLE_MODEL_LOAD_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_LOAD_CACHE", "0"))), } @@ -120,6 +126,14 @@ def __getattr__(name: str): raise AttributeError(f"module {__name__!r} has no attribute {name!r}") +def get_unique_name(self, name): + """ + Get unique name for config + """ + shm_uuid = os.getenv("SHM_UUID", "") + return name + f"_{shm_uuid}" + + def __setattr__(name: str, value: Any): assert name in environment_variables environment_variables[name] = lambda: value diff --git a/fastdeploy/inter_communicator/engine_worker_queue.py b/fastdeploy/inter_communicator/engine_worker_queue.py index da88265a2..e7609e02f 100644 --- a/fastdeploy/inter_communicator/engine_worker_queue.py +++ b/fastdeploy/inter_communicator/engine_worker_queue.py @@ -84,18 +84,28 @@ class EngineWorkerQueue: Value("i", 0) for _ in range(self.local_data_parallel_size) ] self.finished_req_queue = [Queue() for _ in range(self.local_data_parallel_size)] + self.finished_add_cache_task_queue = [Queue() for _ in range(self.local_data_parallel_size)] self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)] + self.connect_rdma_tasks_list = [list() for _ in range(self.local_data_parallel_size)] + self.connect_rdma_tasks_response_list = [list() for _ in range(self.local_data_parallel_size)] self.client_read_info_flag_init: List[List[int]] = [ [1] * self.num_client for _ in range(self.local_data_parallel_size) ] self.lock_info_init: List[threading.Lock] = [ threading.Lock() for _ in range(self.local_data_parallel_size) ] + self.connect_task_lock_init: List[threading.Lock] = [ + threading.Lock() for _ in range(self.local_data_parallel_size) + ] self.finish_request_barrier = [ threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) ] + self.finish_add_cache_task_barrier = [ + threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) + ] + # Register shared objects with proxy types QueueManager.register( "get_tasks", @@ -117,6 +127,19 @@ class EngineWorkerQueue: callable=lambda idx: self.read_finish_flag_init[idx], proxytype=ValueProxy, ) + QueueManager.register( + "get_connect_task_lock", + callable=lambda idx: self.connect_task_lock_init[idx], + proxytype=AcquirerProxy, + ) + QueueManager.register( + "get_connect_rdma_tasks", callable=lambda idx: self.connect_rdma_tasks_list[idx], proxytype=ListProxy + ) + QueueManager.register( + "get_connect_rdma_tasks_responses", + callable=lambda idx: self.connect_rdma_tasks_response_list[idx], + proxytype=ListProxy, + ) QueueManager.register( "get_connected_client_counter", callable=lambda idx: self.connected_client_counter_init[idx], @@ -128,6 +151,11 @@ class EngineWorkerQueue: callable=lambda idx: self.finished_req_queue[idx], ) + QueueManager.register( + "get_finish_add_cache_task_queue", + callable=lambda idx: self.finished_add_cache_task_queue[idx], + ) + QueueManager.register( "get_cache_infos", callable=lambda idx: self.cache_infos_init[idx], @@ -161,6 +189,10 @@ class EngineWorkerQueue: "get_finish_request_barrier", callable=lambda idx: self.finish_request_barrier[idx], ) + QueueManager.register( + "get_finish_add_cache_task_barrier", + callable=lambda idx: self.finish_add_cache_task_barrier[idx], + ) self.manager: BaseManager = QueueManager(address=self.address, authkey=self.authkey) self.manager.start() else: @@ -174,12 +206,17 @@ class EngineWorkerQueue: QueueManager.register("get_read_finish_flag") QueueManager.register("get_connected_client_counter") QueueManager.register("get_finish_request_queue") + QueueManager.register("get_finish_add_cache_task_queue") QueueManager.register("get_cache_infos") QueueManager.register("get_client_read_info_flag") QueueManager.register("get_lock_info") QueueManager.register("get_disaggregate_requests") QueueManager.register("get_available_prefill_instances") QueueManager.register("get_finish_request_barrier") + QueueManager.register("get_finish_add_cache_task_barrier") + QueueManager.register("get_connect_rdma_tasks") + QueueManager.register("get_connect_rdma_tasks_responses") + QueueManager.register("get_connect_task_lock") self.manager = QueueManager(address=self.address, authkey=self.authkey) self._connect_with_retry() @@ -199,7 +236,20 @@ class EngineWorkerQueue: self.disaggregate_requests = self.manager.get_disaggregate_requests(self.local_data_parallel_id) self.available_prefill_instances = self.manager.get_available_prefill_instances() self.finish_request_barrier = self.manager.get_finish_request_barrier(self.local_data_parallel_id) + self.finish_add_cache_task_barrier = self.manager.get_finish_add_cache_task_barrier( + self.local_data_parallel_id + ) self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id) + self.finished_add_cache_task_queue = self.manager.get_finish_add_cache_task_queue( + self.local_data_parallel_id + ) + # p/d互联 + self.connect_rdma_task_queue = self.manager.get_connect_rdma_tasks(self.local_data_parallel_id) + self.connect_rdma_task_response_queue = self.manager.get_connect_rdma_tasks_responses( + self.local_data_parallel_id + ) + self.connect_task_lock = self.manager.get_connect_task_lock(self.local_data_parallel_id) + assert self.num_client == len(self.client_read_flag) if is_server: @@ -281,6 +331,44 @@ class EngineWorkerQueue: self.lock.release() return total_num + def put_connect_rdma_task(self, connect_rdma_task): + self.connect_task_lock.acquire() + self.connect_rdma_task_queue.append(connect_rdma_task) + self.connect_task_lock.release() + + def get_connect_rdma_task(self): + result = None + self.connect_task_lock.acquire() + if len(self.connect_rdma_task_queue) == 0: + self.connect_task_lock.release() + return result + try: + result = self.connect_rdma_task_queue.pop(0) + except Exception as e: + llm_logger.info(f"get_connect_rdma_task got exception: {e}") + finally: + self.connect_task_lock.release() + return result + + def put_connect_rdma_task_response(self, connect_rdma_task_response): + self.connect_task_lock.acquire() + self.connect_rdma_task_response_queue.append(connect_rdma_task_response) + self.connect_task_lock.release() + + def get_connect_rdma_task_response(self): + result = None + self.connect_task_lock.acquire() + if len(self.connect_rdma_task_response_queue) == 0: + self.connect_task_lock.release() + return result + try: + result = self.connect_rdma_task_response_queue.pop(0) + except Exception as e: + llm_logger.info(f"get_connect_rdma_task_response got exception: {e}") + finally: + self.connect_task_lock.release() + return result + def get_prefill_instances(self): """ check if the prefill queue is empty @@ -365,6 +453,29 @@ class EngineWorkerQueue: llm_logger.debug(f"get finished req: {ans}") return ans + def put_finished_add_cache_task_req(self, req_ids) -> None: + """ + Put finished request ID into the queue. + + Args: + req_ids: Request ID to be added to the queue + """ + self.finished_add_cache_task_queue.put(req_ids) + + def get_finished_add_cache_task_req(self) -> str: + """ + Get finished request ID from the queue. + + Returns: + str: Finished request ID + """ + ans = [] + if self.finished_add_cache_task_queue.empty(): + return ans + ans = self.finished_add_cache_task_queue.get() + llm_logger.debug(f"get finished req: {ans}") + return ans + def disaggregate_queue_empty(self): """ Check if the disaggregated task queue is empty. diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index 4ed894a91..61f3fca94 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -211,9 +211,8 @@ class DeepEPEngine: self.num_experts = num_experts self.num_local_experts = num_experts // ep_size self.async_finish = async_finish - from paddle.base.core import Config - self.ep_config = Config(24, 6, 256) + self.ep_config = None # Store phase and role for buffer management self._splitwise_role = splitwise_role diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 7b3bef5de..01cc699cb 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -76,6 +76,7 @@ else: update_inputs, step_reschedule, update_inputs_v1, + speculate_step_reschedule, ) @@ -413,12 +414,11 @@ def step_cuda( """ if speculative_config.method is not None: - if enable_prefix_caching: - speculate_step_system_cache( + if DISABLE_RECOVER: + speculate_step_reschedule( share_inputs["stop_flags"], share_inputs["seq_lens_this_time"], share_inputs["step_seq_lens_encoder"], - share_inputs["step_seq_lens_decoder"], share_inputs["seq_lens_encoder"], share_inputs["seq_lens_decoder"], share_inputs["block_tables"], @@ -444,64 +444,67 @@ def step_cuda( speculative_config.num_speculative_tokens, ) else: - speculate_step_paddle( - share_inputs["stop_flags"], - share_inputs["seq_lens_this_time"], - share_inputs["step_seq_lens_encoder"], - share_inputs["seq_lens_encoder"], - share_inputs["seq_lens_decoder"], - share_inputs["block_tables"], - share_inputs["encoder_block_lens"], - share_inputs["is_block_step"], - share_inputs["step_block_list"], - share_inputs["step_lens"], - share_inputs["recover_block_list"], - share_inputs["recover_lens"], - share_inputs["need_block_list"], - share_inputs["need_block_len"], - share_inputs["used_list_len"], - share_inputs["free_list"], - share_inputs["free_list_len"], - share_inputs["input_ids"], - share_inputs["pre_ids"], - share_inputs["step_idx"], - share_inputs["next_tokens"], - share_inputs["first_token_ids"], - share_inputs["accept_num"], - block_size, - enc_dec_block_num, - speculative_config.num_speculative_tokens, - ) + if enable_prefix_caching: + speculate_step_system_cache( + share_inputs["stop_flags"], + share_inputs["seq_lens_this_time"], + share_inputs["step_seq_lens_encoder"], + share_inputs["step_seq_lens_decoder"], + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_decoder"], + share_inputs["block_tables"], + share_inputs["encoder_block_lens"], + share_inputs["is_block_step"], + share_inputs["step_block_list"], + share_inputs["step_lens"], + share_inputs["recover_block_list"], + share_inputs["recover_lens"], + share_inputs["need_block_list"], + share_inputs["need_block_len"], + share_inputs["used_list_len"], + share_inputs["free_list"], + share_inputs["free_list_len"], + share_inputs["input_ids"], + share_inputs["pre_ids"], + share_inputs["step_idx"], + share_inputs["next_tokens"], + share_inputs["first_token_ids"], + share_inputs["accept_num"], + block_size, + enc_dec_block_num, + speculative_config.num_speculative_tokens, + ) + else: + speculate_step_paddle( + share_inputs["stop_flags"], + share_inputs["seq_lens_this_time"], + share_inputs["step_seq_lens_encoder"], + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_decoder"], + share_inputs["block_tables"], + share_inputs["encoder_block_lens"], + share_inputs["is_block_step"], + share_inputs["step_block_list"], + share_inputs["step_lens"], + share_inputs["recover_block_list"], + share_inputs["recover_lens"], + share_inputs["need_block_list"], + share_inputs["need_block_len"], + share_inputs["used_list_len"], + share_inputs["free_list"], + share_inputs["free_list_len"], + share_inputs["input_ids"], + share_inputs["pre_ids"], + share_inputs["step_idx"], + share_inputs["next_tokens"], + share_inputs["first_token_ids"], + share_inputs["accept_num"], + block_size, + enc_dec_block_num, + speculative_config.num_speculative_tokens, + ) else: - if enable_prefix_caching: - step_system_cache( - share_inputs["stop_flags"], - share_inputs["seq_lens_this_time"], - share_inputs["step_seq_lens_encoder"], - share_inputs["step_seq_lens_decoder"], - share_inputs["seq_lens_encoder"], - share_inputs["seq_lens_decoder"], - share_inputs["block_tables"], - share_inputs["encoder_block_lens"], - share_inputs["is_block_step"], - share_inputs["step_block_list"], - share_inputs["step_lens"], - share_inputs["recover_block_list"], - share_inputs["recover_lens"], - share_inputs["need_block_list"], - share_inputs["need_block_len"], - share_inputs["used_list_len"], - share_inputs["free_list"], - share_inputs["free_list_len"], - share_inputs["input_ids"], - share_inputs["pre_ids"], - share_inputs["step_idx"], - share_inputs["next_tokens"], - share_inputs["first_token_ids"], - block_size, - enc_dec_block_num, - ) - elif DISABLE_RECOVER: + if DISABLE_RECOVER: step_reschedule( share_inputs["stop_flags"], share_inputs["seq_lens_this_time"], @@ -529,32 +532,61 @@ def step_cuda( enc_dec_block_num, ) else: - step_paddle( - share_inputs["stop_flags"], - share_inputs["seq_lens_this_time"], - share_inputs["step_seq_lens_encoder"], - share_inputs["seq_lens_encoder"], - share_inputs["seq_lens_decoder"], - share_inputs["block_tables"], - share_inputs["encoder_block_lens"], - share_inputs["is_block_step"], - share_inputs["step_block_list"], - share_inputs["step_lens"], - share_inputs["recover_block_list"], - share_inputs["recover_lens"], - share_inputs["need_block_list"], - share_inputs["need_block_len"], - share_inputs["used_list_len"], - share_inputs["free_list"], - share_inputs["free_list_len"], - share_inputs["input_ids"], - share_inputs["pre_ids"], - share_inputs["step_idx"], - share_inputs["next_tokens"], - share_inputs["first_token_ids"], - block_size, - enc_dec_block_num, - ) + if enable_prefix_caching: + step_system_cache( + share_inputs["stop_flags"], + share_inputs["seq_lens_this_time"], + share_inputs["step_seq_lens_encoder"], + share_inputs["step_seq_lens_decoder"], + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_decoder"], + share_inputs["block_tables"], + share_inputs["encoder_block_lens"], + share_inputs["is_block_step"], + share_inputs["step_block_list"], + share_inputs["step_lens"], + share_inputs["recover_block_list"], + share_inputs["recover_lens"], + share_inputs["need_block_list"], + share_inputs["need_block_len"], + share_inputs["used_list_len"], + share_inputs["free_list"], + share_inputs["free_list_len"], + share_inputs["input_ids"], + share_inputs["pre_ids"], + share_inputs["step_idx"], + share_inputs["next_tokens"], + share_inputs["first_token_ids"], + block_size, + enc_dec_block_num, + ) + else: + step_paddle( + share_inputs["stop_flags"], + share_inputs["seq_lens_this_time"], + share_inputs["step_seq_lens_encoder"], + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_decoder"], + share_inputs["block_tables"], + share_inputs["encoder_block_lens"], + share_inputs["is_block_step"], + share_inputs["step_block_list"], + share_inputs["step_lens"], + share_inputs["recover_block_list"], + share_inputs["recover_lens"], + share_inputs["need_block_list"], + share_inputs["need_block_len"], + share_inputs["used_list_len"], + share_inputs["free_list"], + share_inputs["free_list_len"], + share_inputs["input_ids"], + share_inputs["pre_ids"], + share_inputs["step_idx"], + share_inputs["next_tokens"], + share_inputs["first_token_ids"], + block_size, + enc_dec_block_num, + ) def rebuild_padding( diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 4f8a9c15d..596b32ab4 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -58,7 +58,6 @@ class TokenProcessor: self.split_connector = split_connector if envs.FD_USE_GET_SAVE_OUTPUT_V1: - llm_logger.debug(f"create zmq get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}") self.zmq_server = ZmqIpcServer( name=f"get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}", mode=zmq.PULL @@ -298,10 +297,15 @@ class TokenProcessor: try: is_blocking = True if self.speculative_decoding: - speculate_get_output(self.output_tokens, rank_id, is_blocking, False) + if ( + self.cfg.parallel_config.enable_expert_parallel + and self.cfg.parallel_config.data_parallel_size > 1 + ): + speculate_get_output(self.output_tokens, rank_id, is_blocking, True) + else: + speculate_get_output(self.output_tokens, rank_id, is_blocking, False) if self.output_tokens[0] == -2: continue - else: if self.use_logprobs: get_output_topk( @@ -370,14 +374,18 @@ class TokenProcessor: llm_logger.info(f"finished_task_id: {finished_task_id}") self.prefill_result_status[finished_task_id[0]] = finished_task_id[1] if task_id in self.prefill_result_status: - self.split_connector.send_first_token(task.disaggregate_info, [result]) - self.resource_manager.stop_flags[index] = True - self.resource_manager.tasks_list[index] = None - self.resource_manager._recycle_block_tables(task) + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + self.resource_manager.finish_requests_async(task_id) + else: + self.resource_manager.stop_flags[index] = True + self.resource_manager.tasks_list[index] = None + self.resource_manager._recycle_block_tables(task) + if task_id in self.resource_manager.req_dict: + del self.resource_manager.req_dict[task_id] if self.prefill_result_status[task_id] != "finished": result.error_code = 400 result.error_message = f"{task_id} failed to {self.prefill_result_status[task_id]}" - del self.resource_manager.req_dict[task_id] + self.split_connector.send_first_token(task.disaggregate_info, [result]) break else: time.sleep(0.002) @@ -388,6 +396,8 @@ class TokenProcessor: self.resource_manager.stop_flags[index] = True self.resource_manager.tasks_list[index] = None self.resource_manager._recycle_block_tables(task) + if task_id in self.resource_manager.req_dict: + del self.resource_manager.req_dict[task_id] task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.resource_manager.tasks_list]) main_process_metrics.available_gpu_block_num.set( @@ -461,16 +471,22 @@ class TokenProcessor: task_id = task.request_id if self.cfg.speculative_config.method: - token_ids = tokens[ - 2 - + SPECULATE_MAX_BSZ - + i * MAX_DRAFT_TOKENS : 2 - + SPECULATE_MAX_BSZ - + i * MAX_DRAFT_TOKENS - + accept_num[i] - ].tolist() - if len(token_ids) == 0 or token_ids[-1] <= 0: - continue + if accept_num[i] == -3: + recovery_stop = True + if recovery_stop: + llm_logger.info(f"recovery stop signal found at task {task_id}") + token_ids = [RECOVERY_STOP_SIGNAL] + else: + token_ids = tokens[ + 2 + + SPECULATE_MAX_BSZ + + i * MAX_DRAFT_TOKENS : 2 + + SPECULATE_MAX_BSZ + + i * MAX_DRAFT_TOKENS + + accept_num[i] + ].tolist() + if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0): + continue else: token_id = int(tokens[i, 0]) token_ids = [token_id] @@ -527,7 +543,7 @@ class TokenProcessor: if self.tokens_counter[task_id] == 0: if task.messages is not None: result.prompt = task.messages - result.num_cached_tokens = task.num_cached_tokens + result.num_cached_tokens = task.num_cached_tokens is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill" @@ -537,7 +553,8 @@ class TokenProcessor: for token_id in token_ids: self.tokens_counter[task_id] += 1 if token_id != RECOVERY_STOP_SIGNAL: - result.outputs.token_ids.append(token_id) + if not (envs.FD_ENABLE_INTERNAL_ADAPTER and token_id in task.eos_token_ids): + result.outputs.token_ids.append(token_id) task.output_token_ids.append(token_id) if self.use_logprobs: result.outputs.logprob = float(scores[i, 0]) @@ -567,7 +584,11 @@ class TokenProcessor: self._record_completion_metrics(task, current_time) self._recycle_resources(task_id, i, task, result, is_prefill) break - if not is_prefill or self.cfg.scheduler_config.name == "splitwise": + if ( + not is_prefill + or self.cfg.scheduler_config.name == "splitwise" + or self.cfg.scheduler_config.name == "dp" + ): batch_result.append(result) self.postprocess(batch_result) @@ -609,7 +630,7 @@ class TokenProcessor: self.cfg.speculative_config.num_speculative_tokens, ) - real_accept_num = [x for x in accept_num if x != 0] + real_accept_num = [x for x in accept_num if x > 0] num_accepted_tokens = sum([x - 1 for x in real_accept_num]) self.num_accepted_tokens += num_accepted_tokens num_emitted_tokens = sum(real_accept_num) diff --git a/fastdeploy/scheduler/config.py b/fastdeploy/scheduler/config.py index e9d664261..e992933be 100644 --- a/fastdeploy/scheduler/config.py +++ b/fastdeploy/scheduler/config.py @@ -18,6 +18,7 @@ import redis from fastdeploy.utils import llm_logger +from .dp_scheduler import DPScheduler from .global_scheduler import GlobalScheduler from .local_scheduler import LocalScheduler from .splitwise_scheduler import SplitWiseScheduler, SplitWiseSchedulerConfig @@ -89,6 +90,54 @@ class LocalSchedulerConfig: llm_logger.info("=============================================================") +class DPLocalSchedulerConfig(LocalSchedulerConfig): + """ + Configuration class for DPLocalScheduler. + Attributes: + max_size: Maximum number of concurrent requests (-1 for unlimited) + ttl: Time-to-live in seconds for request expiration + """ + + def __init__( + self, + max_size: int = -1, + ttl: int = 900, + max_model_len: int = 8192, + enable_chunked_prefill: bool = False, + max_num_partial_prefills: int = 1, + max_long_partial_prefills: int = 1, + long_prefill_token_threshold: int = 0, + splitwise_role: str = "prefill", + **kwargs, + ): + """ + Initialize LocalScheduler configuration. + Args: + max_size: Maximum concurrent requests (-1 for unlimited, 0 for disabled) + ttl: Time-to-live in seconds for request expiration (default 900s) + max_model_len: Maximum model context length in tokens + enable_chunked_prefill: Whether to enable chunked prefill processing + max_num_partial_prefills: Max partial prefill operations allowed + max_long_partial_prefills: Max long-running partial prefill ops + long_prefill_token_threshold: Token count threshold for long prefill + **kwargs: Additional unused arguments (for forward compatibility) + Note: + - If long_prefill_token_threshold is 0, it's auto-calculated as 4% of max_model_len + - See LocalScheduler class for implementation details + """ + self.max_size = max_size + self.ttl = ttl + + self.max_model_len = max_model_len + self.enable_chunked_prefill = enable_chunked_prefill + self.max_num_partial_prefills = max_num_partial_prefills + self.max_long_partial_prefills = max_long_partial_prefills + self.long_prefill_token_threshold = long_prefill_token_threshold + if self.long_prefill_token_threshold == 0: + self.long_prefill_token_threshold = int(self.max_model_len * 0.04) + self.splitwise_role = splitwise_role + + class GlobalSchedulerConfig: """ Configuration class for GlobalScheduler (Redis-based). @@ -235,6 +284,9 @@ class SchedulerConfig: if self.name == "splitwise": self.config = SplitWiseSchedulerConfig(**args) + if self.name == "dp": + self.config = DPLocalSchedulerConfig(**args) + def check(self): """ Validate the configuration. @@ -242,7 +294,7 @@ class SchedulerConfig: Raises: Exception: If invalid scheduler type is specified """ - if self.name not in ["local", "global", "splitwise"]: + if self.name not in ["local", "global", "splitwise", "dp"]: raise Exception(f"Unknown scheduler type {self.name}") self.config.check() @@ -280,6 +332,17 @@ class SchedulerConfig: if self.name == "splitwise": return SplitWiseScheduler(self.config) + if self.name == "dp": + return DPScheduler( + max_size=self.config.max_size, + ttl=self.config.ttl, + enable_chunked_prefill=self.config.enable_chunked_prefill, + max_num_partial_prefills=self.config.max_num_partial_prefills, + max_long_partial_prefills=self.config.max_long_partial_prefills, + long_prefill_token_threshold=self.config.long_prefill_token_threshold, + splitwise_role=self.config.splitwise_role, + ) + return LocalScheduler( max_size=self.config.max_size, ttl=self.config.ttl, diff --git a/fastdeploy/scheduler/dp_scheduler.py b/fastdeploy/scheduler/dp_scheduler.py new file mode 100644 index 000000000..288fb6aa7 --- /dev/null +++ b/fastdeploy/scheduler/dp_scheduler.py @@ -0,0 +1,272 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import logging +import threading +import time +from multiprocessing import Queue +from typing import Dict, List, Optional + +from fastdeploy.engine.request import Request, RequestOutput +from fastdeploy.scheduler.data import ScheduledResponse +from fastdeploy.scheduler.local_scheduler import LocalScheduler +from fastdeploy.utils import envs, get_logger + + +class DPLocalScheduler(LocalScheduler): + def __init__( + self, + max_size: int, + ttl: int, + enable_chunked_prefill: bool, + max_num_partial_prefills: int, + max_long_partial_prefills: int, + long_prefill_token_threshold: int, + splitwise_role: str = "prefill", + ): + super().__init__( + max_size, + ttl, + enable_chunked_prefill, + max_num_partial_prefills, + max_long_partial_prefills, + long_prefill_token_threshold, + ) + self.splitwise_role = splitwise_role + self.scheduler_logger = logging + + def put_results(self, results: List[RequestOutput]): + """ + Add processing results back to the scheduler. + Args: + results: List of RequestOutput objects containing results + """ + responses: List[ScheduledResponse] = [ScheduledResponse(result) for result in results] + + finished_responses = [response.request_id for response in responses if response.finished] + if len(finished_responses) > 0: + self.scheduler_logger.info(f"Scheduler has received some finished responses: {finished_responses}") + + with self.mutex: + for response in responses: + if response.request_id not in self.responses: + self.responses[response.request_id] = [response] + continue + self.responses[response.request_id].append(response) + self.responses_not_empty.notify_all() + + def _recycle(self, request_id: Optional[str] = None): + """ + Clean up expired or completed requests to free memory. + Args: + request_id: Optional specific request ID to remove. + If None, removes all expired requests. + """ + if request_id is not None: + self.requests.pop(request_id, None) + self.responses.pop(request_id, None) + if self.splitwise_role == "decode": + return + self.ids.pop(self.ids.index(request_id)) + self.ids_read_cursor -= 1 + return + + if self.max_size <= 0: + return + + if len(self.requests) <= self.max_size: + return + + now = time.time() + expired_ids = [] + for request_id in self.ids: + request = self.requests[request_id] + if now - request.schedule_time < self.ttl: + break + expired_ids.append(request.request_id) + + for i, expired_id in enumerate(expired_ids): + self.requests.pop(expired_id, None) + self.responses.pop(expired_id, None) + self.ids.pop(i) + + if len(expired_ids) > 0: + if len(expired_ids) - 1 >= self.ids_read_cursor: + self.ids_read_cursor = 0 + else: + self.ids_read_cursor -= len(expired_ids) + + def get_requests( + self, + available_blocks, + block_size, + reserved_output_blocks, + max_num_batched_tokens, + batch=1, + ) -> List[Request]: + """ + Retrieve requests from the scheduler based on available resources. + + Args: + available_blocks: Number of available processing blocks + block_size: Size of each processing block + reserved_output_blocks: Blocks reserved for output + max_num_batched_tokens: Maximum tokens that can be batched + batch: Preferred batch size + + Returns: + List of Request objects ready for processing + """ + if available_blocks <= reserved_output_blocks or batch < 1: + self.scheduler_logger.debug( + f"Scheduler's resource are insufficient: available_blocks={available_blocks} " + f"reserved_output_blocks={reserved_output_blocks} batch={batch} " + f"max_num_batched_tokens={max_num_batched_tokens}" + ) + return [] + required_total_blocks = 0 + current_prefill_tokens = 0 + start_batch_time = time.time() + requests: List[Request] = [] + + with self.requests_not_empty: + if not envs.ENABLE_V1_KVCACHE_SCHEDULER: + while True: + batch_ids = self.requests_not_empty.wait_for( + lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch], + 0.005, + ) + if batch_ids: + for request_id in batch_ids: + request = self.requests[request_id] + required_input_blocks = self.calc_required_blocks( + request.prompt_tokens_ids_len, block_size + ) + current_prefill_tokens += request.prompt_tokens_ids_len + required_total_blocks += required_input_blocks + reserved_output_blocks + if required_total_blocks > available_blocks: + break + + requests.append(request.raw) + self.ids_read_cursor += 1 + start_batch_time = time.time() + if current_prefill_tokens > max_num_batched_tokens: + break + if len(requests) >= batch: + break + if ( + (current_prefill_tokens > max_num_batched_tokens) + or (len(requests) >= batch) + or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT) + ): + break + else: + batch_ids = self.requests_not_empty.wait_for( + lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch], + 0.005, + ) + if batch_ids: + for request_id in batch_ids: + request = self.requests[request_id] + requests.append(request.raw) + self.ids_read_cursor += 1 + + if batch_ids: + if len(batch_ids) > 0 and len(requests) == 0: + self.scheduler_logger.debug( + f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}" + ) + + if len(requests) > 0: + self.scheduler_logger.info( + f"Scheduler has pulled some request: {[request.request_id for request in requests]}" + ) + + return requests + + +class DPScheduler: + def __init__( + self, + max_size: int, + ttl: int, + enable_chunked_prefill: bool, + max_num_partial_prefills: int, + max_long_partial_prefills: int, + long_prefill_token_threshold: int, + splitwise_role: str = "prefill", + ): + self._scheduler = DPLocalScheduler( + max_size, + ttl, + enable_chunked_prefill, + max_num_partial_prefills, + max_long_partial_prefills, + long_prefill_token_threshold, + splitwise_role, + ) + + def start(self, dp_rank: int, request_queues: List[Queue], result_queue: Queue): + self.dp_rank = dp_rank + self.request_queues = request_queues + self.result_queue = result_queue + self.scheduler_logger = get_logger("dpscheduler", f"dp_scheduler_rank{self.dp_rank}.log") + self._scheduler.scheduler_logger = self.scheduler_logger + threading.Thread(target=self._put_requests_to_local).start() + threading.Thread(target=self._get_response_from_local).start() + + def put_requests(self, requests: List[Dict]): + results = [] + for request in requests: + if not hasattr(request, "dp_rank"): + raise ValueError(f"Request object is missing the 'dp_rank' attribute: {request}") + self.request_queues[request.dp_rank].put(request) + results.append((request.request_id, None)) + return results + + def _put_requests_to_local(self): + while True: + request = self.request_queues[self.dp_rank].get() + self.scheduler_logger.info(f"Recieve request from puller, request_id: {request.request_id}") + self._scheduler.put_requests([request]) + + def _get_response_from_local(self): + while True: + results = self._scheduler.get_results() + if len(results) == 0: + continue + self.result_queue.put(results) + + def get_requests( + self, + available_blocks, + block_size, + reserved_output_blocks, + max_num_batched_tokens, + batch=1, + ) -> List[Request]: + return self._scheduler.get_requests( + available_blocks, block_size, reserved_output_blocks, max_num_batched_tokens, batch + ) + + def get_unhandled_request_num(self): + return len(self._scheduler.requests) + + def put_results(self, results: List[RequestOutput]): + self._scheduler.put_results(results) + + def get_results(self) -> Dict[str, List[RequestOutput]]: + return self.result_queue.get() diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index b08215fc8..e87520d0d 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -28,8 +28,6 @@ from fastdeploy.inter_communicator import EngineWorkerQueue from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.utils import get_logger -logger = get_logger("splitwise_connector", "splitwise_connector.log") - class SplitwiseConnector: """ @@ -46,12 +44,19 @@ class SplitwiseConnector: resource_manager (object): Resource manager object. """ self.cfg = cfg + if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1: + self.logger = get_logger( + "splitwise_connector", f"splitwise_connector_{self.cfg.parallel_config.local_data_parallel_id}.log" + ) + else: + self.logger = get_logger("splitwise_connector", "splitwise_connector.log") self.engine_worker_queue = worker_queue self.resource_manager = resource_manager self.connect_innode_instances = {} self.temp_cache_info = dict() self.current_request_ids = dict() self.idx = self.cfg.parallel_config.local_data_parallel_id + self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1" if self.cfg.cache_config.pd_comm_port is not None: self.zmq_ctx = zmq.Context() @@ -70,7 +75,7 @@ class SplitwiseConnector: self.router_socket.setsockopt(zmq.SNDHWM, 1000) self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1) self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}") - logger.info(f"bind {self.cfg.cache_config.pd_comm_port}") + self.logger.info(f"bind {self.cfg.cache_config.pd_comm_port}") self.poller = zmq.Poller() self.poller.register(self.router_socket, zmq.POLLIN) @@ -90,17 +95,17 @@ class SplitwiseConnector: if not socks: continue else: - logger.debug(f"receive {socks}") + self.logger.debug(f"receive {socks}") frames = self.router_socket.recv_multipart() - logger.debug(f"frames: {frames}") + self.logger.debug(f"frames: {frames}") message = frames[-1] self.io_executor.submit(self._process_message, message) time.sleep(0.001) else: time.sleep(5) except Exception as e: - logger.error(f"Receiver error: {e}, {str(traceback.format_exc())}") + self.logger.error(f"Receiver error: {e}, {str(traceback.format_exc())}") time.sleep(1) def _get_push_socket(self, addr): @@ -112,7 +117,7 @@ class SplitwiseConnector: return sock try: - logger.info(f"Establishing new connection to {addr}") + self.logger.info(f"Establishing new connection to {addr}") sock = self.zmq_ctx.socket(zmq.DEALER) # 设置连接参数 @@ -131,7 +136,7 @@ class SplitwiseConnector: return sock except zmq.ZMQError as e: - logger.error(f"Connection to {addr} failed: {e}") + self.logger.error(f"Connection to {addr} failed: {e}") raise ConnectionError(f"Failed to connect to {addr}") from e @@ -140,7 +145,7 @@ class SplitwiseConnector: return try: - logger.info(f"Sent {msg_type} to {addr}") + self.logger.info(f"Sent {msg_type} to {addr}") message = self._serialize_message(msg_type, payload) try: @@ -148,19 +153,19 @@ class SplitwiseConnector: sock = self._get_push_socket(addr) sock.send_multipart([b"", message]) - logger.info(f"Sent {msg_type} to {addr}") + self.logger.info(f"Sent {msg_type} to {addr}") except ConnectionError: - logger.warning(f"Connection to {addr} not established") + self.logger.warning(f"Connection to {addr} not established") except zmq.Again: - logger.warning(f"Send queue full for {addr}") + self.logger.warning(f"Send queue full for {addr}") except Exception as e: - logger.error(f"Send to {addr} failed: {e}, {str(traceback.format_exc())}") + self.logger.error(f"Send to {addr} failed: {e}, {str(traceback.format_exc())}") main_process_metrics.send_cache_failed_num.inc() self._close_connection(addr) except Exception as e: - logger.error(f"Message preparation failed: {e}") + self.logger.error(f"Message preparation failed: {e}") def _close_connection(self, addr): """ @@ -265,7 +270,7 @@ class SplitwiseConnector: f"{task.disaggregate_info['cache_info']['rdma']['ip']}:" + f"{task.disaggregate_info['cache_info']['rdma']['port']}" ) - logger.info(f"send splitwise tasks to port {addr} decode") + self.logger.info(f"send splitwise tasks to port {addr} decode") self.current_request_ids[task.request_id] = "init" decode_diagg = task.disaggregate_info["cache_info"] task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"] @@ -295,7 +300,7 @@ class SplitwiseConnector: self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks)) for task in tasks: task.disaggregate_info["cache_info"]["ipc"]["port"] = port - logger.info(f"send splitwise tasks to port {port} decode") + self.logger.info(f"send splitwise tasks to port {port} decode") current_port = port return current_port @@ -305,7 +310,7 @@ class SplitwiseConnector: """ if not isinstance(tasks_list, list): tasks_list = [tasks_list] - logger.info("send first token to port decode") + self.logger.info("send first token to port decode") if prefill_msg["transfer_protocol"] == "ipc": port = prefill_msg["cache_info"]["ipc"]["port"] if port not in self.connect_innode_instances: @@ -313,7 +318,7 @@ class SplitwiseConnector: self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list)) else: node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}" - logger.info(f"send first token to port {node} decode") + self.logger.info(f"send first token to port {node} decode") self._send_message(node, "decode", tasks_list) def create_connection(self, port): @@ -329,6 +334,26 @@ class SplitwiseConnector: client_id=0, ) + def check_decode_allocated(self, task): + start_time = time.time() + if task.disaggregate_info is None: + return True, "" + if self.enable_decode_cache_task: + return True, "" + if task.disaggregate_info["role"] != "prefill": + return True, "" + while self.current_request_ids[task.request_id] == "init": + time.sleep(0.001) + if time.time() - start_time > 30: + del self.current_request_ids[task.request_id] + return False, "timeout" + msg = self.current_request_ids[task.request_id] + del self.current_request_ids[task.request_id] + if msg == "finished": + return True, "" + self.logger.error(f"Receive_decode_allocated error: {msg}") + return False, msg + def send_cache_infos(self, tasks, current_id): """ Send cache information to specific port. @@ -345,7 +370,7 @@ class SplitwiseConnector: for i in range(len(tasks)): if tasks[i].disaggregate_info is None: continue - logger.info(f"{tasks[i].disaggregate_info}") + self.logger.info(f"{tasks[i].disaggregate_info}") if tasks[i].disaggregate_info["role"] == "decode": if tasks[i].disaggregate_info["transfer_protocol"] == "ipc": cache_info = { @@ -380,11 +405,19 @@ class SplitwiseConnector: addr = "prefill" if current_id == -1: current_id = tasks[i].disaggregate_info["cache_info"]["ipc"]["current_id"] - cache_info = { - "request_id": tasks[i].request_id, - "src_block_ids": tasks[i].block_tables, - "current_id": current_id, - } + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + cache_info = { + "request_id": tasks[i].request_id, + "src_block_ids": tasks[i].block_tables, + "current_id": tasks[i].idx, + "need_prefill_tokens": tasks[i].need_prefill_tokens, + } + else: + cache_info = { + "request_id": tasks[i].request_id, + "src_block_ids": tasks[i].block_tables, + "current_id": current_id, + } if addr not in temp_cache_info: temp_cache_info[addr] = [] @@ -396,7 +429,7 @@ class SplitwiseConnector: else: if len(temp_cache_info): for k, v in temp_cache_info.items(): - logger.info(f"{k} {v}") + self.logger.info(f"{k} {v}") if ":" in str(k): self._send_message(k, "cache_sync", v) else: @@ -427,7 +460,7 @@ class SplitwiseConnector: """ try: msg_type, payload = self._deserialize_message(message) - logger.info(f"{msg_type}") + self.logger.info(f"{msg_type}") if msg_type == "prefill": self._handle_prefill(payload) @@ -435,11 +468,16 @@ class SplitwiseConnector: self._handle_decode(payload) elif msg_type == "cache_sync": for task in payload: - del self.current_request_ids[task["request_id"]] - self.engine_worker_queue.put_cache_info(payload) + self.logger.info(f"cache_sync task: {task}") + current_status = task.get("error_msg", "finished") + self.current_request_ids[task["request_id"]] = current_status + if self.enable_decode_cache_task: + del self.current_request_ids[task["request_id"]] + if current_status == "finished": + self.engine_worker_queue.put_cache_info(payload) except Exception as e: - logger.error(f"Message processing failed: {e}, {str(traceback.format_exc())}") + self.logger.error(f"Message processing failed: {e}, {str(traceback.format_exc())}") def _handle_prefill(self, tasks): """ @@ -462,8 +500,12 @@ class SplitwiseConnector: index=task["outputs"]["index"], send_idx=0, token_ids=task["outputs"]["token_ids"], + draft_token_ids=task["outputs"]["draft_token_ids"], ), finished=True, + num_cached_tokens=task["num_cached_tokens"], + error_code=task["error_code"], + error_msg=task["error_msg"], ) ) self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks)) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index f188bb6a7..186dd58ea 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -16,6 +16,7 @@ import argparse import json +import os import time from typing import Tuple @@ -259,6 +260,7 @@ class PaddleDisWorkerProc: """Main event loop for Paddle Distributed Workers. TODO(gongshaotian): support remote calling of functions that control worker. """ + # Currently, only support single node self.nnode = int((self.parallel_config.tensor_parallel_size + 7) // 8) req_ids = [] @@ -643,6 +645,12 @@ def parse_args(): help="Flag to specify dtype of lm_head as FP32", ) + parser.add_argument( + "--cache-transfer-protocol", + type=str, + default="ipc", + help="support protocol list, comma separated, default is ipc", + ) parser.add_argument( "--runner", type=str, @@ -762,8 +770,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: ): logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not support speculative decoding now.") envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 - if args.splitwise_role != "mixed": - logger.info(f"Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported {args.splitwise_role} now.") + if args.splitwise_role != "mixed" and args.cache_transfer_protocol != "rdma": envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 if not current_platform.is_cuda(): logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported.") @@ -772,6 +779,9 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported guided_decoding.") envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 + if envs.ENABLE_V1_KVCACHE_SCHEDULER and args.splitwise_role == "prefill": + os.environ["PREFILL_NODE_ONE_STEP_STOP_V1"] = "1" + fd_config = FDConfig( model_config=model_config, parallel_config=parallel_config,