diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index 4502860d0..2602ff0e9 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -564,14 +564,14 @@ class CacheMessagerV1: """ while True: try: - engine_indexes = self.cache_prefilled_engine_ids_queue.get() + batch_engine_signals = self.cache_prefilled_engine_ids_queue.get() self.engine_worker_queue.finish_request_barrier.wait() block_start_end_list = [] current_prefilled_token_num_list = [] - for engine_index in engine_indexes: + for engine_index, current_step_prefilled_token_num in batch_engine_signals: assert engine_index in self.idx_cache_task_dict block_id_start = self.idx_cache_task_dict[engine_index]["sended_block_num"] - prefilled_token_num = self.engine_cache_tasks[engine_index]["prefilled_token_num"] + prefilled_token_num = current_step_prefilled_token_num if ( prefilled_token_num == self.idx_cache_task_dict[engine_index]["need_prefill_tokens"] ): # all chunks have been prefilled @@ -581,17 +581,19 @@ class CacheMessagerV1: block_start_end_list.append((block_id_start, block_id_end)) current_prefilled_token_num_list.append(prefilled_token_num) while True: # from layer0 to last layer - sended_layer_idx = self.idx_cache_task_dict[engine_indexes[0]]["sended_layer_id"] + sended_layer_idx = self.idx_cache_task_dict[batch_engine_signals[0][0]]["sended_layer_id"] start_layer_idx = sended_layer_idx + 1 with self.engine_cache_task_thread_lock: # to check end_layer_idx - prefilled_layer_idx = self.engine_cache_tasks[engine_indexes[0]]["prefilled_layer_idx"] + prefilled_layer_idx = self.engine_cache_tasks[batch_engine_signals[0][0]][ + "prefilled_layer_idx" + ] if sended_layer_idx > prefilled_layer_idx: # computation must in next chunk logger.info( - f"current_prefilled_token_num_list[0] {current_prefilled_token_num_list[0]} prefilled_token_num {self.engine_cache_tasks[engine_indexes[0]]['prefilled_token_num']}" + f"current_prefilled_token_num_list[0] {current_prefilled_token_num_list[0]} prefilled_token_num {self.engine_cache_tasks[batch_engine_signals[0][0]]['prefilled_token_num']}" ) assert ( current_prefilled_token_num_list[0] - < self.engine_cache_tasks[engine_indexes[0]]["prefilled_token_num"] + < self.engine_cache_tasks[batch_engine_signals[0][0]]["prefilled_token_num"] ), "when sended_layer_idx > prefilled_layer_idx, must be in next chunk, but not, sth wrong" end_layer_idx = self.num_layers - 1 # [start_layer_idx, end_layer_idx) else: @@ -600,7 +602,7 @@ class CacheMessagerV1: time.sleep(0.01) for layer_idx in range(start_layer_idx, end_layer_idx + 1): for i, (block_id_start, block_id_end) in enumerate(block_start_end_list): - engine_index = engine_indexes[i] + engine_index = batch_engine_signals[i][0] task = self.idx_cache_task_dict[engine_index] req_id = task["request_id"] if ( @@ -675,7 +677,7 @@ class CacheMessagerV1: task["sended_layer_id"] = -1 if end_layer_idx == self.num_layers - 1: with self.engine_cache_task_thread_lock: - for engine_idx in engine_indexes: + for engine_idx, _ in batch_engine_signals: task = self.idx_cache_task_dict[engine_idx] if task["status"] == "finished" or ("error" in task["status"]): target_id = int(task["rdma_ports"][self.rank]) @@ -711,7 +713,8 @@ class CacheMessagerV1: layer_id = kv_signal_data[1].numpy().tolist() if layer_id == self.num_layers - 1: logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id}") - batch_engine_ids = [] + batch_engine_signals = [] + # format for signal to put in cache_prefilled_engine_ids_queue: [(engine_idx1, prefilled_token_num1), (engine_idx2, prefilled_token_num2)] with self.engine_cache_task_thread_lock: for bi in range(tasks_count): engine_idx = kv_signal_data[3 * bi + 2].numpy().tolist() @@ -721,9 +724,9 @@ class CacheMessagerV1: self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = ( chuck_token_offset + current_seq_len ) - batch_engine_ids.append(engine_idx) + batch_engine_signals.append((engine_idx, chuck_token_offset + current_seq_len)) if layer_id == 0: - self.cache_prefilled_engine_ids_queue.put(batch_engine_ids) + self.cache_prefilled_engine_ids_queue.put(batch_engine_signals) except Exception as e: logger.error(f"Consume signals get exception: {e}") diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index c9a5db2ab..8c22a761e 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -596,31 +596,47 @@ class EngineService: batch=num_prefill_batch, ) if self.cfg.scheduler_config.splitwise_role != "mixed": - for task in tasks: - # assure can allocate block ids in P - while not self.resource_manager.preallocate_resource_in_p(task): - time.sleep(0.005) - self.llm_logger.info(f"ask D resource for req_id: {task.request_id}") - self.split_connector.send_splitwise_tasks([task], task.idx) need_delete_tasks = [] - for task in tasks: - if self.cfg.scheduler_config.splitwise_role != "mixed": - # assure fetch block ids from D - status, msg = self.split_connector.check_decode_allocated(task) - if not status: - self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.") - self.scheduler.put_results( - [ - RequestOutput( - request_id=task.request_id, - finished=True, - error_code=500, - error_msg=msg, - ) - ] - ) - need_delete_tasks.append(task) - continue + if envs.FD_OFFLINE_PERF_TEST_FOR_PD: + for task in tasks: + # assure can allocate block ids in P + while not self.resource_manager.preallocate_resource_in_p(task): + time.sleep(0.005) + self.llm_logger.info(f"ask D resource for req_id: {task.request_id}") + while True: + self.split_connector.send_splitwise_tasks([task], task.idx) + status, msg = self.split_connector.check_decode_allocated(task) + if not status: + self.llm_logger.error(f"{task.request_id} ask D resource failed, try again.") + time.sleep(0.05) + else: + break + else: + for task in tasks: + # assure can allocate block ids in P + while not self.resource_manager.preallocate_resource_in_p(task): + time.sleep(0.005) + self.llm_logger.info(f"ask D resource for req_id: {task.request_id}") + self.split_connector.send_splitwise_tasks([task], task.idx) + + for task in tasks: + if self.cfg.scheduler_config.splitwise_role != "mixed": + # assure fetch block ids from D + status, msg = self.split_connector.check_decode_allocated(task) + if not status: + self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.") + self.scheduler.put_results( + [ + RequestOutput( + request_id=task.request_id, + finished=True, + error_code=500, + error_msg=msg, + ) + ] + ) + need_delete_tasks.append(task) + continue for tmp_task in need_delete_tasks: tasks.remove(tmp_task) # release resource in P @@ -887,6 +903,7 @@ class EngineService: f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource." ) continue + self.token_processor.tokens_counter[task.request_id] = 1 self.resource_manager.insert_task_for_decoding(task) else: diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 5ef9be1f2..01bf0f2d9 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -122,6 +122,8 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_ENABLE_METRIC_LABELS": lambda: bool(int(os.getenv("FD_ENABLE_METRIC_LABELS", "0"))), # Default label values in metrics. "FD_DEFAULT_METRIC_LABEL_VALUES": lambda: os.getenv("FD_DEFAULT_METRIC_LABEL_VALUES", "{}"), + # Enable offline perf test mode for PD disaggregation + "FD_OFFLINE_PERF_TEST_FOR_PD": lambda: int(os.getenv("FD_OFFLINE_PERF_TEST_FOR_PD", "0")), } diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index ff49a5f27..ac28e467a 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -275,6 +275,7 @@ class SplitwiseConnector: decode_diagg = task.disaggregate_info["cache_info"] task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"] task.disaggregate_info["cache_info"]["rdma"]["current_id"] = current_id + task.disaggregate_info["role"] = "decode" self._send_message(addr, "prefill", [task]) task.disaggregate_info["cache_info"] = decode_diagg task.disaggregate_info["role"] = "prefill"