From 25498efcf37d330d7bb1ed63d58d3916c397e998 Mon Sep 17 00:00:00 2001 From: chenjian <1435317881@qq.com> Date: Mon, 3 Nov 2025 15:38:31 +0800 Subject: [PATCH] [Optimize] Support and robust for tpN for PD (#4595) * [Optimize] Support and robust for tpN for PD * fix * fix * support dpM tpN for cache messager * fix * fix token counter * fix bug for merge develop * fix bug * robust cache messager for v0 --- fastdeploy/cache_manager/cache_messager.py | 150 ++++---- fastdeploy/engine/common_engine.py | 80 ++-- fastdeploy/engine/expert_service.py | 5 + fastdeploy/envs.py | 2 + .../inter_communicator/engine_worker_queue.py | 352 +++++++++++++++--- fastdeploy/scheduler/dp_scheduler.py | 46 +-- .../splitwise/internal_adapter_utils.py | 2 +- fastdeploy/splitwise/splitwise_connector.py | 1 + fastdeploy/worker/worker_process.py | 11 +- 9 files changed, 452 insertions(+), 197 deletions(-) diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index 7ac2e699a..e6e6aa152 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -246,11 +246,10 @@ class CacheMessager: engine_recycled_count = 0 while True: - cache_info = self.engine_worker_queue.get_cache_info() - if cache_info: logger.debug(f"cache info {cache_info}") + self.engine_worker_queue.cache_info_barrier.wait() for info in cache_info: if info["request_id"] in self.cache_info: self.cache_info[info["request_id"]].update(info) @@ -295,9 +294,6 @@ class CacheMessager: continue if "layer_idx" not in item: item["layer_idx"] = 0 - if item["status"] == "error": - del self.cache_info[req_id] - continue if item["current_id"] > prefilled_step_idx: continue current_transfer_protocol = item["transfer_protocol"] @@ -307,11 +303,7 @@ class CacheMessager: status = self.messager[current_transfer_protocol].connect(target_ip, target_id) if not status: logger.error(f"connect to {target_ip}:{target_id} failed") - item["status"] = "error" - self.engine_worker_queue.finish_request_barrier.wait() - if self.rank == 0: - self.engine_worker_queue.put_finished_req([(item["request_id"], "connect error")]) - continue + item["status"] = "connect error" elif item["transfer_protocol"] == "ipc": target_ip = "0.0.0.0" target_id = int(item["device_ids"][self.rank]) @@ -321,48 +313,43 @@ class CacheMessager: current_layer_idx = self.num_layers else: current_layer_idx = prefilled_layer_idx + 1 - - for layer_idx in range(item["layer_idx"], current_layer_idx): - tic = time.time() - return_code = self.messager[current_transfer_protocol].write_cache( - target_ip, - target_id, - src_block_ids, - dest_block_ids, - layer_idx, - ) - if return_code != 0: - item["status"] = "error" - self.engine_worker_queue.finish_request_barrier.wait() - if self.rank == 0: - self.engine_worker_queue.put_finished_req([(item["request_id"], "write cache error")]) - logger.error( - f"write cache failed, layer_idx: {layer_idx}, " - f"req_id: {item['request_id']}, dest_ip: {target_ip}" + if "error" not in item["status"]: + for layer_idx in range(item["layer_idx"], current_layer_idx): + tic = time.time() + return_code = self.messager[current_transfer_protocol].write_cache( + target_ip, + target_id, + src_block_ids, + dest_block_ids, + layer_idx, ) - break + if return_code != 0: + item["status"] = "write cache error" + logger.error( + f"write cache failed, layer_idx: {layer_idx}, " + f"req_id: {item['request_id']}, dest_ip: {target_ip}" + ) + break - tok = time.time() - cost_time = tok - tic - block_num = len(src_block_ids) - avg_time_per_block = cost_time * 1000 / block_num # ms - send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s - logger.debug( - f"finish write cache for a layer, {item['request_id']}, {layer_idx}" - f" {current_transfer_protocol}" - f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)}," - f"avg_time per block(ms): {round(avg_time_per_block, 5)}" - ) + tok = time.time() + cost_time = tok - tic + block_num = len(src_block_ids) + avg_time_per_block = cost_time * 1000 / block_num # ms + send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s + logger.debug( + f"finish write cache for a layer, {item['request_id']}, {layer_idx}" + f" {current_transfer_protocol}" + f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)}," + f"avg_time per block(ms): {round(avg_time_per_block, 5)}" + ) item["layer_idx"] = current_layer_idx if item["layer_idx"] == self.num_layers: if item["transfer_protocol"] == "ipc": self.messager["ipc"].write_block_by_sync(target_id) logger.info(f"finish write cache {item['request_id']}") - self.engine_worker_queue.finish_request_barrier.wait() - if self.rank == 0: - # to do: robust in TP: here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok. - self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")]) - logger.info(f"put write cache {item['request_id']}") + self.engine_worker_queue.finish_send_cache_barrier.wait() + self.engine_worker_queue.put_finished_req([[item["request_id"], item["status"]]]) + logger.info(f"put write cache {item['request_id']}, status {item['status']}") del self.cache_info[req_id] self.last_layer_idx = prefilled_layer_idx @@ -376,14 +363,17 @@ class CacheMessager: if task is None: time.sleep(0.001) continue + else: + self.engine_worker_queue.connect_task_barrier.wait() logger.info(f"_handle_connect_task recv task: {task}") task_id = task["task_id"] - ip, rdma_port = task["ip"], task["rdma_port"] + ip, rdma_port = task["ip"], task["rdma_ports"][self.rank] status = self.messager["rdma"].connect(ip, rdma_port) if not status: response = {"task_id": task_id, "success": False} else: response = {"task_id": task_id, "success": True} + self.engine_worker_queue.connect_task_response_barrier.wait() self.engine_worker_queue.put_connect_rdma_task_response(response) except Exception as e: logger.error(f"handle_connect_task has exception: {e}") @@ -524,9 +514,9 @@ class CacheMessagerV1: while True: try: cache_info = self.engine_worker_queue.get_cache_info() - self.engine_worker_queue.finish_add_cache_task_barrier.wait() finished_add_cache_task_req_ids = [] if cache_info: + self.engine_worker_queue.cache_info_barrier.wait() for info in cache_info: if info["request_id"] in self.cache_info: self.cache_info[info["request_id"]].update(info) @@ -544,13 +534,16 @@ class CacheMessagerV1: current_info["sended_layer_id"] = -1 current_info["sended_block_num"] = current_info["decode_cached_tokens"] // self.block_size current_info["status"] = "init" - logger.info(f"finish add cache task: {current_info}") + logger.info(f"Get cache info from P: finish add cache task: {current_info}") self.cache_info[info["request_id"]] = current_info self.idx_cache_task_dict[current_info["current_id"]] = current_info else: + logger.info(f"Get cache info from D: {info}") self.cache_info[info["request_id"]] = info - if self.rank == 0 and finished_add_cache_task_req_ids: + + if finished_add_cache_task_req_ids: self.engine_worker_queue.put_finished_add_cache_task_req(finished_add_cache_task_req_ids) + self.engine_worker_queue.finish_add_cache_task_barrier.wait() else: time.sleep(0.001) except Exception as e: @@ -563,14 +556,16 @@ class CacheMessagerV1: """ while True: try: - engine_indexes = self.cache_prefilled_engine_ids_queue.get() - self.engine_worker_queue.finish_request_barrier.wait() + batch_engine_signals = self.cache_prefilled_engine_ids_queue.get() + self.engine_worker_queue.begin_send_cache_barrier.wait() block_start_end_list = [] current_prefilled_token_num_list = [] - for engine_index in engine_indexes: - assert engine_index in self.idx_cache_task_dict + for engine_index, current_step_prefilled_token_num in batch_engine_signals: + assert ( + engine_index in self.idx_cache_task_dict + ), f"engine_index {engine_index} not in self.idx_cache_task_dict {self.idx_cache_task_dict}" block_id_start = self.idx_cache_task_dict[engine_index]["sended_block_num"] - prefilled_token_num = self.engine_cache_tasks[engine_index]["prefilled_token_num"] + prefilled_token_num = current_step_prefilled_token_num if ( prefilled_token_num == self.idx_cache_task_dict[engine_index]["need_prefill_tokens"] ): # all chunks have been prefilled @@ -580,17 +575,20 @@ class CacheMessagerV1: block_start_end_list.append((block_id_start, block_id_end)) current_prefilled_token_num_list.append(prefilled_token_num) while True: # from layer0 to last layer - sended_layer_idx = self.idx_cache_task_dict[engine_indexes[0]]["sended_layer_id"] + sended_layer_idx = self.idx_cache_task_dict[batch_engine_signals[0][0]]["sended_layer_id"] start_layer_idx = sended_layer_idx + 1 with self.engine_cache_task_thread_lock: # to check end_layer_idx - prefilled_layer_idx = self.engine_cache_tasks[engine_indexes[0]]["prefilled_layer_idx"] + prefilled_layer_idx = self.engine_cache_tasks[batch_engine_signals[0][0]][ + "prefilled_layer_idx" + ] if sended_layer_idx > prefilled_layer_idx: # computation must in next chunk logger.info( - f"current_prefilled_token_num_list[0] {current_prefilled_token_num_list[0]} prefilled_token_num {self.engine_cache_tasks[engine_indexes[0]]['prefilled_token_num']}" + f"current_prefilled_token_num_list[0] {current_prefilled_token_num_list[0]} prefilled_token_num {self.engine_cache_tasks[batch_engine_signals[0][0]]['prefilled_token_num']}" ) + assert ( current_prefilled_token_num_list[0] - < self.engine_cache_tasks[engine_indexes[0]]["prefilled_token_num"] + < self.engine_cache_tasks[batch_engine_signals[0][0]]["prefilled_token_num"] ), "when sended_layer_idx > prefilled_layer_idx, must be in next chunk, but not, sth wrong" end_layer_idx = self.num_layers - 1 # [start_layer_idx, end_layer_idx) else: @@ -599,7 +597,7 @@ class CacheMessagerV1: time.sleep(0.01) for layer_idx in range(start_layer_idx, end_layer_idx + 1): for i, (block_id_start, block_id_end) in enumerate(block_start_end_list): - engine_index = engine_indexes[i] + engine_index = batch_engine_signals[i][0] task = self.idx_cache_task_dict[engine_index] req_id = task["request_id"] if ( @@ -615,7 +613,7 @@ class CacheMessagerV1: if task["transfer_protocol"] == "rdma": target_ip = task["ip"] target_id = int(task["rdma_ports"][self.rank]) - if task["status"] == "error": + if "error" in task["status"]: continue status = self.messager[current_transfer_protocol].connect(target_ip, target_id) if not status: @@ -665,7 +663,7 @@ class CacheMessagerV1: block_id_end - block_id_start ) if current_prefilled_token_num_list[i] == task["need_prefill_tokens"]: - if task["status"] != "error": + if "error" not in task["status"]: task["status"] = "finished" logger.info( f"finish write cache for all layers, req_id: {req_id}, block_id_end {block_id_end} need_prefill_tokens {task['need_prefill_tokens']}" @@ -674,18 +672,15 @@ class CacheMessagerV1: task["sended_layer_id"] = -1 if end_layer_idx == self.num_layers - 1: with self.engine_cache_task_thread_lock: - for engine_idx in engine_indexes: + for engine_idx, _ in batch_engine_signals: task = self.idx_cache_task_dict[engine_idx] if task["status"] == "finished" or ("error" in task["status"]): target_id = int(task["rdma_ports"][self.rank]) if task["transfer_protocol"] == "ipc": self.messager["ipc"].write_block_by_sync(target_id) - if self.rank == 0: - # to do: robust in TP, here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok. - self.engine_worker_queue.put_finished_req( - [(task["request_id"], task["status"])] - ) - logger.info(f"put write cache {task['request_id']}, status {task['status']}") + self.engine_worker_queue.finish_send_cache_barrier.wait() + self.engine_worker_queue.put_finished_req([[task["request_id"], task["status"]]]) + logger.info(f"put write cache {task['request_id']}, status {task['status']}") self.engine_cache_tasks[task["current_id"]] = dict() del self.cache_info[task["request_id"]] del self.idx_cache_task_dict[task["current_id"]] @@ -709,8 +704,9 @@ class CacheMessagerV1: continue layer_id = kv_signal_data[1].numpy().tolist() if layer_id == self.num_layers - 1: - logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id}") - batch_engine_ids = [] + logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id} self.rank_id {self.rank_id}") + batch_engine_signals = [] + # format for signal to put in cache_prefilled_engine_ids_queue: [(engine_idx1, prefilled_token_num1), (engine_idx2, prefilled_token_num2)] with self.engine_cache_task_thread_lock: for bi in range(tasks_count): engine_idx = kv_signal_data[3 * bi + 2].numpy().tolist() @@ -720,27 +716,33 @@ class CacheMessagerV1: self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = ( chuck_token_offset + current_seq_len ) - batch_engine_ids.append(engine_idx) + batch_engine_signals.append((engine_idx, chuck_token_offset + current_seq_len)) if layer_id == 0: - self.cache_prefilled_engine_ids_queue.put(batch_engine_ids) + logger.info( + f"Put batch_engine_signals {batch_engine_signals} into cache_prefilled_engine_ids_queue" + ) + self.cache_prefilled_engine_ids_queue.put(batch_engine_signals) except Exception as e: logger.error(f"Consume signals get exception: {e}") def _handle_connect_task(self): while True: try: - task = self.engine_worker_queue.get_connect_rdma_task() + task, _ = self.engine_worker_queue.get_connect_rdma_task() if task is None: time.sleep(0.001) continue + else: + self.engine_worker_queue.connect_task_barrier.wait() logger.info(f"_handle_connect_task recv task: {task}") task_id = task["task_id"] - ip, rdma_port = task["ip"], task["rdma_port"] + ip, rdma_port = task["ip"], task["rdma_ports"][self.rank] status = self.messager["rdma"].connect(ip, rdma_port) if not status: response = {"task_id": task_id, "success": False} else: response = {"task_id": task_id, "success": True} + self.engine_worker_queue.connect_task_response_barrier.wait() self.engine_worker_queue.put_connect_rdma_task_response(response) except Exception as e: logger.error(f"handle_connect_task has exception: {e}") diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 11359fe91..c346d1c75 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -310,11 +310,7 @@ class EngineService: num_client=self.cfg.parallel_config.tensor_parallel_size, client_id=0, local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, - local_data_parallel_id=min( - self.cfg.worker_num_per_node // self.cfg.parallel_config.tensor_parallel_size * self.cfg.node_rank - + self.cfg.parallel_config.local_data_parallel_id, - self.cfg.parallel_config.data_parallel_size - 1, - ), + local_data_parallel_id=self.cfg.parallel_config.local_data_parallel_id, ) def insert_tasks(self, tasks, current_id=-1, allocated=False): @@ -656,39 +652,60 @@ class EngineService: self.cfg.max_prefill_batch, ) + if self.cfg.scheduler_config.splitwise_role != "mixed": + max_num_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens + else: + max_num_batched_tokens = self.cfg.model_config.max_model_len + tasks = self.scheduler.get_requests( available_blocks=self.cfg.cache_config.max_block_num_per_seq, block_size=self.cfg.cache_config.block_size, reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num, - max_num_batched_tokens=self.cfg.model_config.max_model_len, + max_num_batched_tokens=max_num_batched_tokens, batch=num_prefill_batch, ) if self.cfg.scheduler_config.splitwise_role != "mixed": - for task in tasks: - # assure can allocate block ids in P - while not self.resource_manager.preallocate_resource_in_p(task): - time.sleep(0.005) - self.llm_logger.info(f"ask D resource for req_id: {task.request_id}") - self.split_connector.send_splitwise_tasks([task], task.idx) need_delete_tasks = [] - for task in tasks: - if self.cfg.scheduler_config.splitwise_role != "mixed": - # assure fetch block ids from D - status, msg = self.split_connector.check_decode_allocated(task) - if not status: - self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.") - self.scheduler.put_results( - [ - RequestOutput( - request_id=task.request_id, - finished=True, - error_code=500, - error_msg=msg, - ) - ] - ) - need_delete_tasks.append(task) - continue + if envs.FD_OFFLINE_PERF_TEST_FOR_PD: + for task in tasks: + # assure can allocate block ids in P + while not self.resource_manager.preallocate_resource_in_p(task): + time.sleep(0.005) + self.llm_logger.info(f"ask D resource for req_id: {task.request_id}") + while True: + self.split_connector.send_splitwise_tasks([task], task.idx) + status, msg = self.split_connector.check_decode_allocated(task) + if not status: + self.llm_logger.error(f"{task.request_id} ask D resource failed, try again.") + time.sleep(0.05) + else: + break + else: + for task in tasks: + # assure can allocate block ids in P + while not self.resource_manager.preallocate_resource_in_p(task): + time.sleep(0.005) + self.llm_logger.info(f"ask D resource for req_id: {task.request_id}") + self.split_connector.send_splitwise_tasks([task], task.idx) + + for task in tasks: + if self.cfg.scheduler_config.splitwise_role != "mixed": + # assure fetch block ids from D + status, msg = self.split_connector.check_decode_allocated(task) + if not status: + self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.") + self.scheduler.put_results( + [ + RequestOutput( + request_id=task.request_id, + finished=True, + error_code=500, + error_msg=msg, + ) + ] + ) + need_delete_tasks.append(task) + continue for tmp_task in need_delete_tasks: tasks.remove(tmp_task) # release resource in P @@ -930,7 +947,7 @@ class EngineService: for request_id, contents in results.items(): new_contents = [] for content in contents: - if isinstance(content, RequestOutput): + if isinstance(content, RequestOutput) and content.outputs is not None: decode_type = content.outputs.decode_type delta_text = "" if decode_type == 0: @@ -1035,6 +1052,7 @@ class EngineService: f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource." ) continue + self.token_processor.tokens_counter[task.request_id] = 1 self.resource_manager.insert_task_for_decoding(task) else: diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index 5a0da7bc6..174fcf9d2 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -27,6 +27,7 @@ import numpy as np from fastdeploy.engine.common_engine import EngineService from fastdeploy.inter_communicator import IPCSignal +from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter from fastdeploy.utils import console_logger, envs, llm_logger @@ -99,6 +100,10 @@ class ExpertService: self.engine.start_zmq_service(ipc_signal_suffix) else: ipc_signal_suffix = self.cfg.parallel_config.engine_worker_queue_port[0] + if envs.FD_ENABLE_INTERNAL_ADAPTER: + self.internal_adapter = InternalAdapter( + cfg=self.cfg, engine=self.engine, dp_rank=self.cfg.parallel_config.local_data_parallel_id + ) llm_logger.info(f"start expert service {local_data_parallel_id}") diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 258631f58..1eb7af394 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -151,6 +151,8 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_MOE_QUANT_TYPE": lambda: os.getenv("FD_MOE_QUANT_TYPE", "w4a8"), "ENCODE_FEATURE_BOS_AK": lambda: os.getenv("ENCODE_FEATURE_BOS_AK"), "ENCODE_FEATURE_BOS_SK": lambda: os.getenv("ENCODE_FEATURE_BOS_SK"), + # Enable offline perf test mode for PD disaggregation + "FD_OFFLINE_PERF_TEST_FOR_PD": lambda: int(os.getenv("FD_OFFLINE_PERF_TEST_FOR_PD", "0")), } diff --git a/fastdeploy/inter_communicator/engine_worker_queue.py b/fastdeploy/inter_communicator/engine_worker_queue.py index 3c358154f..be4880e17 100644 --- a/fastdeploy/inter_communicator/engine_worker_queue.py +++ b/fastdeploy/inter_communicator/engine_worker_queue.py @@ -80,24 +80,79 @@ class EngineWorkerQueue: self.client_read_flag_init: List[List[int]] = [ [1] * self.num_client for _ in range(self.local_data_parallel_size) ] + self.lock_init: List[threading.Lock] = [threading.Lock() for _ in range(self.local_data_parallel_size)] self.read_finish_flag_init: List[Value] = [Value("i", 0) for _ in range(self.local_data_parallel_size)] self.connected_client_counter_init: List[Value] = [ Value("i", 0) for _ in range(self.local_data_parallel_size) ] - self.finished_req_queue = [Queue() for _ in range(self.local_data_parallel_size)] - self.finished_add_cache_task_queue = [Queue() for _ in range(self.local_data_parallel_size)] + self.finished_req_list = [list() for _ in range(self.local_data_parallel_size)] + self.finished_add_cache_task_list = [list() for _ in range(self.local_data_parallel_size)] self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)] self.connect_rdma_tasks_list = [list() for _ in range(self.local_data_parallel_size)] self.connect_rdma_tasks_response_list = [list() for _ in range(self.local_data_parallel_size)] self.client_read_info_flag_init: List[List[int]] = [ - [1] * self.num_client for _ in range(self.local_data_parallel_size) + [0] * self.num_client for _ in range(self.local_data_parallel_size) ] self.lock_info_init: List[threading.Lock] = [ threading.Lock() for _ in range(self.local_data_parallel_size) ] + # PD disaggregation + # Locks self.connect_task_lock_init: List[threading.Lock] = [ threading.Lock() for _ in range(self.local_data_parallel_size) + ] # connect rdma task + self.connect_task_response_lock_init: List[threading.Lock] = [ + threading.Lock() for _ in range(self.local_data_parallel_size) + ] # connect rdma task response + self.finish_add_cache_task_lock_init: List[threading.Lock] = [ + threading.Lock() for _ in range(self.local_data_parallel_size) + ] # finish add cache task + self.finish_send_cache_lock_init: List[threading.Lock] = [ + threading.Lock() for _ in range(self.local_data_parallel_size) + ] # finish send cache + + # sync read status for TPs + self.client_get_connect_task_flag_init: List[List[int]] = [ + [0] * self.num_client for _ in range(self.local_data_parallel_size) + ] + self.client_get_connect_task_response_flag_init: List[List[int]] = [ + [0] * self.num_client for _ in range(self.local_data_parallel_size) + ] + self.client_get_finished_add_cache_task_flag_init: List[List[int]] = [ + [0] * self.num_client for _ in range(self.local_data_parallel_size) + ] + self.client_get_finish_send_cache_flag_init: List[List[int]] = [ + [0] * self.num_client for _ in range(self.local_data_parallel_size) + ] + self.can_put_next_connect_task_response_flag_init: List[Value] = [ + Value("i", 1) for _ in range(self.local_data_parallel_size) + ] + self.can_put_next_add_task_finished_flag_init: List[Value] = [ + Value("i", 1) for _ in range(self.local_data_parallel_size) + ] + self.can_put_next_send_cache_finished_flag_init: List[Value] = [ + Value("i", 1) for _ in range(self.local_data_parallel_size) + ] + + # barrier + self.get_connect_task_barrier = [ + threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) + ] + self.get_connect_task_response_barrier = [ + threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) + ] + self.finish_add_cache_task_barrier = [ + threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) + ] + self.begin_send_cache_barrier = [ + threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) + ] + self.finish_send_cache_barrier = [ + threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) + ] + self.get_cache_info_barrier = [ + threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) ] self.finish_request_barrier = [ @@ -107,10 +162,6 @@ class EngineWorkerQueue: threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) ] - self.finish_add_cache_task_barrier = [ - threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) - ] - # Register shared objects with proxy types QueueManager.register( "get_tasks", @@ -122,6 +173,26 @@ class EngineWorkerQueue: callable=lambda idx: self.client_read_flag_init[idx], proxytype=ListProxy, ) + QueueManager.register( + "get_client_get_connect_task_flag", + callable=lambda idx: self.client_get_connect_task_flag_init[idx], + proxytype=ListProxy, + ) + QueueManager.register( + "get_client_get_connect_task_response_flag", + callable=lambda idx: self.client_get_connect_task_response_flag_init[idx], + proxytype=ListProxy, + ) + QueueManager.register( + "get_client_get_finished_add_cache_task_flag_init", + callable=lambda idx: self.client_get_finished_add_cache_task_flag_init[idx], + proxytype=ListProxy, + ) + QueueManager.register( + "get_client_get_finish_send_cache_flag_init", + callable=lambda idx: self.client_get_finish_send_cache_flag_init[idx], + proxytype=ListProxy, + ) QueueManager.register( "get_lock", callable=lambda idx: self.lock_init[idx], @@ -132,11 +203,43 @@ class EngineWorkerQueue: callable=lambda idx: self.read_finish_flag_init[idx], proxytype=ValueProxy, ) + QueueManager.register( + "get_can_put_next_connect_task_response_flag", + callable=lambda idx: self.can_put_next_connect_task_response_flag_init[idx], + proxytype=ValueProxy, + ) + QueueManager.register( + "get_can_put_next_add_task_finished_flag", + callable=lambda idx: self.can_put_next_add_task_finished_flag_init[idx], + proxytype=ValueProxy, + ) + QueueManager.register( + "get_can_put_next_send_cache_finished_flag", + callable=lambda idx: self.can_put_next_send_cache_finished_flag_init[idx], + proxytype=ValueProxy, + ) + # PD disaggregation QueueManager.register( "get_connect_task_lock", callable=lambda idx: self.connect_task_lock_init[idx], proxytype=AcquirerProxy, ) + QueueManager.register( + "get_connect_task_response_lock", + callable=lambda idx: self.connect_task_response_lock_init[idx], + proxytype=AcquirerProxy, + ) + QueueManager.register( + "get_finish_add_cache_task_lock", + callable=lambda idx: self.finish_add_cache_task_lock_init[idx], + proxytype=AcquirerProxy, + ) + QueueManager.register( + "get_finish_send_cache_lock", + callable=lambda idx: self.finish_send_cache_lock_init[idx], + proxytype=AcquirerProxy, + ) + QueueManager.register( "get_connect_rdma_tasks", callable=lambda idx: self.connect_rdma_tasks_list[idx], proxytype=ListProxy ) @@ -152,13 +255,13 @@ class EngineWorkerQueue: ) QueueManager.register( - "get_finish_request_queue", - callable=lambda idx: self.finished_req_queue[idx], + "get_finish_request_queue", callable=lambda idx: self.finished_req_list[idx], proxytype=ListProxy ) QueueManager.register( "get_finish_add_cache_task_queue", - callable=lambda idx: self.finished_add_cache_task_queue[idx], + callable=lambda idx: self.finished_add_cache_task_list[idx], + proxytype=ListProxy, ) QueueManager.register( @@ -194,6 +297,26 @@ class EngineWorkerQueue: "get_finish_request_barrier", callable=lambda idx: self.finish_request_barrier[idx], ) + QueueManager.register( + "get_connect_task_barrier", + callable=lambda idx: self.get_connect_task_barrier[idx], + ) + QueueManager.register( + "get_connect_task_response_barrier", + callable=lambda idx: self.get_connect_task_response_barrier[idx], + ) + QueueManager.register( + "get_begin_send_cache_barrier", + callable=lambda idx: self.begin_send_cache_barrier[idx], + ) + QueueManager.register( + "get_finish_send_cache_barrier", + callable=lambda idx: self.finish_send_cache_barrier[idx], + ) + QueueManager.register( + "get_cache_info_barrier", + callable=lambda idx: self.get_cache_info_barrier[idx], + ) QueueManager.register( "get_finish_add_cache_task_barrier", @@ -231,10 +354,25 @@ class EngineWorkerQueue: QueueManager.register("get_available_prefill_instances") QueueManager.register("get_finish_request_barrier") QueueManager.register("get_finish_add_cache_task_barrier") + QueueManager.register("get_connect_task_barrier") + QueueManager.register("get_connect_task_response_barrier") + QueueManager.register("get_finish_send_cache_barrier") + QueueManager.register("get_begin_send_cache_barrier") + QueueManager.register("get_cache_info_barrier") QueueManager.register("get_connect_rdma_tasks") + QueueManager.register("get_client_get_connect_task_flag") + QueueManager.register("get_client_get_connect_task_response_flag") + QueueManager.register("get_client_get_finished_add_cache_task_flag_init") + QueueManager.register("get_client_get_finish_send_cache_flag_init") QueueManager.register("get_connect_rdma_tasks_responses") QueueManager.register("get_connect_task_lock") + QueueManager.register("get_connect_task_response_lock") + QueueManager.register("get_finish_add_cache_task_lock") + QueueManager.register("get_finish_send_cache_lock") QueueManager.register("get_worker_process_tp_barrier") + QueueManager.register("get_can_put_next_connect_task_response_flag") + QueueManager.register("get_can_put_next_add_task_finished_flag") + QueueManager.register("get_can_put_next_send_cache_finished_flag") self.manager = QueueManager(address=self.address, authkey=self.authkey) self._connect_with_retry() @@ -257,17 +395,50 @@ class EngineWorkerQueue: self.finish_add_cache_task_barrier = self.manager.get_finish_add_cache_task_barrier( self.local_data_parallel_id ) + self.connect_task_barrier = self.manager.get_connect_task_barrier(self.local_data_parallel_id) + self.connect_task_response_barrier = self.manager.get_connect_task_response_barrier( + self.local_data_parallel_id + ) + self.finish_send_cache_barrier = self.manager.get_finish_send_cache_barrier(self.local_data_parallel_id) + self.cache_info_barrier = self.manager.get_cache_info_barrier(self.local_data_parallel_id) + self.begin_send_cache_barrier = self.manager.get_begin_send_cache_barrier(self.local_data_parallel_id) self.worker_process_tp_barrier = self.manager.get_worker_process_tp_barrier(self.local_data_parallel_id) - self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id) - self.finished_add_cache_task_queue = self.manager.get_finish_add_cache_task_queue( + self.finished_send_cache_list = self.manager.get_finish_request_queue(self.local_data_parallel_id) + self.finished_add_cache_task_list = self.manager.get_finish_add_cache_task_queue( self.local_data_parallel_id ) # p/d互联 - self.connect_rdma_task_queue = self.manager.get_connect_rdma_tasks(self.local_data_parallel_id) - self.connect_rdma_task_response_queue = self.manager.get_connect_rdma_tasks_responses( + self.connect_rdma_tasks = self.manager.get_connect_rdma_tasks(self.local_data_parallel_id) + self.client_get_connect_task_flag = self.manager.get_client_get_connect_task_flag( + self.local_data_parallel_id + ) + self.client_get_connect_task_response_flag = self.manager.get_client_get_connect_task_response_flag( + self.local_data_parallel_id + ) + self.client_get_finished_add_cache_task_flag = ( + self.manager.get_client_get_finished_add_cache_task_flag_init(self.local_data_parallel_id) + ) + self.client_get_finish_send_cache_flag = self.manager.get_client_get_finish_send_cache_flag_init( + self.local_data_parallel_id + ) + + self.connect_rdma_task_responses = self.manager.get_connect_rdma_tasks_responses( self.local_data_parallel_id ) self.connect_task_lock = self.manager.get_connect_task_lock(self.local_data_parallel_id) + self.connect_task_response_lock = self.manager.get_connect_task_response_lock(self.local_data_parallel_id) + self.finish_add_cache_task_lock = self.manager.get_finish_add_cache_task_lock(self.local_data_parallel_id) + self.finish_send_cache_lock = self.manager.get_finish_send_cache_lock(self.local_data_parallel_id) + + self.can_put_next_add_task_finished_flag = self.manager.get_can_put_next_add_task_finished_flag( + self.local_data_parallel_id + ) + self.can_put_next_connect_task_response_flag = self.manager.get_can_put_next_connect_task_response_flag( + self.local_data_parallel_id + ) + self.can_put_next_send_cache_finished_flag = self.manager.get_can_put_next_send_cache_finished_flag( + self.local_data_parallel_id + ) assert self.num_client == len(self.client_read_flag) @@ -411,41 +582,61 @@ class EngineWorkerQueue: def put_connect_rdma_task(self, connect_rdma_task): self.connect_task_lock.acquire() - self.connect_rdma_task_queue.append(connect_rdma_task) + while sum(self.client_get_connect_task_flag) < self.num_client: + self.connect_task_lock.release() + time.sleep(0.001) + self.connect_task_lock.acquire() + + self.connect_rdma_tasks[:] = list() + self.client_get_connect_task_flag[:] = [0] * self.num_client + self.connect_rdma_tasks.append(connect_rdma_task) self.connect_task_lock.release() def get_connect_rdma_task(self): - result = None + connect_rdma_task = None self.connect_task_lock.acquire() - if len(self.connect_rdma_task_queue) == 0: - self.connect_task_lock.release() - return result - try: - result = self.connect_rdma_task_queue.pop(0) - except Exception as e: - llm_logger.info(f"get_connect_rdma_task got exception: {e}") - finally: - self.connect_task_lock.release() - return result + if len(self.connect_rdma_tasks) > 0: + connect_rdma_task = self.connect_rdma_tasks[0] + self.client_get_connect_task_flag[self.client_id] = 1 + all_client_read: bool = np.sum(self.client_get_connect_task_flag) == self.num_client + if all_client_read: + self.connect_rdma_tasks[:] = list() + self.connect_task_lock.release() + return connect_rdma_task, all_client_read def put_connect_rdma_task_response(self, connect_rdma_task_response): - self.connect_task_lock.acquire() - self.connect_rdma_task_response_queue.append(connect_rdma_task_response) - self.connect_task_lock.release() + self.connect_task_response_lock.acquire() + while not self.can_put_next_connect_task_response_flag.get(): + self.connect_task_response_lock.release() + time.sleep(0.001) + self.connect_task_response_lock.acquire() + self.connect_rdma_task_responses.append(connect_rdma_task_response) + self.client_get_connect_task_response_flag[self.client_id] = 1 + all_client_put: bool = np.sum(self.client_get_connect_task_response_flag) == self.num_client + if all_client_put: + self.can_put_next_connect_task_response_flag.set(0) + self.connect_task_response_lock.release() + return all_client_put def get_connect_rdma_task_response(self): - result = None - self.connect_task_lock.acquire() - if len(self.connect_rdma_task_response_queue) == 0: - self.connect_task_lock.release() - return result - try: - result = self.connect_rdma_task_response_queue.pop(0) - except Exception as e: - llm_logger.info(f"get_connect_rdma_task_response got exception: {e}") - finally: - self.connect_task_lock.release() - return result + task_response = None + self.connect_task_response_lock.acquire() + if len(self.connect_rdma_task_responses) == 0: + self.connect_task_response_lock.release() + return task_response + while sum(self.client_get_connect_task_response_flag) < self.num_client: + self.connect_task_response_lock.release() + time.sleep(0.001) + self.connect_task_response_lock.acquire() + if len(self.connect_rdma_task_responses) > 0: + task_response = self.connect_rdma_task_responses[0] + for tmp_task_response in self.connect_rdma_task_responses: + task_response["success"] = task_response["success"] and tmp_task_response["success"] + self.connect_rdma_task_responses[:] = list() + self.client_get_connect_task_response_flag[:] = [0] * self.num_client + self.can_put_next_connect_task_response_flag.set(1) + self.connect_task_response_lock.release() + return task_response def get_prefill_instances(self): """ @@ -508,14 +699,25 @@ class EngineWorkerQueue: self.lock_info.release() return total_num - def put_finished_req(self, req_ids) -> None: + def put_finished_req(self, send_cache_result) -> None: """ Put finished request ID into the queue. Args: req_ids: Request ID to be added to the queue """ - self.finished_req_queue.put(req_ids) + self.finish_send_cache_lock.acquire() + while not self.can_put_next_send_cache_finished_flag.get(): + self.finish_send_cache_lock.release() + time.sleep(0.001) + self.finish_send_cache_lock.acquire() + self.finished_send_cache_list.append(send_cache_result[0]) + self.client_get_finish_send_cache_flag[self.client_id] = 1 + all_client_put: bool = np.sum(self.client_get_finish_send_cache_flag) == self.num_client + if all_client_put: + self.can_put_next_send_cache_finished_flag.set(0) + self.finish_send_cache_lock.release() + return all_client_put def get_finished_req(self) -> str: """ @@ -524,12 +726,27 @@ class EngineWorkerQueue: Returns: str: Finished request ID """ - ans = [] - if self.finished_req_queue.empty(): - return ans - ans = self.finished_req_queue.get() - llm_logger.debug(f"get finished req: {ans}") - return ans + response = [] + self.finish_send_cache_lock.acquire() + if len(self.finished_send_cache_list) == 0: + self.finish_send_cache_lock.release() + return response + while sum(self.client_get_finish_send_cache_flag) < self.num_client: + self.finish_send_cache_lock.release() + time.sleep(0.001) + self.finish_send_cache_lock.acquire() + if len(self.finished_send_cache_list) > 0: + response = self.finished_send_cache_list[0] + for tmp_response in self.finished_send_cache_list: + if "error" in tmp_response[1]: + response[1] = tmp_response[1] + if response: + response = [response] + self.finished_send_cache_list[:] = list() + self.client_get_finish_send_cache_flag[:] = [0] * self.num_client + self.can_put_next_send_cache_finished_flag.set(1) + self.finish_send_cache_lock.release() + return response def put_finished_add_cache_task_req(self, req_ids) -> None: """ @@ -538,7 +755,18 @@ class EngineWorkerQueue: Args: req_ids: Request ID to be added to the queue """ - self.finished_add_cache_task_queue.put(req_ids) + self.finish_add_cache_task_lock.acquire() + while not self.can_put_next_add_task_finished_flag.get(): + self.finish_add_cache_task_lock.release() + time.sleep(0.001) + self.finish_add_cache_task_lock.acquire() + self.finished_add_cache_task_list.append(req_ids) + self.client_get_finished_add_cache_task_flag[self.client_id] = 1 + all_client_put: bool = np.sum(self.client_get_finished_add_cache_task_flag) == self.num_client + if all_client_put: + self.can_put_next_add_task_finished_flag.set(0) + self.finish_add_cache_task_lock.release() + return all_client_put def get_finished_add_cache_task_req(self) -> str: """ @@ -547,12 +775,24 @@ class EngineWorkerQueue: Returns: str: Finished request ID """ - ans = [] - if self.finished_add_cache_task_queue.empty(): - return ans - ans = self.finished_add_cache_task_queue.get() - llm_logger.debug(f"get finished req: {ans}") - return ans + response = [] + self.finish_add_cache_task_lock.acquire() + if len(self.finished_add_cache_task_list) == 0: + self.finish_add_cache_task_lock.release() + return response + while sum(self.client_get_finished_add_cache_task_flag) < self.num_client: + self.finish_add_cache_task_lock.release() + time.sleep(0.001) + self.finish_add_cache_task_lock.acquire() + if len(self.finished_add_cache_task_list) > 0: + response = self.finished_add_cache_task_list[0] + for tmp_response in self.finished_add_cache_task_list: + assert tmp_response == response + self.finished_add_cache_task_list[:] = list() + self.client_get_finished_add_cache_task_flag[:] = [0] * self.num_client + self.can_put_next_add_task_finished_flag.set(1) + self.finish_add_cache_task_lock.release() + return response def disaggregate_queue_empty(self): """ diff --git a/fastdeploy/scheduler/dp_scheduler.py b/fastdeploy/scheduler/dp_scheduler.py index 0a1c9b0d6..18227ce04 100644 --- a/fastdeploy/scheduler/dp_scheduler.py +++ b/fastdeploy/scheduler/dp_scheduler.py @@ -143,38 +143,7 @@ class DPLocalScheduler(LocalScheduler): requests: List[Request] = [] with self.requests_not_empty: - if not envs.ENABLE_V1_KVCACHE_SCHEDULER: - while True: - batch_ids = self.requests_not_empty.wait_for( - lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch], - 0.005, - ) - if batch_ids: - for request_id in batch_ids: - request = self.requests[request_id] - required_input_blocks = self.calc_required_blocks( - request.prompt_tokens_ids_len, block_size - ) - current_prefill_tokens += request.prompt_tokens_ids_len - required_total_blocks += required_input_blocks + reserved_output_blocks - if required_total_blocks > available_blocks: - break - - requests.append(request.raw) - self.ids_read_cursor += 1 - start_batch_time = time.time() - if current_prefill_tokens > max_num_batched_tokens: - break - if len(requests) >= batch: - break - if ( - (current_prefill_tokens > max_num_batched_tokens) - or (len(requests) >= batch) - or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT) - ): - break - else: - required_total_blocks = 0 + while True: batch_ids = self.requests_not_empty.wait_for( lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch], 0.005, @@ -183,11 +152,24 @@ class DPLocalScheduler(LocalScheduler): for request_id in batch_ids: request = self.requests[request_id] required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size) + current_prefill_tokens += request.prompt_tokens_ids_len required_total_blocks += required_input_blocks + reserved_output_blocks if required_total_blocks > available_blocks: break + requests.append(request.raw) self.ids_read_cursor += 1 + start_batch_time = time.time() + if current_prefill_tokens > max_num_batched_tokens: + break + if len(requests) >= batch: + break + if ( + (current_prefill_tokens > max_num_batched_tokens) + or (len(requests) >= batch) + or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT) + ): + break if batch_ids: if len(batch_ids) > 0 and len(requests) == 0: diff --git a/fastdeploy/splitwise/internal_adapter_utils.py b/fastdeploy/splitwise/internal_adapter_utils.py index eabae716d..30aa74d7d 100644 --- a/fastdeploy/splitwise/internal_adapter_utils.py +++ b/fastdeploy/splitwise/internal_adapter_utils.py @@ -78,7 +78,7 @@ class InternalAdapter: if task is None: time.sleep(0.001) continue - logger.info(f"Recieve control task: {task}") + logger.info(f"dprank {self.dp_rank} Recieve control task: {task}") task_id_str = task["task_id"] if task["cmd"] == "get_payload": payload_info = self._get_current_server_info() diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index 01d1c50c0..ea4022390 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -275,6 +275,7 @@ class SplitwiseConnector: decode_diagg = task.disaggregate_info["cache_info"] task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"] task.disaggregate_info["cache_info"]["rdma"]["current_id"] = current_id + task.disaggregate_info["role"] = "decode" self._send_message(addr, "prefill", [task]) task.disaggregate_info["cache_info"] = decode_diagg task.disaggregate_info["role"] = "prefill" diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 4f492db96..8cedfea37 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -177,7 +177,7 @@ class PaddleDisWorkerProc: self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 if self.parallel_config.data_parallel_size > 1 and not envs.FD_ENABLE_MULTI_API_SERVER: launched_expert_service_signal_data = np.zeros( - shape=[min(self.parallel_config.data_parallel_size, self.max_chips_per_node)], dtype=np.int32 + shape=[self.parallel_config.data_parallel_size // self.fd_config.nnode], dtype=np.int32 ) self.launched_expert_service_signal = IPCSignal( name="launched_expert_service_signal", @@ -186,7 +186,12 @@ class PaddleDisWorkerProc: suffix=self.parallel_config.engine_pid, create=False, ) - while self.launched_expert_service_signal.value[self.local_rank % self.max_chips_per_node] == 0: + while ( + self.launched_expert_service_signal.value[ + self.parallel_config.local_data_parallel_id % self.max_chips_per_node + ] + == 0 + ): pass # init worker_ready_signal @@ -568,7 +573,7 @@ class PaddleDisWorkerProc: is_server=False, num_client=self.parallel_config.tensor_parallel_size, client_id=self.parallel_config.tensor_parallel_rank, - local_data_parallel_id=self.parallel_config.data_parallel_rank, + local_data_parallel_id=self.parallel_config.local_data_parallel_id, ) def load_model(self) -> None: