""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ from __future__ import annotations import copy import os import signal import threading import time import traceback import weakref from collections import deque import numpy as np from fastdeploy.engine.request import RequestOutput from fastdeploy.engine.resource_manager import ResourceManager from fastdeploy.inter_communicator import EngineWorkerQueue from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.output.token_processor import TokenProcessor from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.utils import EngineError, console_logger, envs, get_logger, llm_logger class ExpertService: """ Engine class responsible for managing the Large Language Model (LLM) operations. Attributes: cfg (Config): Configuration object containing all the parameters. local_data_parallel_id (int): Local data parallel ID. """ def __init__(self, cfg, local_data_parallel_id): """ Initializes the LLMEngine with the provided configuration. Args: cfg (Config): Config object containing all the configuration parameters. """ self.cfg = cfg start_pos = (local_data_parallel_id * self.cfg.tensor_parallel_size) % cfg.worker_num_per_node end_pos = start_pos + self.cfg.tensor_parallel_size self.waiting_requests = [] self.disaggregate_queue = deque() self.llm_logger = get_logger("expert_service", f"expert_service_{local_data_parallel_id}.log") if cfg.splitwise_role != "mixed": self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos] self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos] self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id self.cfg.disaggregate_info = None self.scheduler = cfg.scheduler_config.scheduler() if self.cfg.scheduler_config.name == "splitwise": self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}") self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id address = (cfg.master_ip, cfg.engine_worker_queue_port + local_data_parallel_id) self.engine_worker_queue = EngineWorkerQueue( address=address, is_server=False, client_id=0, num_client=cfg.tensor_parallel_size, local_data_parallel_id=local_data_parallel_id, ) self.resource_manager = ResourceManager( cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role, local_data_parallel_id, ) if cfg.splitwise_role != "mixed": if len(self.cfg.cache_config.pd_comm_port) == 1: self.cfg.cache_config.pd_comm_port[0] = ( int(self.cfg.cache_config.pd_comm_port[0]) + local_data_parallel_id ) else: self.cfg.cache_config.pd_comm_port = [self.cfg.cache_config.pd_comm_port[local_data_parallel_id]] self.split_connector = SplitwiseConnector( self.cfg, self.scheduler, self.engine_worker_queue, self.resource_manager, self.disaggregate_queue ) self.token_processor = TokenProcessor( cfg=cfg, cached_generated_tokens=self.scheduler, engine_worker_queue=self.engine_worker_queue, split_connector=self.split_connector, ) self.token_processor.set_resource_manager(self.resource_manager) self.partial_chunked_tokens = [0] * (self.cfg.max_num_partial_prefills + 1) for idx in range(1, self.cfg.max_num_partial_prefills + 1): self.partial_chunked_tokens[idx] = ( (self.cfg.max_num_batched_tokens // idx) // self.cfg.cache_config.block_size * self.cfg.cache_config.block_size ) self._finalizer = weakref.finalize(self, self._exit_sub_services) if envs.FD_ENABLE_INTERNAL_ADAPTER: self.external_adapter = InternalAdapter(cfg=self.cfg, engine=self, dp_rank=local_data_parallel_id) def start( self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None ): """ Initializes the engine and starts its sub-services. If `api_server_pid` is defined, will launch a thread to keep getting request from zmq_server. """ # assert not self.is_started, "The engine is already started." start_time = time.time() self.llm_logger.info(f"start expert service {local_data_parallel_id}") if self.cfg.splitwise_role != "mixed": self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager( cache_config=self.cfg.cache_config, tensor_parallel_size=self.cfg.tensor_parallel_size, device_ids=self.cfg.local_device_ids, pod_ip=self.cfg.master_ip, engine_worker_queue_port=self.cfg.engine_worker_queue_port, pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}", ) self.split_mode_get_tasks() self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, args=()) self.insert_task_to_worker_thread.daemon = True self.insert_task_to_worker_thread.start() # Start TokenProcessor thread os.environ["INFERENCE_MSG_QUEUE_ID"] = str(local_data_parallel_id + int(self.cfg.engine_worker_queue_port)) self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK self.token_processor.run() self.cfg.init_cache_info() role = self.cfg.splitwise_role host_ip = self.cfg.host_ip disaggregate = self.cfg.disaggregate_info if self.cfg.scheduler_config.name == "dp": assert (request_queues_for_dp_ipc is not None) and (result_queue_for_dp_ipc is not None) self.scheduler.start(local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc) elif self.cfg.scheduler_config.name == "splitwise": self.scheduler.start(role, host_ip, disaggregate) self.cfg.print() console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.") return True def _insert_task_to_worker(self): """ Insert task to engine thread, monitor scheduler request queue. if the engine has resource, insert task to engine """ current_id = 0 while True: try: if self.resource_manager.available_batch() == 0: time.sleep(0.001) continue if self.engine_worker_queue.num_tasks() > 0: time.sleep(0.001) continue num_prefill_batch = min( int(self.resource_manager.available_batch()), self.cfg.max_prefill_batch, ) if envs.FD_ENABLE_INTERNAL_ADAPTER: num_prefill_batch = int(self.resource_manager.available_batch()) self.resource_manager.check_and_free_block_tables() tasks = self.scheduler.get_requests( available_blocks=self.resource_manager.available_block_num(), block_size=self.cfg.cache_config.block_size, reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num, max_num_batched_tokens=self.cfg.max_num_batched_tokens, batch=num_prefill_batch, ) if len(tasks) == 0: time.sleep(0.001) continue if self.cfg.splitwise_role != "mixed": self.llm_logger.info("Inserting splitwise tasks") self.split_connector.send_splitwise_tasks(tasks, current_id) insert_successful = self.insert_tasks(tasks, current_id) if insert_successful: current_id = current_id + 1 else: continue main_process_metrics.num_requests_waiting.dec(len(tasks)) main_process_metrics.num_requests_running.inc(len(tasks)) except Exception as e: err_msg = "Error happend while insert task to engine: {}, {}.".format(e, str(traceback.format_exc())) self.llm_logger.error(err_msg) def split_mode_get_tasks(self): """ Split mode get tasks """ def receiver_loop(): while True: try: processed_indices = [] for idx, task in enumerate(self.waiting_requests): if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): self.insert_tasks([task]) self.llm_logger.info(f"Resource available, processing task {task.request_id}") processed_indices.append(idx) else: self.llm_logger.debug(f"Still waiting for resources {task.request_id}") break for idx in sorted(processed_indices, reverse=True): self.waiting_requests.pop(idx) if len(self.disaggregate_queue) > 0: items = self.disaggregate_queue.pop() role = items[0] tasks = items[1] if role == "prefill": for task in tasks: task.max_tokens = task.min_tokens = 2 self.insert_tasks(tasks) elif role == "decode": if hasattr(tasks[0], "finished"): if not isinstance(tasks, list): tasks = [tasks] for task in tasks: task.finished = False self.insert_tasks(tasks, allocated=True) if self.cfg.innode_prefill_ports is not None: self.scheduler.put_results(tasks) else: if len(self.waiting_requests): self.llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}") self.waiting_requests.extend(tasks) else: new_waiting = [] for task in tasks: if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): self.insert_tasks([task]) else: if not self.enable_decode_cache_task: task.error_msg = "Not enough resources" new_waiting.append(task) if new_waiting: if not self.enable_decode_cache_task: self.split_connector.send_cache_infos(new_waiting, -1) else: self.waiting_requests.extend(new_waiting) self.llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue") else: time.sleep(0.001) except Exception as e: self.llm_logger.error(f"Error in main loop: {e} {str(traceback.format_exc())}") time.sleep(0.1) threading.Thread(target=receiver_loop, daemon=True).start() def insert_tasks(self, tasks, current_id=-1, allocated=False): """ Insert tasks to engine. """ if allocated: current_tasks = [] for task in tasks: cur_task_idx = self.resource_manager.req_dict[task.request_id] del self.resource_manager.req_dict[task.request_id] cur_task = self.resource_manager.tasks_list[cur_task_idx] if envs.FD_ENABLE_INTERNAL_ADAPTER: if not task.outputs.token_ids: # first token is eos in Prefill, just recycle resource and continue self.resource_manager.stop_flags[cur_task_idx] = True self.resource_manager.tasks_list[cur_task_idx] = None self.resource_manager._recycle_block_tables(cur_task) if task.request_id in self.token_processor.tokens_counter: del self.token_processor.tokens_counter[task.request_id] self.llm_logger.warning(f"{task.request_id} need not decode after first token") continue cur_task.prompt_token_ids[0] = task.outputs.token_ids[0] if self.cfg.speculative_config.method in ["mtp"] and self.cfg.splitwise_role == "decode": cur_task.draft_token_ids = copy.deepcopy(task.outputs.draft_token_ids) if task.error_code != 200: self.resource_manager.stop_flags[cur_task_idx] = True self.resource_manager.tasks_list[cur_task_idx] = None self.resource_manager._recycle_block_tables(cur_task) if task.request_id in self.token_processor.tokens_counter: del self.token_processor.tokens_counter[task.request_id] self.llm_logger.warning( f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource." ) continue self.token_processor.tokens_counter[task.request_id] = 1 current_tasks.append(cur_task) if current_tasks: self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz)) return True self.resource_manager.check_and_free_block_tables() if not isinstance(tasks, list): tasks = [tasks] need_delete_tasks = [] for task in tasks: if self.cfg.splitwise_role != "mixed": status, msg = self.split_connector.check_decode_allocated(task) if not status: self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.") self.scheduler.put_results( [ RequestOutput( request_id=task.request_id, finished=True, error_code=500, error_msg=msg, ) ] ) need_delete_tasks.append(task) continue for tmp_task in need_delete_tasks: tasks.remove(tmp_task) for item in tasks: item.schedule_start_time = time.time() req_ids = [t.request_id for t in tasks] if len(tasks) == 0: return False available_batch = np.sum(self.resource_manager.stop_flags) if len(tasks) > available_batch: self.llm_logger.error( "Inserting batch:{} exceeds the available batch:{}.".format(len(tasks), available_batch) ) self.llm_logger.error("The exceeded part will be ignored!") tasks = tasks[:available_batch] tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks) if not tasks: error_msg = f"The request required resources is exceed the limit, request id={req_ids}." self.llm_logger.error(error_msg) raise EngineError(error_msg, error_code=500) return False self.token_processor.number_of_tasks += len(tasks) is_decode = False is_prefill = False for i in range(len(tasks)): if tasks[i].disaggregate_info is not None: if tasks[i].disaggregate_info["role"] == "decode": is_decode = True else: is_prefill = True self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len if is_decode or is_prefill: self.split_connector.send_cache_infos(tasks, current_id) for task in tasks: task.infer_start_time = time.time() if not is_decode: self.llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}") if not is_prefill and self.cfg.cache_config.enable_chunked_prefill: if not self.cfg.enable_mm: self.update_requests_chunk_size(tasks) else: self.update_mm_requests_chunk_size(tasks) self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz)) return True def _exit_sub_services(self): """ exit sub services """ if hasattr(self, "cache_manager_processes"): self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear() self.resource_manager.cache_manager.cache_ready_signal.clear() for p in self.cache_manager_processes: self.llm_logger.info(f"Killing cache manager process {p.pid}") try: os.killpg(p.pid, signal.SIGTERM) except: pass if hasattr(self, "zmq_server") and self.zmq_server is not None: self.zmq_server.close() def start_expert_service( cfg, local_data_parallel_id, ipc_signal_suffix, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None ): """ Start expert service """ expert_service = ExpertService(cfg, local_data_parallel_id) try: expert_service.start( ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc ) expert_service.split_connector.start_receiver() except Exception as e: llm_logger.exception(f"Expert service failed to start: {e}")