From aba94169dcc1367ddb70fdcd244f2ee09231ea9a Mon Sep 17 00:00:00 2001 From: chenjian <1435317881@qq.com> Date: Mon, 18 Aug 2025 11:43:36 +0800 Subject: [PATCH] [Feature] Support batched tokens for EP (#3415) * Support batched tokens for EP * Support batched tokens for EP * Support batched tokens for EP * Support batched tokens for EP * Support batched tokens for EP and fix bug * Support batched tokens for EP and fix bug * Support batched tokens for EP and fix bug * Support batched tokens for EP and fix bug --- fastdeploy/engine/engine.py | 16 +- fastdeploy/engine/expert_service.py | 18 +- fastdeploy/envs.py | 2 + .../model_executor/pre_and_post_process.py | 204 ++++++++++-------- fastdeploy/output/token_processor.py | 2 +- fastdeploy/scheduler/dp_scheduler.py | 83 ++++++- .../splitwise/internal_adapter_utils.py | 1 + fastdeploy/splitwise/splitwise_connector.py | 2 + fastdeploy/worker/gpu_model_runner.py | 4 +- 9 files changed, 235 insertions(+), 97 deletions(-) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index fc8fd8d08..8fb9858b6 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -360,6 +360,9 @@ class LLMEngine: self.cfg.max_prefill_batch, ) + if envs.FD_ENABLE_INTERNAL_ADAPTER: + num_prefill_batch = int(self.resource_manager.available_batch()) + self.resource_manager.check_and_free_block_tables() tasks = self.scheduler.get_requests( available_blocks=self.resource_manager.available_block_num(), @@ -790,6 +793,15 @@ class LLMEngine: cur_task_idx = self.resource_manager.req_dict[task.request_id] del self.resource_manager.req_dict[task.request_id] cur_task = self.resource_manager.tasks_list[cur_task_idx] + if envs.FD_ENABLE_INTERNAL_ADAPTER: + if not task.outputs.token_ids: # first token is eos in Prefill, just recycle resource and continue + self.resource_manager.stop_flags[cur_task_idx] = True + self.resource_manager.tasks_list[cur_task_idx] = None + self.resource_manager._recycle_block_tables(cur_task) + if task.request_id in self.token_processor.tokens_counter: + del self.token_processor.tokens_counter[task.request_id] + llm_logger.warning(f"{task.request_id} need not decode after first token") + continue cur_task.prompt_token_ids[0] = task.outputs.token_ids[0] if self.cfg.speculative_config.method in ["mtp"] and self.cfg.splitwise_role == "decode": cur_task.draft_token_ids = copy.deepcopy(task.outputs.draft_token_ids) @@ -799,14 +811,14 @@ class LLMEngine: self.resource_manager._recycle_block_tables(cur_task) if task.request_id in self.token_processor.tokens_counter: del self.token_processor.tokens_counter[task.request_id] - self.scheduler.put_results([task]) llm_logger.warning( f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource." ) continue self.token_processor.tokens_counter[task.request_id] = 1 current_tasks.append(cur_task) - self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz)) + if current_tasks: + self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz)) return True for task in tasks: diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index dba21cc20..048b9e7d3 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -186,7 +186,8 @@ class ExpertService: int(self.resource_manager.available_batch()), self.cfg.max_prefill_batch, ) - + if envs.FD_ENABLE_INTERNAL_ADAPTER: + num_prefill_batch = int(self.resource_manager.available_batch()) self.resource_manager.check_and_free_block_tables() tasks = self.scheduler.get_requests( available_blocks=self.resource_manager.available_block_num(), @@ -294,6 +295,15 @@ class ExpertService: cur_task_idx = self.resource_manager.req_dict[task.request_id] del self.resource_manager.req_dict[task.request_id] cur_task = self.resource_manager.tasks_list[cur_task_idx] + if envs.FD_ENABLE_INTERNAL_ADAPTER: + if not task.outputs.token_ids: # first token is eos in Prefill, just recycle resource and continue + self.resource_manager.stop_flags[cur_task_idx] = True + self.resource_manager.tasks_list[cur_task_idx] = None + self.resource_manager._recycle_block_tables(cur_task) + if task.request_id in self.token_processor.tokens_counter: + del self.token_processor.tokens_counter[task.request_id] + self.llm_logger.warning(f"{task.request_id} need not decode after first token") + continue cur_task.prompt_token_ids[0] = task.outputs.token_ids[0] if self.cfg.speculative_config.method in ["mtp"] and self.cfg.splitwise_role == "decode": cur_task.draft_token_ids = copy.deepcopy(task.outputs.draft_token_ids) @@ -303,16 +313,14 @@ class ExpertService: self.resource_manager._recycle_block_tables(cur_task) if task.request_id in self.token_processor.tokens_counter: del self.token_processor.tokens_counter[task.request_id] - self.scheduler.put_results([task]) self.llm_logger.warning( f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource." ) continue - self.llm_logger.info(f"{cur_task_idx} {task.request_id}") - cur_task.prompt_token_ids[0] = task.outputs.token_ids[0] self.token_processor.tokens_counter[task.request_id] = 1 current_tasks.append(cur_task) - self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz)) + if current_tasks: + self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz)) return True self.resource_manager.check_and_free_block_tables() diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 25c4b0f83..67799a9fa 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -88,6 +88,8 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"), # LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1 "FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"), + # Batched token timeout in EP + "FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")), # Whether to use PLUGINS. "FD_PLUGINS": lambda: None if "FD_PLUGINS" not in os.environ else os.environ["FD_PLUGINS"].split(","), # Whether to enable cache task in decode node diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 5a14d77b4..d183a6b0b 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -65,6 +65,7 @@ else: update_inputs, step_reschedule, update_inputs_v1, + speculate_step_reschedule, ) from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput @@ -355,12 +356,11 @@ def step_cuda( """ if speculative_config.method is not None: - if enable_prefix_caching: - speculate_step_system_cache( + if DISABLE_RECOVER: + speculate_step_reschedule( share_inputs["stop_flags"], share_inputs["seq_lens_this_time"], share_inputs["step_seq_lens_encoder"], - share_inputs["step_seq_lens_decoder"], share_inputs["seq_lens_encoder"], share_inputs["seq_lens_decoder"], share_inputs["block_tables"], @@ -386,64 +386,67 @@ def step_cuda( speculative_config.num_speculative_tokens, ) else: - speculate_step_paddle( - share_inputs["stop_flags"], - share_inputs["seq_lens_this_time"], - share_inputs["step_seq_lens_encoder"], - share_inputs["seq_lens_encoder"], - share_inputs["seq_lens_decoder"], - share_inputs["block_tables"], - share_inputs["encoder_block_lens"], - share_inputs["is_block_step"], - share_inputs["step_block_list"], - share_inputs["step_lens"], - share_inputs["recover_block_list"], - share_inputs["recover_lens"], - share_inputs["need_block_list"], - share_inputs["need_block_len"], - share_inputs["used_list_len"], - share_inputs["free_list"], - share_inputs["free_list_len"], - share_inputs["input_ids"], - share_inputs["pre_ids"], - share_inputs["step_idx"], - share_inputs["next_tokens"], - share_inputs["first_token_ids"], - share_inputs["accept_num"], - block_size, - enc_dec_block_num, - speculative_config.num_speculative_tokens, - ) + if enable_prefix_caching: + speculate_step_system_cache( + share_inputs["stop_flags"], + share_inputs["seq_lens_this_time"], + share_inputs["step_seq_lens_encoder"], + share_inputs["step_seq_lens_decoder"], + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_decoder"], + share_inputs["block_tables"], + share_inputs["encoder_block_lens"], + share_inputs["is_block_step"], + share_inputs["step_block_list"], + share_inputs["step_lens"], + share_inputs["recover_block_list"], + share_inputs["recover_lens"], + share_inputs["need_block_list"], + share_inputs["need_block_len"], + share_inputs["used_list_len"], + share_inputs["free_list"], + share_inputs["free_list_len"], + share_inputs["input_ids"], + share_inputs["pre_ids"], + share_inputs["step_idx"], + share_inputs["next_tokens"], + share_inputs["first_token_ids"], + share_inputs["accept_num"], + block_size, + enc_dec_block_num, + speculative_config.num_speculative_tokens, + ) + else: + speculate_step_paddle( + share_inputs["stop_flags"], + share_inputs["seq_lens_this_time"], + share_inputs["step_seq_lens_encoder"], + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_decoder"], + share_inputs["block_tables"], + share_inputs["encoder_block_lens"], + share_inputs["is_block_step"], + share_inputs["step_block_list"], + share_inputs["step_lens"], + share_inputs["recover_block_list"], + share_inputs["recover_lens"], + share_inputs["need_block_list"], + share_inputs["need_block_len"], + share_inputs["used_list_len"], + share_inputs["free_list"], + share_inputs["free_list_len"], + share_inputs["input_ids"], + share_inputs["pre_ids"], + share_inputs["step_idx"], + share_inputs["next_tokens"], + share_inputs["first_token_ids"], + share_inputs["accept_num"], + block_size, + enc_dec_block_num, + speculative_config.num_speculative_tokens, + ) else: - if enable_prefix_caching: - step_system_cache( - share_inputs["stop_flags"], - share_inputs["seq_lens_this_time"], - share_inputs["step_seq_lens_encoder"], - share_inputs["step_seq_lens_decoder"], - share_inputs["seq_lens_encoder"], - share_inputs["seq_lens_decoder"], - share_inputs["block_tables"], - share_inputs["encoder_block_lens"], - share_inputs["is_block_step"], - share_inputs["step_block_list"], - share_inputs["step_lens"], - share_inputs["recover_block_list"], - share_inputs["recover_lens"], - share_inputs["need_block_list"], - share_inputs["need_block_len"], - share_inputs["used_list_len"], - share_inputs["free_list"], - share_inputs["free_list_len"], - share_inputs["input_ids"], - share_inputs["pre_ids"], - share_inputs["step_idx"], - share_inputs["next_tokens"], - share_inputs["first_token_ids"], - block_size, - enc_dec_block_num, - ) - elif DISABLE_RECOVER: + if DISABLE_RECOVER: step_reschedule( share_inputs["stop_flags"], share_inputs["seq_lens_this_time"], @@ -471,32 +474,61 @@ def step_cuda( enc_dec_block_num, ) else: - step_paddle( - share_inputs["stop_flags"], - share_inputs["seq_lens_this_time"], - share_inputs["step_seq_lens_encoder"], - share_inputs["seq_lens_encoder"], - share_inputs["seq_lens_decoder"], - share_inputs["block_tables"], - share_inputs["encoder_block_lens"], - share_inputs["is_block_step"], - share_inputs["step_block_list"], - share_inputs["step_lens"], - share_inputs["recover_block_list"], - share_inputs["recover_lens"], - share_inputs["need_block_list"], - share_inputs["need_block_len"], - share_inputs["used_list_len"], - share_inputs["free_list"], - share_inputs["free_list_len"], - share_inputs["input_ids"], - share_inputs["pre_ids"], - share_inputs["step_idx"], - share_inputs["next_tokens"], - share_inputs["first_token_ids"], - block_size, - enc_dec_block_num, - ) + if enable_prefix_caching: + step_system_cache( + share_inputs["stop_flags"], + share_inputs["seq_lens_this_time"], + share_inputs["step_seq_lens_encoder"], + share_inputs["step_seq_lens_decoder"], + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_decoder"], + share_inputs["block_tables"], + share_inputs["encoder_block_lens"], + share_inputs["is_block_step"], + share_inputs["step_block_list"], + share_inputs["step_lens"], + share_inputs["recover_block_list"], + share_inputs["recover_lens"], + share_inputs["need_block_list"], + share_inputs["need_block_len"], + share_inputs["used_list_len"], + share_inputs["free_list"], + share_inputs["free_list_len"], + share_inputs["input_ids"], + share_inputs["pre_ids"], + share_inputs["step_idx"], + share_inputs["next_tokens"], + share_inputs["first_token_ids"], + block_size, + enc_dec_block_num, + ) + else: + step_paddle( + share_inputs["stop_flags"], + share_inputs["seq_lens_this_time"], + share_inputs["step_seq_lens_encoder"], + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_decoder"], + share_inputs["block_tables"], + share_inputs["encoder_block_lens"], + share_inputs["is_block_step"], + share_inputs["step_block_list"], + share_inputs["step_lens"], + share_inputs["recover_block_list"], + share_inputs["recover_lens"], + share_inputs["need_block_list"], + share_inputs["need_block_len"], + share_inputs["used_list_len"], + share_inputs["free_list"], + share_inputs["free_list_len"], + share_inputs["input_ids"], + share_inputs["pre_ids"], + share_inputs["step_idx"], + share_inputs["next_tokens"], + share_inputs["first_token_ids"], + block_size, + enc_dec_block_num, + ) def rebuild_padding( diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index fbb978407..3f590b73c 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -270,7 +270,7 @@ class TokenProcessor: self.resource_manager._recycle_block_tables(task) if self.prefill_result_status[task_id] != "finished": result.error_code = 400 - result.error_message = f"{task_id} failed to {self.prefill_result_status[task_id]}" + result.error_msg = f"{task_id} failed to {self.prefill_result_status[task_id]}" self.split_connector.send_first_token(task.disaggregate_info, [result]) del self.resource_manager.req_dict[task_id] break diff --git a/fastdeploy/scheduler/dp_scheduler.py b/fastdeploy/scheduler/dp_scheduler.py index d55a68790..d5d1d3967 100644 --- a/fastdeploy/scheduler/dp_scheduler.py +++ b/fastdeploy/scheduler/dp_scheduler.py @@ -14,6 +14,7 @@ # limitations under the License. """ +import logging import threading import time from multiprocessing import Queue @@ -22,7 +23,7 @@ from typing import Dict, List, Optional from fastdeploy.engine.request import Request, RequestOutput from fastdeploy.scheduler.data import ScheduledResponse from fastdeploy.scheduler.local_scheduler import LocalScheduler -from fastdeploy.utils import scheduler_logger +from fastdeploy.utils import envs, get_logger class DPLocalScheduler(LocalScheduler): @@ -45,6 +46,7 @@ class DPLocalScheduler(LocalScheduler): long_prefill_token_threshold, ) self.splitwise_role = splitwise_role + self.scheduler_logger = logging def put_results(self, results: List[RequestOutput]): """ @@ -56,7 +58,7 @@ class DPLocalScheduler(LocalScheduler): finished_responses = [response.request_id for response in responses if response.finished] if len(finished_responses) > 0: - scheduler_logger.info(f"Scheduler has received some finished responses: {finished_responses}") + self.scheduler_logger.info(f"Scheduler has received some finished responses: {finished_responses}") with self.mutex: for response in responses: @@ -107,6 +109,80 @@ class DPLocalScheduler(LocalScheduler): else: self.ids_read_cursor -= len(expired_ids) + def get_requests( + self, + available_blocks, + block_size, + reserved_output_blocks, + max_num_batched_tokens, + batch=1, + ) -> List[Request]: + """ + Retrieve requests from the scheduler based on available resources. + + Args: + available_blocks: Number of available processing blocks + block_size: Size of each processing block + reserved_output_blocks: Blocks reserved for output + max_num_batched_tokens: Maximum tokens that can be batched + batch: Preferred batch size + + Returns: + List of Request objects ready for processing + """ + if available_blocks <= reserved_output_blocks or batch < 1: + self.scheduler_logger.debug( + f"Scheduler's resource are insufficient: available_blocks={available_blocks} " + f"reserved_output_blocks={reserved_output_blocks} batch={batch} " + f"max_num_batched_tokens={max_num_batched_tokens}" + ) + return [] + required_total_blocks = 0 + current_prefill_tokens = 0 + start_batch_time = time.time() + requests: List[Request] = [] + + with self.requests_not_empty: + while True: + batch_ids = self.requests_not_empty.wait_for( + lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch], + 0.005, + ) + if batch_ids: + for request_id in batch_ids: + request = self.requests[request_id] + required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size) + current_prefill_tokens += request.prompt_tokens_ids_len + required_total_blocks += required_input_blocks + reserved_output_blocks + if required_total_blocks > available_blocks: + break + if current_prefill_tokens > max_num_batched_tokens: + break + + requests.append(request.raw) + self.ids_read_cursor += 1 + start_batch_time = time.time() + if len(requests) >= batch: + break + if ( + (current_prefill_tokens > max_num_batched_tokens) + or (len(requests) >= batch) + or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT) + ): + break + if batch_ids: + if len(batch_ids) > 0 and len(requests) == 0: + self.scheduler_logger.debug( + f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}" + ) + + if len(requests) > 0: + self.scheduler_logger.info( + f"Scheduler has pulled some request: {[request.request_id for request in requests]}" + ) + + return requests + class DPScheduler: def __init__( @@ -133,6 +209,8 @@ class DPScheduler: self.dp_rank = dp_rank self.request_queues = request_queues self.result_queue = result_queue + self.scheduler_logger = get_logger("dpscheduler", f"dp_scheduler_rank{self.dp_rank}.log") + self._scheduler.scheduler_logger = self.scheduler_logger threading.Thread(target=self._put_requests_to_local).start() threading.Thread(target=self._get_response_from_local).start() @@ -148,6 +226,7 @@ class DPScheduler: def _put_requests_to_local(self): while True: request = self.request_queues[self.dp_rank].get() + self.scheduler_logger.info(f"Recieve request from puller, request_id: {request.request_id}") self._scheduler.put_requests([request]) def _get_response_from_local(self): diff --git a/fastdeploy/splitwise/internal_adapter_utils.py b/fastdeploy/splitwise/internal_adapter_utils.py index dfee8b41c..6288a30f9 100644 --- a/fastdeploy/splitwise/internal_adapter_utils.py +++ b/fastdeploy/splitwise/internal_adapter_utils.py @@ -56,6 +56,7 @@ class InternalAdapter: "splitwise_role": self.cfg.splitwise_role, "block_size": int(self.cfg.cache_config.block_size), "block_num": int(available_block_num), + "max_block_num": self.cfg.cache_config.total_block_num, "dec_token_num": int(self.cfg.cache_config.dec_token_num), "available_resource": 1.0 * available_block_num / self.cfg.cache_config.total_block_num, "max_batch_size": int(available_batch_size), diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index f3611a0fb..dbcb46b47 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -506,6 +506,8 @@ class SplitwiseConnector: draft_token_ids=task["outputs"]["draft_token_ids"], ), finished=True, + error_code=task["error_code"], + error_msg=task["error_msg"], ) ) req_ids = [task["request_id"] for task in payload] diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index fdb73609d..c5cf9ebde 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1274,7 +1274,7 @@ class GPUModelRunner(ModelRunnerBase): if not self.not_need_stop(): self._execute_empty_input() return None - + start_time = time.time() # 1. Prepare inputs of model and sampler. skip_idx_list = self._get_skip_idx(model_forward_batch) self._prepare_inputs() @@ -1409,6 +1409,8 @@ class GPUModelRunner(ModelRunnerBase): self._update_chunked_prefill(model_forward_batch) self._add_cache(model_forward_batch) + end_time = time.time() + logger.debug(f"execute one step cost time: {end_time-start_time:.3f} s") return None def _add_cache(self, model_forward_batch) -> None: