""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ from __future__ import annotations import os import signal import threading import time import traceback import weakref from collections import deque import numpy as np from fastdeploy.engine.request import RequestOutput from fastdeploy.engine.resource_manager import ResourceManager from fastdeploy.inter_communicator import EngineWorkerQueue from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.output.token_processor import TokenProcessor from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.utils import EngineError, console_logger, envs, get_logger, llm_logger class ExpertService: """ Engine class responsible for managing the Large Language Model (LLM) operations. Attributes: cfg (Config): Configuration object containing all the parameters. local_data_parallel_id (int): Local data parallel ID. """ def __init__(self, cfg, local_data_parallel_id): """ Initializes the LLMEngine with the provided configuration. Args: cfg (Config): Config object containing all the configuration parameters. """ self.cfg = cfg start_pos = (local_data_parallel_id * self.cfg.tensor_parallel_size) % cfg.worker_num_per_node end_pos = start_pos + self.cfg.tensor_parallel_size self.waiting_requests = [] self.disaggregate_queue = deque() self.llm_logger = get_logger("expert_service", f"expert_service_{local_data_parallel_id}.log") if cfg.splitwise_role != "mixed": self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos] self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos] self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id self.cfg.disaggregate_info = None self.scheduler = cfg.scheduler_config.scheduler() if self.cfg.scheduler_config.name == "splitwise": self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}") self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id address = (cfg.master_ip, cfg.engine_worker_queue_port + local_data_parallel_id) self.engine_worker_queue = EngineWorkerQueue( address=address, is_server=False, client_id=0, num_client=cfg.tensor_parallel_size, local_data_parallel_id=local_data_parallel_id, ) self.resource_manager = ResourceManager( cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role, local_data_parallel_id, ) if cfg.splitwise_role != "mixed": if len(self.cfg.cache_config.pd_comm_port) == 1: self.cfg.cache_config.pd_comm_port[0] = ( int(self.cfg.cache_config.pd_comm_port[0]) + local_data_parallel_id ) else: self.cfg.cache_config.pd_comm_port = [self.cfg.cache_config.pd_comm_port[local_data_parallel_id]] self.split_connector = SplitwiseConnector( self.cfg, self.scheduler, self.engine_worker_queue, self.resource_manager, self.disaggregate_queue ) self.token_processor = TokenProcessor( cfg=cfg, cached_generated_tokens=self.scheduler, engine_worker_queue=self.engine_worker_queue, split_connector=self.split_connector, ) self.token_processor.set_resource_manager(self.resource_manager) self.partial_chunked_tokens = [0] * (self.cfg.max_num_partial_prefills + 1) for idx in range(1, self.cfg.max_num_partial_prefills + 1): self.partial_chunked_tokens[idx] = ( (self.cfg.max_num_batched_tokens // idx) // self.cfg.cache_config.block_size * self.cfg.cache_config.block_size ) self._finalizer = weakref.finalize(self, self._exit_sub_services) if envs.FD_ENABLE_INTERNAL_ADAPTER: self.external_adapter = InternalAdapter(cfg=self.cfg, engine=self, dp_rank=local_data_parallel_id) def start( self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None ): """ Initializes the engine and starts its sub-services. If `api_server_pid` is defined, will launch a thread to keep getting request from zmq_server. """ # assert not self.is_started, "The engine is already started." start_time = time.time() self.llm_logger.info(f"start expert service {local_data_parallel_id}") if self.cfg.splitwise_role != "mixed": self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager( cache_config=self.cfg.cache_config, tensor_parallel_size=self.cfg.tensor_parallel_size, device_ids=self.cfg.local_device_ids, pod_ip=self.cfg.master_ip, engine_worker_queue_port=self.cfg.engine_worker_queue_port, pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}", ) self.split_mode_get_tasks() self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, args=()) self.insert_task_to_worker_thread.daemon = True self.insert_task_to_worker_thread.start() # Start TokenProcessor thread os.environ["INFERENCE_MSG_QUEUE_ID"] = str(local_data_parallel_id + int(self.cfg.engine_worker_queue_port)) self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK self.token_processor.run() self.cfg.init_cache_info() role = self.cfg.splitwise_role host_ip = self.cfg.host_ip disaggregate = self.cfg.disaggregate_info if self.cfg.scheduler_config.name == "dp": assert (request_queues_for_dp_ipc is not None) and (result_queue_for_dp_ipc is not None) self.scheduler.start(local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc) elif self.cfg.scheduler_config.name == "splitwise": self.scheduler.start(role, host_ip, disaggregate) self.cfg.print() console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.") return True def _insert_task_to_worker(self): """ Insert task to engine thread, monitor scheduler request queue. if the engine has resource, insert task to engine """ current_id = -1 while True: try: if self.resource_manager.available_batch() == 0: time.sleep(0.001) continue if self.engine_worker_queue.num_tasks() > 0: time.sleep(0.001) continue num_prefill_batch = min( int(self.resource_manager.available_batch()), self.cfg.max_prefill_batch, ) self.resource_manager.check_and_free_block_tables() tasks = self.scheduler.get_requests( available_blocks=self.resource_manager.available_block_num(), block_size=self.cfg.cache_config.block_size, reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num, max_num_batched_tokens=self.cfg.max_num_batched_tokens, batch=num_prefill_batch, ) if len(tasks) == 0: time.sleep(0.001) continue if self.cfg.splitwise_role != "mixed": self.llm_logger.info("Inserting splitwise tasks") self.split_connector.send_splitwise_tasks(tasks, current_id) current_id = (current_id + 1) % 100003 self.insert_tasks(tasks, current_id) main_process_metrics.num_requests_waiting.dec(len(tasks)) main_process_metrics.num_requests_running.inc(len(tasks)) except Exception as e: err_msg = "Error happend while insert task to engine: {}, {}.".format(e, str(traceback.format_exc())) self.llm_logger.error(err_msg) def split_mode_get_tasks(self): """ Split mode get tasks """ def receiver_loop(): while True: try: processed_indices = [] for idx, task in enumerate(self.waiting_requests): if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): self.insert_tasks([task]) self.llm_logger.info(f"Resource available, processing task {task.request_id}") processed_indices.append(idx) else: self.llm_logger.debug(f"Still waiting for resources {task.request_id}") break for idx in sorted(processed_indices, reverse=True): self.waiting_requests.pop(idx) if len(self.disaggregate_queue) > 0: items = self.disaggregate_queue.pop() role = items[0] tasks = items[1] if role == "prefill": for task in tasks: task.max_tokens = task.min_tokens = 2 self.insert_tasks(tasks) elif role == "decode": if hasattr(tasks[0], "finished"): if not isinstance(tasks, list): tasks = [tasks] for task in tasks: task.finished = False self.insert_tasks(tasks, allocated=True) if self.cfg.innode_prefill_ports is not None: self.scheduler.put_results(tasks) else: if len(self.waiting_requests): self.llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}") self.waiting_requests.extend(tasks) else: new_waiting = [] for task in tasks: if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): self.insert_tasks([task]) else: if not self.enable_decode_cache_task: task.error_msg = "Not enough resources" new_waiting.append(task) if new_waiting: if not self.enable_decode_cache_task: self.split_connector.send_cache_infos(new_waiting, -1) else: self.waiting_requests.extend(new_waiting) self.llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue") else: time.sleep(0.001) except Exception as e: self.llm_logger.error(f"Error in main loop: {e} {str(traceback.format_exc())}") time.sleep(0.1) threading.Thread(target=receiver_loop, daemon=True).start() def insert_tasks(self, tasks, current_id=-1, allocated=False): """ Insert tasks to engine. """ if allocated: current_tasks = [] for task in tasks: cur_task_idx = self.resource_manager.req_dict[task.request_id] del self.resource_manager.req_dict[task.request_id] cur_task = self.resource_manager.tasks_list[cur_task_idx] if task.error_code != 200: self.resource_manager.stop_flags[cur_task_idx] = True self.resource_manager.tasks_list[cur_task_idx] = None self.resource_manager._recycle_block_tables(cur_task) if task.request_id in self.token_processor.tokens_counter: del self.token_processor.tokens_counter[task.request_id] self.scheduler.put_results([task]) self.llm_logger.warning( f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource." ) continue self.llm_logger.info(f"{cur_task_idx} {task.request_id}") cur_task.prompt_token_ids[0] = task.outputs.token_ids[0] self.token_processor.tokens_counter[task.request_id] = 1 current_tasks.append(cur_task) self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz)) return True self.resource_manager.check_and_free_block_tables() if not isinstance(tasks, list): tasks = [tasks] for task in tasks: if self.cfg.splitwise_role != "mixed": status, msg = self.split_connector.check_decode_allocated(task) if not status: self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.") self.scheduler.put_results( [ RequestOutput( request_id=task.request_id, finished=True, error_code=500, error_msg=msg, ) ] ) tasks.remove(task) continue task.schedule_start_time = time.time() available_batch = np.sum(self.resource_manager.stop_flags) if len(tasks) > available_batch: self.llm_logger.error( "Inserting batch:{} exceeds the available batch:{}.".format(len(tasks), available_batch) ) self.llm_logger.error("The exceeded part will be ignored!") tasks = tasks[:available_batch] req_ids = [t.request_id for t in tasks] tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks) if not tasks: error_msg = f"The request required resources is exceed the limit, request id={req_ids}." self.llm_logger.error(error_msg) raise EngineError(error_msg, error_code=500) return False self.token_processor.number_of_tasks += len(tasks) is_decode = False is_prefill = False for i in range(len(tasks)): if tasks[i].disaggregate_info is not None: if tasks[i].disaggregate_info["role"] == "decode": is_decode = True else: is_prefill = True self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len if is_decode or is_prefill: self.split_connector.send_cache_infos(tasks, current_id) for task in tasks: task.infer_start_time = time.time() if not is_decode: self.llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}") if not is_prefill and self.cfg.cache_config.enable_chunked_prefill: if not self.cfg.enable_mm: self.update_requests_chunk_size(tasks) else: self.update_mm_requests_chunk_size(tasks) self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz)) return True def _exit_sub_services(self): """ exit sub services """ if hasattr(self, "cache_manager_processes"): self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear() self.resource_manager.cache_manager.cache_ready_signal.clear() for p in self.cache_manager_processes: self.llm_logger.info(f"Killing cache manager process {p.pid}") try: os.killpg(p.pid, signal.SIGTERM) except: pass if hasattr(self, "zmq_server") and self.zmq_server is not None: self.zmq_server.close() def start_expert_service( cfg, local_data_parallel_id, ipc_signal_suffix, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None ): """ Start expert service """ expert_service = ExpertService(cfg, local_data_parallel_id) try: expert_service.start( ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc ) expert_service.split_connector.start_receiver() except Exception as e: llm_logger.exception(f"Expert service failed to start: {e}")