""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ from __future__ import annotations import copy import os import threading import time import traceback import weakref from concurrent.futures import ThreadPoolExecutor from typing import Dict, List, Optional, Tuple import numpy as np import paddle import zmq from opentelemetry import trace from fastdeploy.engine.request import Request, RequestOutput from fastdeploy.engine.resource_manager import ResourceManager from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1 from fastdeploy.inter_communicator import ( EngineCacheQueue, EngineWorkerQueue, IPCSignal, ZmqClient, ) from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.metrics.trace_util import start_span, start_span_request from fastdeploy.model_executor.guided_decoding import schema_checker from fastdeploy.output.token_processor import TokenProcessor from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.utils import EngineError, envs, llm_logger class EngineSevice: """ Base class containing common engine functionality """ def __init__(self, cfg): """ Initializes the LLMEngine with the provided configuration. Args: cfg (Config): Config object containing all the configuration parameters. """ self.cfg = cfg self.scheduler = cfg.scheduler_config.scheduler() if envs.ENABLE_V1_KVCACHE_SCHEDULER: self.resource_manager = ResourceManagerV1( cfg.max_num_seqs, cfg, cfg.parallel_config.tensor_parallel_size, cfg.splitwise_role, cfg.parallel_config.local_data_parallel_id, ) if cfg.splitwise_role != "mixed": raise NotImplementedError( "Currently ENABLE_V1_KVCACHE_SCHEDULER=1 only supported in mixed sampling now." ) else: self.resource_manager = ResourceManager( cfg.max_num_seqs, cfg, cfg.parallel_config.tensor_parallel_size, cfg.splitwise_role, cfg.parallel_config.local_data_parallel_id, ) self.start_worker_queue_service() os.environ["INFERENCE_MSG_QUEUE_ID"] = self.cfg.engine_worker_queue_port[ self.cfg.parallel_config.local_data_parallel_id ] self.split_connector = SplitwiseConnector(cfg, self.engine_worker_queue, self.resource_manager) self.waiting_requests = [] self.token_processor = TokenProcessor( cfg=cfg, cached_generated_tokens=self.scheduler, engine_worker_queue=self.engine_worker_queue, split_connector=self.split_connector, ) self.token_processor.set_resource_manager(self.resource_manager) self.partial_chunked_tokens = [0] * (self.cfg.max_num_partial_prefills + 1) for idx in range(1, self.cfg.max_num_partial_prefills + 1): self.partial_chunked_tokens[idx] = ( (self.cfg.max_num_batched_tokens // idx) // self.cfg.cache_config.block_size * self.cfg.cache_config.block_size ) self.guided_decoding_checker = None if self.cfg.guided_decoding_backend != "off": self.guided_decoding_checker = schema_checker( self.cfg.guided_decoding_backend, disable_any_whitespace=self.cfg.disable_any_whitespace, ) self._init_worker_monitor_signals() self._finalizer = weakref.finalize(self, self._exit_sub_services) def start(self): self.running = True if envs.ENABLE_V1_KVCACHE_SCHEDULER: self.insert_task_to_worker_thread = threading.Thread(target=self._scheduler_task_to_worker_v1, daemon=True) else: self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, daemon=True) self.insert_task_to_worker_thread.start() self.token_processor.tasks_queue = self.engine_worker_queue self.token_processor.run() def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进程感知是否有新Task需要处理 current_suffix = int(self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]) llm_logger.info(f"current_suffix: {current_suffix}") exist_task_signal_data = np.zeros([1], dtype=np.int32) self.exist_task_signal = IPCSignal( name="exist_task_signal", array=exist_task_signal_data, dtype=np.int32, suffix=current_suffix, create=True, ) # exist_swapped_task_signal 用于engine感知worker中是否存在swapped task exist_swapped_task_signal_data = np.zeros([1], dtype=np.int32) self.exist_swapped_task_signal = IPCSignal( name="exist_swapped_task_signal", array=exist_swapped_task_signal_data, dtype=np.int32, suffix=current_suffix, create=True, ) # exist_prefill_task_signal 用于各worker进程感知是否进行prefill exist_prefill_task_signal_data = np.zeros([1], dtype=np.int32) self.exist_prefill_task_signal = IPCSignal( name="exist_prefill_task_signal", array=exist_prefill_task_signal_data, dtype=np.int32, suffix=current_suffix, create=True, ) # worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间 worker_healthy_live_recorded_time_array = np.zeros( shape=[min(self.cfg.worker_num_per_node, self.cfg.parallel_config.tensor_parallel_size)], dtype=np.int32 ) self.worker_healthy_live_signal = IPCSignal( name="worker_healthy_live_signal", array=worker_healthy_live_recorded_time_array, dtype=np.int32, suffix=current_suffix, create=True, ) model_weights_status = np.zeros([1], dtype=np.int32) self.model_weights_status_signal = IPCSignal( name="model_weights_status", array=model_weights_status, dtype=np.int32, suffix=current_suffix, create=True, ) def start_worker_queue_service(self): """ start queue service for engine worker communication """ address = ( self.cfg.master_ip, int(self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]), ) if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0": llm_logger.info(f"Starting engine worker queue server service at {address}") self.engine_worker_queue_server = EngineWorkerQueue( address=address, is_server=True, num_client=self.cfg.parallel_config.tensor_parallel_size, local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, ) if ( self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed" and self.cfg.parallel_config.local_data_parallel_id == 0 ): self.cache_task_queue = EngineCacheQueue( address=( self.cfg.master_ip, self.cfg.cache_config.cache_queue_port, ), authkey=b"cache_queue_service", is_server=True, num_client=self.cfg.parallel_config.tensor_parallel_size, client_id=-1, local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, ) llm_logger.info( f"local {min(self.cfg.worker_num_per_node * self.cfg.node_rank + self.cfg.parallel_config.local_data_parallel_id,self.cfg.parallel_config.data_parallel_size - 1)}" ) self.engine_worker_queue = EngineWorkerQueue( address=address, is_server=False, num_client=self.cfg.parallel_config.tensor_parallel_size, client_id=0, local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, local_data_parallel_id=min( self.cfg.worker_num_per_node * self.cfg.node_rank + self.cfg.parallel_config.local_data_parallel_id, self.cfg.parallel_config.data_parallel_size - 1, ), ) def insert_tasks(self, tasks, current_id=-1, allocated=False): """ Insert tasks to engine. """ for task in tasks: start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER) # TODO 返回至 scheduler if allocated: current_tasks = [] for task in tasks: cur_task_idx = self.resource_manager.req_dict[task.request_id] del self.resource_manager.req_dict[task.request_id] cur_task = self.resource_manager.tasks_list[cur_task_idx] cur_task.prompt_token_ids[0] = task.outputs.token_ids[0] if self.cfg.speculative_config.method in ["mtp"] and self.cfg.splitwise_role == "decode": cur_task.draft_token_ids = copy.deepcopy(task.outputs.draft_token_ids) if task.error_code != 200: self.resource_manager.stop_flags[cur_task_idx] = True self.resource_manager.tasks_list[cur_task_idx] = None self.resource_manager._recycle_block_tables(cur_task) if task.request_id in self.token_processor.tokens_counter: del self.token_processor.tokens_counter[task.request_id] self.scheduler.put_results([task]) llm_logger.warning( f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource." ) continue self.token_processor.tokens_counter[task.request_id] = 1 current_tasks.append(cur_task) self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz)) return True self.resource_manager.check_and_free_block_tables() if not isinstance(tasks, list): tasks = [tasks] for item in tasks: item.schedule_start_time = time.time() available_batch = np.sum(self.resource_manager.stop_flags) if len(tasks) > available_batch: llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.") llm_logger.error("The exceeded part will be ignored!") tasks = tasks[:available_batch] req_ids = [t.request_id for t in tasks] tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks) if not tasks: error_msg = f"The request required resources is exceed the limit, request id={req_ids}." llm_logger.error(error_msg) raise EngineError(error_msg, error_code=500) return False self.token_processor.number_of_tasks += len(tasks) is_decode = False is_prefill = False for i in range(len(tasks)): if tasks[i].disaggregate_info is not None: if tasks[i].disaggregate_info["role"] == "decode": is_decode = True else: is_prefill = True self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len self.split_connector.send_cache_infos(tasks, current_id) if not is_decode: llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}") for task in tasks: task.inference_start_time = time.time() if not is_prefill: if not self.cfg.model_config.enable_mm: self.update_requests_chunk_size(tasks) else: self.update_mm_requests_chunk_size(tasks) self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz)) if is_prefill and self.cfg.scheduler_config.name != "splitwise": self.engine_worker_queue.available_prefill_instances.put(1) return True def task_is_finished(self, index): """ judge if the task is finished """ assert index < len(self.resource_manager.stop_flags) return self.resource_manager.stop_flags[index] def all_tasks_finished(self): """ judge if all tasks are finished """ return np.sum(self.resource_manager.stop_flags) == len(self.resource_manager.stop_flags) def update_requests_chunk_size(self, requests): """ update each request's chunk size info """ def update_tokens(idx, chunk_size, update_chunk=False): nonlocal remain_batched_tokens, chunk_request_num if update_chunk: requests_chunk[idx][-1] += chunk_size else: requests_chunk[idx].append(chunk_size) remain_batched_tokens -= chunk_size current_request_size[idx] -= chunk_size if current_request_size[idx] <= 0: chunk_request_num -= 1 if not self.cfg.cache_config.enable_chunked_prefill or len(requests) == 0: return current_request_size = [request.prompt_token_ids_len for request in requests] requests_chunk = [[] for _ in range(len(requests))] chunk_request_num = len(current_request_size) while chunk_request_num >= 1: remain_batched_tokens = self.cfg.max_num_batched_tokens for idx in range(len(current_request_size)): if current_request_size[idx] <= 0: continue chunk_size = min( current_request_size[idx], self.partial_chunked_tokens[chunk_request_num], ) update_tokens(idx, chunk_size) while remain_batched_tokens >= self.cfg.cache_config.block_size: # 当前 max_num_batched_tokens 还有剩余时,优先分配给较短的请求 waiting_requests = [input_lens for input_lens in current_request_size if input_lens > 0] if len(waiting_requests) == 0: break available_tokens = ( remain_batched_tokens // self.cfg.cache_config.block_size * self.cfg.cache_config.block_size ) append_idx = current_request_size.index(min(waiting_requests)) chunk_size = min( current_request_size[append_idx], self.partial_chunked_tokens[chunk_request_num], available_tokens, ) update_tokens(append_idx, chunk_size, update_chunk=True) for idx in range(len(requests)): requests[idx].set("prefill_chunk_info", requests_chunk[idx]) def update_mm_requests_chunk_size(self, requests): """ update each multimodal request's chunk size info """ if not self.cfg.cache_config.enable_chunked_prefill or len(requests) == 0: return for request in requests: inputs = request.multimodal_inputs # 兼容没有图片和视频的情况 if inputs["images"] is None: inputs["image_type_ids"] = np.array([], dtype="int32") inputs["grid_thw"] = np.array([], dtype="int64") inputs["images"] = np.array([], dtype="uint8") input_ids = paddle.to_tensor(inputs["input_ids"], dtype="int64") image_type_ids = paddle.to_tensor(inputs["image_type_ids"], dtype="int32") image_mask = input_ids == self.data_processor.image_patch_id image_token_sum = paddle.full(shape=[len(input_ids) + 1], fill_value=0, dtype="int32") image_token_sum[1:] = paddle.cumsum(image_mask.cast("int32")) grid_thw = [] for one in inputs["grid_thw"]: if one[0] == 1: grid_thw.append(one) else: grid_thw.extend([[2, one[1], one[2]]] * (one[0] // 2)) grid_thw = paddle.to_tensor(grid_thw, dtype="int64") from fastdeploy.model_executor.ops.gpu import get_mm_split_fuse chunk_image_num, chunk_seq_len = get_mm_split_fuse( input_ids, image_type_ids, image_token_sum, grid_thw, self.data_processor.image_patch_id, len(grid_thw), 0, len(input_ids), 0, self.partial_chunked_tokens[1], 2048, ) grid_thw = grid_thw.numpy().reshape([-1, 3]) num_chunks = len(chunk_image_num) chunks_info = [] input_ids_st, image_type_ids_st, grid_thw_st, patch_st = 0, 0, 0, 0 for idx in range(num_chunks): chunk_input_ids = inputs["input_ids"][input_ids_st : input_ids_st + chunk_seq_len[idx]] chunk_token_type_ids = inputs["token_type_ids"][input_ids_st : input_ids_st + chunk_seq_len[idx]] actual_image_num = np.sum(grid_thw[grid_thw_st : grid_thw_st + chunk_image_num[idx], 0]) chunk_image_type_ids = inputs["image_type_ids"][ image_type_ids_st : image_type_ids_st + actual_image_num ] chunk_grid_thw = grid_thw[grid_thw_st : grid_thw_st + chunk_image_num[idx]] chunk_patch_num = np.sum(np.prod(chunk_grid_thw, axis=1)) chunk_images = inputs["images"][patch_st : patch_st + chunk_patch_num] chunks_info.append( { "input_ids": chunk_input_ids, "token_type_ids": chunk_token_type_ids, "image_type_ids": (chunk_image_type_ids if chunk_image_type_ids.shape[0] else None), "grid_thw": (chunk_grid_thw if chunk_grid_thw.shape[0] else None), "images": (chunk_images if chunk_images.shape[0] else None), "position_ids": None, } ) input_ids_st += chunk_seq_len[idx] image_type_ids_st += actual_image_num grid_thw_st += chunk_image_num[idx] patch_st += chunk_patch_num request.set("prefill_chunk_info", chunks_info) def _insert_task_to_worker(self): """ Insert task to engine thread, monitor scheduler request queue. if the engine has resource, insert task to engine """ current_id = -1 while getattr(self, "running", True): try: if self.resource_manager.available_batch() == 0: time.sleep(0.001) continue if self.engine_worker_queue.num_tasks() > 0: time.sleep(0.001) continue if hasattr(self, "exist_prefill_task_signal") and self.exist_prefill_task_signal.value[0] > 0: if self.cfg.splitwise_role == "mixed" or self.split_connector.has_splitwise_tasks(): time.sleep(0.005) continue if self.engine_worker_queue.num_cache_infos() > 0: time.sleep(0.001) continue if len(self.split_connector.current_request_ids) > 0: time.sleep(0.001) continue num_prefill_batch = min( int(self.resource_manager.available_batch()), self.cfg.max_prefill_batch, ) self.resource_manager.check_and_free_block_tables() tasks = self.scheduler.get_requests( available_blocks=self.resource_manager.available_block_num(), block_size=self.cfg.cache_config.block_size, reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num, max_num_batched_tokens=self.cfg.max_num_batched_tokens, batch=num_prefill_batch, ) if len(tasks) == 0: time.sleep(0.001) continue current_id = (current_id + 1) % 100003 if self.cfg.splitwise_role != "mixed": llm_logger.info("Inserting splitwise tasks") self.split_connector.send_splitwise_tasks(tasks, current_id) self.insert_tasks(tasks, current_id) main_process_metrics.num_requests_waiting.dec(len(tasks)) main_process_metrics.num_requests_running.inc(len(tasks)) except Exception as e: err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}." llm_logger.error(err_msg) def _scheduler_task_to_worker_v1(self): """ Insert tasks to worker with scheduler v1 (ENABLE_V1_KVCACHE_SCHEDULER=1). """ get_request_pool = ThreadPoolExecutor(max_workers=1) is_fetching = False def _fetch_request(): nonlocal is_fetching is_fetching = True num_prefill_batch = min( int(self.resource_manager.available_batch()), self.cfg.max_prefill_batch, ) self.resource_manager.check_and_free_block_tables() tasks = self.scheduler.get_requests( available_blocks=self.resource_manager.available_block_num(), block_size=self.cfg.cache_config.block_size, reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num, max_num_batched_tokens=self.cfg.max_model_len, batch=num_prefill_batch, ) # Fetch requests and add them to the scheduling queue for task in tasks: self.resource_manager.add_request(task) is_fetching = False while self.running: try: if self.engine_worker_queue.num_tasks() > 0: time.sleep(0.001) continue if ( len(self.resource_manager.waiting) == 0 and (not is_fetching) and self.exist_prefill_task_signal.value[0] == 0 ): get_request_pool.submit(_fetch_request) # 2. Schedule requests tasks = self.resource_manager.schedule() # 3. Send to engine if tasks: self.resource_manager.get_real_bsz() self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz)) else: time.sleep(0.005) except Exception as e: err_msg = "Error happend while insert task to engine: {}, {}.".format(e, str(traceback.format_exc())) llm_logger.error(err_msg) def start_zmq_service(self, api_server_pid=None): if api_server_pid is None: return self.api_server_pid = api_server_pid self.zmq_server = ZmqClient(name=api_server_pid, mode=zmq.PULL) self.zmq_server.start_server() self.zmq_server.create_router() time.sleep(3) self.insert_task_to_scheduler_thread = threading.Thread(target=self._insert_zmq_task_to_scheduler, daemon=True) self.insert_task_to_scheduler_thread.start() self.receive_output_thread = threading.Thread(target=self._zmq_send_generated_tokens, daemon=True) self.receive_output_thread.start() def _insert_zmq_task_to_scheduler(self): added_requests: Dict[str, int] = dict() while self.running: try: block = True if len(added_requests) == 0 else False if not self.cfg.model_config.enable_mm: err, data = self.zmq_server.receive_json_once(block) else: err, data = self.zmq_server.receive_pyobj_once(block) if err is not None: llm_logger.error("Engine stops inserting zmq task into scheduler, err:{err}") break request, insert_task = None, [] results: List[Tuple[str, Optional[str]]] = list() if data: err_msg = None try: request = Request.from_dict(data) start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER) llm_logger.debug(f"Receive request: {request}") except Exception as e: llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}") err_msg = str(e) results.append((data["request_id"], err_msg)) if self.guided_decoding_checker is not None and err_msg is None: request, err_msg = self.guided_decoding_checker.schema_format(request) if err_msg is not None: llm_logger.error(f"Receive request error: {err_msg}") results.append((request.request_id, err_msg)) if err_msg is None: insert_task.append(request) response = self.scheduler.put_requests(insert_task) results.extend(response) if request: if request.request_id not in added_requests: added_requests[request.request_id] = 0 added_requests[request.request_id] += 1 for request_id, failed in results: if request_id in added_requests: added_requests[request_id] -= 1 if added_requests[request_id] == 0: added_requests.pop(request_id) if failed is None: main_process_metrics.num_requests_waiting.inc(1) continue error_result = RequestOutput( request_id=request_id, finished=True, error_code=500, error_msg=failed, ) # Since the request is not in scheduler # Send result by zmq directly self.zmq_server.send_multipart(request_id, [error_result]) except Exception as e: llm_logger.error( f"Error happend while receving new request from zmq, details={e}, " f"traceback={traceback.format_exc()}" ) def _zmq_send_generated_tokens(self): """ Recieve output for zmq """ while self.running: try: results = self.scheduler.get_results() if len(results) == 0: time.sleep(0.005) continue for request_id, contents in results.items(): llm_logger.info(f"Send results: {request_id} {contents}") self.zmq_server.send_multipart(request_id, contents) except Exception as e: llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}") def split_mode_get_tasks(self): """ Split mode get tasks """ def receiver_loop(): while self.running: try: processed_indices = [] for idx, task in enumerate(self.waiting_requests): if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): self.insert_tasks([task]) llm_logger.info(f"Resource available, processing task {task.request_id}") processed_indices.append(idx) else: llm_logger.debug(f"Still waiting for resources {task.request_id}") break for idx in sorted(processed_indices, reverse=True): self.waiting_requests.pop(idx) if not self.engine_worker_queue.disaggregate_queue_empty(): items = self.engine_worker_queue.get_disaggregated_tasks() for item in items: role = item[0] tasks = item[1] if role == "prefill": for task in tasks: task.max_tokens = task.min_tokens = 2 self.insert_tasks(tasks) elif role == "decode": if hasattr(tasks[0], "finished"): if not isinstance(tasks, list): tasks = [tasks] for task in tasks: task.finished = False self.insert_tasks(tasks, allocated=True) if self.cfg.innode_prefill_ports is not None: self.scheduler.put_results(tasks) else: if len(self.waiting_requests): llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}") self.waiting_requests.extend(tasks) else: new_waiting = [] for task in tasks: if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): self.insert_tasks([task]) else: new_waiting.append(task) if new_waiting: self.waiting_requests.extend(new_waiting) llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue") else: time.sleep(0.001) except Exception as e: llm_logger.error(f"Error in main loop: {e}") time.sleep(0.1) threading.Thread(target=receiver_loop, daemon=True).start() def start_cache_service(self, device_ids, ipc_signal_suffix): return self.resource_manager.cache_manager.launch_cache_manager( cache_config=self.cfg.cache_config, tensor_parallel_size=self.cfg.parallel_config.tensor_parallel_size, device_ids=device_ids, pod_ip=self.cfg.master_ip, engine_worker_queue_port=int( self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id] ), pid_suffix=ipc_signal_suffix, ) def check_and_free_block_tables(self): self.resource_manager.check_and_free_block_tables() def _exit_sub_services(self): """ exit sub services """ self.running = False self.engine_worker_queue_server.cleanup() self.exist_task_signal.clear() self.exist_swapped_task_signal.clear() self.worker_healthy_live_signal.clear() self.exist_prefill_task_signal.clear() self.model_weights_status_signal.clear() if hasattr(self, "zmq_server") and self.zmq_server is not None: self.zmq_server.close()