diff --git a/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu b/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu index 24e24456d..1c3a45e50 100644 --- a/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu +++ b/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu @@ -28,6 +28,16 @@ #define DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK, ...) \ switch (num_experts_per_rank) { \ + case 2: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 2; \ + __VA_ARGS__ \ + break; \ + } \ + case 6: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 6; \ + __VA_ARGS__ \ + break; \ + } \ case 8: { \ constexpr size_t NUM_EXPERTS_PER_RANK = 8; \ __VA_ARGS__ \ diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index 409941f7d..a9a46d4c4 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -23,7 +23,11 @@ import numpy as np import paddle from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager -from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal +from fastdeploy.inter_communicator import ( + EngineWorkerQueue, + IPCSignal, + shared_memory_exists, +) from fastdeploy.utils import get_logger logger = get_logger("cache_messager", "cache_messager.log") @@ -159,36 +163,23 @@ class CacheMessager: try: prefilled_step_idx_data = np.zeros(shape=[1], dtype=np.int32) prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32) - try: - step_shm_value = IPCSignal( - name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}", - array=prefilled_step_idx_data, - dtype=np.int32, - suffix=self.gpu_id, - create=True, - ) - layer_shm_value = IPCSignal( - name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}", - array=prefilled_layer_idx_data, - dtype=np.int32, - suffix=self.gpu_id, - create=True, - ) - except: - step_shm_value = IPCSignal( - name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}", - array=prefilled_step_idx_data, - dtype=np.int32, - suffix=self.gpu_id, - create=False, - ) - layer_shm_value = IPCSignal( - name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}", - array=prefilled_layer_idx_data, - dtype=np.int32, - suffix=self.gpu_id, - create=False, - ) + prefilled_layer_name = f"splitwise_complete_prefilled_step_{self.dp_rank_id}.{self.gpu_id}" + prefilled_step_name = f"splitwise_complete_prefilled_step_{self.dp_rank_id}.{self.gpu_id}" + step_shm_value = IPCSignal( + name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}", + array=prefilled_step_idx_data, + dtype=np.int32, + suffix=self.gpu_id, + create=not shared_memory_exists(prefilled_step_name), + ) + layer_shm_value = IPCSignal( + name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}", + array=prefilled_layer_idx_data, + dtype=np.int32, + suffix=self.gpu_id, + create=not shared_memory_exists(prefilled_layer_name), + ) + logger.info(f"splitwise_complete_prefilled_step_{self.dp_rank_id}, gpu_id: {self.gpu_id}") step_shm_value.value[0] = -1 layer_shm_value.value[0] = -1 @@ -220,6 +211,7 @@ class CacheMessager: self.cache_info[info["request_id"]] = info prefilled_layer_idx = layer_shm_value.value[0] prefilled_step_idx = step_shm_value.value[0] + logger.info(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}") if prefilled_layer_idx == self.num_layers - 1: time.sleep(0.001) prefilled_layer_idx = layer_shm_value.value[0] diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 9b0192ec8..0b949339e 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -95,7 +95,7 @@ PRETRAINED_INIT_CONFIGURATION = { "start_layer_index": 0, "moe_num_shared_experts": 0, "moe_layer_start_index": 0, - "num_max_dispatch_tokens_per_rank": 256, + "num_max_dispatch_tokens_per_rank": 128, "moe_use_aux_free": False, "vocab_size": -1, "hidden_dropout_prob": 0.0, @@ -278,7 +278,7 @@ class ParallelConfig: # block size self.block_size: int = 64 # Engine worker queue port - self.engine_worker_queue_port: int = 9923 + self.engine_worker_queue_port: str = "9923" # Max model len self.max_model_len: int = 3072 # max_seq_len # cuda visible devices @@ -307,7 +307,11 @@ class ParallelConfig: for key, value in args.items(): if hasattr(self, key): setattr(self, key, value) - + if isinstance(self.engine_worker_queue_port, str): + self.engine_worker_queue_port = [int(port) for port in self.engine_worker_queue_port.split(",")] + logger.info(f"engine_worker_queue_port: {self.engine_worker_queue_port}") + elif isinstance(self.engine_worker_queue_port, int): + self.engine_worker_queue_port = [self.engine_worker_queue_port] # currently, the expert parallel size is equal data parallel size if self.enable_expert_parallel: self.expert_parallel_size = self.data_parallel_size * self.tensor_parallel_size @@ -1038,7 +1042,7 @@ class FDConfig: max_num_batched_tokens: Optional[int] = None, ips: str = None, use_warmup: bool = False, - engine_worker_queue_port: int = 8002, + engine_worker_queue_port: str = "8002", limit_mm_per_prompt: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, splitwise_role: str = "mixed", @@ -1082,11 +1086,10 @@ class FDConfig: if self.ips is None: self.master_ip = "0.0.0.0" - elif isinstance(self.ips, list): - self.master_ip = self.ips[0] - else: + elif isinstance(self.ips, str): self.ips = self.ips.split(",") - self.master_ip = self.ips[0] + + self.host_ip = get_host_ip() if self.ips is None: self.nnode = 1 @@ -1095,7 +1098,7 @@ class FDConfig: self.nnode = len(self.ips) for idx, ip in enumerate(self.ips): - if ip == self.master_ip: + if ip == self.host_ip: self.node_rank = idx self.max_model_len = max_model_len @@ -1111,7 +1114,11 @@ class FDConfig: self.reasoning_parser = reasoning_parser self.guided_decoding_backend = guided_decoding_backend self.disable_any_whitespace = disable_any_whitespace + self.engine_worker_queue_port = engine_worker_queue_port self._str_to_list("innode_prefill_ports", int) + if isinstance(engine_worker_queue_port, int): + self.engine_worker_queue_port = str(engine_worker_queue_port) + self._str_to_list("engine_worker_queue_port", str) if envs.FD_FOR_TORCH_MODEL_FORMAT: self.model_config.model_format = "torch" @@ -1129,10 +1136,11 @@ class FDConfig: self.worker_num_per_node = self.max_chips_per_node nnode = ceil_div(num_ranks, self.worker_num_per_node) assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}" + + # assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}" else: self.worker_num_per_node = num_ranks - self.engine_worker_queue_port = engine_worker_queue_port self.device_ids = ",".join([str(i) for i in range(self.worker_num_per_node)]) self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids) if current_platform.is_xpu(): @@ -1155,15 +1163,12 @@ class FDConfig: self.local_device_ids = self.device_ids.split(",")[: self.parallel_config.tensor_parallel_size] - self.host_ip = get_host_ip() - - if self.ips is None or self.host_ip == self.master_ip: - self.is_master = True - else: - self.is_master = False - if self.parallel_config.tensor_parallel_size <= self.worker_num_per_node: self.is_master = True + self.master_ip = "0.0.0.0" + else: + self.is_master = False + self.master_ip = self.ips[0] self.paddle_commit_id = paddle.version.commit @@ -1345,10 +1350,12 @@ class FDConfig: def _str_to_list(self, attr_name, default_type): if hasattr(self, attr_name): val = getattr(self, attr_name) + if val is None: + return if type(val) is str: setattr(self, attr_name, [default_type(i) for i in val.split(",")]) else: - setattr(self, attr_name, val) + setattr(self, attr_name, [default_type(i) for i in val]) def __str__(self) -> str: return json.dumps(self.__dict__, indent=4) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index cbdb0cecf..a0b72e8e1 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -193,7 +193,7 @@ class EngineArgs: Flag to enable the custom all-reduce kernel. """ - engine_worker_queue_port: int = 8002 + engine_worker_queue_port: str = "8002" """ Port for worker queue communication. """ @@ -208,6 +208,11 @@ class EngineArgs: Number of data parallelism. """ + local_data_parallel_id: int = 0 + """ + Local data parallel id. + """ + enable_expert_parallel: bool = False """ Enable expert parallelism. @@ -498,7 +503,7 @@ class EngineArgs: ) model_group.add_argument( "--engine-worker-queue-port", - type=int, + type=lambda s: s.split(",") if s else None, default=EngineArgs.engine_worker_queue_port, help="port for engine worker queue", ) @@ -607,6 +612,13 @@ class EngineArgs: default=EngineArgs.data_parallel_size, help="Degree of data parallelism.", ) + + parallel_group.add_argument( + "--local-data-parallel-id", + type=int, + default=EngineArgs.local_data_parallel_id, + help="the rank of data parallelism.", + ) parallel_group.add_argument( "--enable-expert-parallel", action="store_true", @@ -947,8 +959,13 @@ class EngineArgs: early_stop_cfg = self.create_early_stop_config() early_stop_cfg.update_enable_early_stop(self.enable_early_stop) + if isinstance(self.engine_worker_queue_port, int): + self.engine_worker_queue_port = str(self.engine_worker_queue_port) + if isinstance(self.engine_worker_queue_port, str): + self.engine_worker_queue_port = self.engine_worker_queue_port.split(",") + assert is_port_available( - "0.0.0.0", self.engine_worker_queue_port + "0.0.0.0", int(self.engine_worker_queue_port[parallel_cfg.local_data_parallel_id]) ), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use." return FDConfig( diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py new file mode 100644 index 000000000..801067952 --- /dev/null +++ b/fastdeploy/engine/common_engine.py @@ -0,0 +1,754 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import annotations + +import copy +import os +import threading +import time +import traceback +import weakref +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Tuple + +import numpy as np +import paddle +import zmq +from opentelemetry import trace + +from fastdeploy.engine.request import Request, RequestOutput +from fastdeploy.engine.resource_manager import ResourceManager +from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1 +from fastdeploy.inter_communicator import ( + EngineCacheQueue, + EngineWorkerQueue, + IPCSignal, + ZmqClient, +) +from fastdeploy.metrics.metrics import main_process_metrics +from fastdeploy.metrics.trace_util import start_span, start_span_request +from fastdeploy.model_executor.guided_decoding import schema_checker +from fastdeploy.output.token_processor import TokenProcessor +from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector +from fastdeploy.utils import EngineError, envs, llm_logger + + +class EngineSevice: + """ + Base class containing common engine functionality + """ + + def __init__(self, cfg): + """ + Initializes the LLMEngine with the provided configuration. + + Args: + cfg (Config): Config object containing all the configuration parameters. + """ + self.cfg = cfg + + self.scheduler = cfg.scheduler_config.scheduler() + + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + self.resource_manager = ResourceManagerV1( + cfg.max_num_seqs, + cfg, + cfg.parallel_config.tensor_parallel_size, + cfg.splitwise_role, + cfg.parallel_config.local_data_parallel_id, + ) + if cfg.splitwise_role != "mixed": + raise NotImplementedError( + "Currently ENABLE_V1_KVCACHE_SCHEDULER=1 only supported in mixed sampling now." + ) + else: + self.resource_manager = ResourceManager( + cfg.max_num_seqs, + cfg, + cfg.parallel_config.tensor_parallel_size, + cfg.splitwise_role, + cfg.parallel_config.local_data_parallel_id, + ) + + self.start_worker_queue_service() + + os.environ["INFERENCE_MSG_QUEUE_ID"] = self.cfg.engine_worker_queue_port[ + self.cfg.parallel_config.local_data_parallel_id + ] + + self.split_connector = SplitwiseConnector(cfg, self.engine_worker_queue, self.resource_manager) + self.waiting_requests = [] + self.token_processor = TokenProcessor( + cfg=cfg, + cached_generated_tokens=self.scheduler, + engine_worker_queue=self.engine_worker_queue, + split_connector=self.split_connector, + ) + self.token_processor.set_resource_manager(self.resource_manager) + + self.partial_chunked_tokens = [0] * (self.cfg.max_num_partial_prefills + 1) + for idx in range(1, self.cfg.max_num_partial_prefills + 1): + self.partial_chunked_tokens[idx] = ( + (self.cfg.max_num_batched_tokens // idx) + // self.cfg.cache_config.block_size + * self.cfg.cache_config.block_size + ) + + self.guided_decoding_checker = None + if self.cfg.guided_decoding_backend != "off": + self.guided_decoding_checker = schema_checker( + self.cfg.guided_decoding_backend, + disable_any_whitespace=self.cfg.disable_any_whitespace, + ) + self._init_worker_monitor_signals() + + self._finalizer = weakref.finalize(self, self._exit_sub_services) + + def start(self): + self.running = True + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + self.insert_task_to_worker_thread = threading.Thread(target=self._scheduler_task_to_worker_v1, daemon=True) + else: + self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, daemon=True) + self.insert_task_to_worker_thread.start() + self.token_processor.tasks_queue = self.engine_worker_queue + self.token_processor.run() + + def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进程感知是否有新Task需要处理 + current_suffix = int(self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]) + llm_logger.info(f"current_suffix: {current_suffix}") + exist_task_signal_data = np.zeros([1], dtype=np.int32) + self.exist_task_signal = IPCSignal( + name="exist_task_signal", + array=exist_task_signal_data, + dtype=np.int32, + suffix=current_suffix, + create=True, + ) + + # exist_swapped_task_signal 用于engine感知worker中是否存在swapped task + exist_swapped_task_signal_data = np.zeros([1], dtype=np.int32) + self.exist_swapped_task_signal = IPCSignal( + name="exist_swapped_task_signal", + array=exist_swapped_task_signal_data, + dtype=np.int32, + suffix=current_suffix, + create=True, + ) + + # exist_prefill_task_signal 用于各worker进程感知是否进行prefill + exist_prefill_task_signal_data = np.zeros([1], dtype=np.int32) + self.exist_prefill_task_signal = IPCSignal( + name="exist_prefill_task_signal", + array=exist_prefill_task_signal_data, + dtype=np.int32, + suffix=current_suffix, + create=True, + ) + + # worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间 + worker_healthy_live_recorded_time_array = np.zeros( + shape=[min(self.cfg.worker_num_per_node, self.cfg.parallel_config.tensor_parallel_size)], dtype=np.int32 + ) + self.worker_healthy_live_signal = IPCSignal( + name="worker_healthy_live_signal", + array=worker_healthy_live_recorded_time_array, + dtype=np.int32, + suffix=current_suffix, + create=True, + ) + + model_weights_status = np.zeros([1], dtype=np.int32) + self.model_weights_status_signal = IPCSignal( + name="model_weights_status", + array=model_weights_status, + dtype=np.int32, + suffix=current_suffix, + create=True, + ) + + def start_worker_queue_service(self): + """ + start queue service for engine worker communication + """ + address = ( + self.cfg.master_ip, + int(self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]), + ) + if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0": + llm_logger.info(f"Starting engine worker queue server service at {address}") + self.engine_worker_queue_server = EngineWorkerQueue( + address=address, + is_server=True, + num_client=self.cfg.parallel_config.tensor_parallel_size, + local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, + ) + + if ( + self.cfg.cache_config.enable_prefix_caching + or self.cfg.splitwise_role != "mixed" + and self.cfg.parallel_config.local_data_parallel_id == 0 + ): + self.cache_task_queue = EngineCacheQueue( + address=( + self.cfg.master_ip, + self.cfg.cache_config.cache_queue_port, + ), + authkey=b"cache_queue_service", + is_server=True, + num_client=self.cfg.parallel_config.tensor_parallel_size, + client_id=-1, + local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, + ) + llm_logger.info( + f"local {min(self.cfg.worker_num_per_node * self.cfg.node_rank + self.cfg.parallel_config.local_data_parallel_id,self.cfg.parallel_config.data_parallel_size - 1)}" + ) + self.engine_worker_queue = EngineWorkerQueue( + address=address, + is_server=False, + num_client=self.cfg.parallel_config.tensor_parallel_size, + client_id=0, + local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, + local_data_parallel_id=min( + self.cfg.worker_num_per_node * self.cfg.node_rank + self.cfg.parallel_config.local_data_parallel_id, + self.cfg.parallel_config.data_parallel_size - 1, + ), + ) + + def insert_tasks(self, tasks, current_id=-1, allocated=False): + """ + Insert tasks to engine. + """ + for task in tasks: + start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER) + + # TODO 返回至 scheduler + if allocated: + current_tasks = [] + for task in tasks: + cur_task_idx = self.resource_manager.req_dict[task.request_id] + del self.resource_manager.req_dict[task.request_id] + cur_task = self.resource_manager.tasks_list[cur_task_idx] + cur_task.prompt_token_ids[0] = task.outputs.token_ids[0] + if self.cfg.speculative_config.method in ["mtp"] and self.cfg.splitwise_role == "decode": + cur_task.draft_token_ids = copy.deepcopy(task.outputs.draft_token_ids) + if task.error_code != 200: + self.resource_manager.stop_flags[cur_task_idx] = True + self.resource_manager.tasks_list[cur_task_idx] = None + self.resource_manager._recycle_block_tables(cur_task) + if task.request_id in self.token_processor.tokens_counter: + del self.token_processor.tokens_counter[task.request_id] + self.scheduler.put_results([task]) + llm_logger.warning( + f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource." + ) + continue + self.token_processor.tokens_counter[task.request_id] = 1 + current_tasks.append(cur_task) + self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz)) + return True + + self.resource_manager.check_and_free_block_tables() + + if not isinstance(tasks, list): + tasks = [tasks] + + for item in tasks: + item.schedule_start_time = time.time() + + available_batch = np.sum(self.resource_manager.stop_flags) + if len(tasks) > available_batch: + llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.") + llm_logger.error("The exceeded part will be ignored!") + tasks = tasks[:available_batch] + + req_ids = [t.request_id for t in tasks] + + tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks) + + if not tasks: + error_msg = f"The request required resources is exceed the limit, request id={req_ids}." + llm_logger.error(error_msg) + raise EngineError(error_msg, error_code=500) + return False + + self.token_processor.number_of_tasks += len(tasks) + + is_decode = False + is_prefill = False + for i in range(len(tasks)): + if tasks[i].disaggregate_info is not None: + if tasks[i].disaggregate_info["role"] == "decode": + is_decode = True + else: + is_prefill = True + self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len + + self.split_connector.send_cache_infos(tasks, current_id) + if not is_decode: + llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}") + for task in tasks: + task.inference_start_time = time.time() + if not is_prefill: + if not self.cfg.model_config.enable_mm: + self.update_requests_chunk_size(tasks) + else: + self.update_mm_requests_chunk_size(tasks) + self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz)) + if is_prefill and self.cfg.scheduler_config.name != "splitwise": + self.engine_worker_queue.available_prefill_instances.put(1) + return True + + def task_is_finished(self, index): + """ + judge if the task is finished + """ + assert index < len(self.resource_manager.stop_flags) + return self.resource_manager.stop_flags[index] + + def all_tasks_finished(self): + """ + judge if all tasks are finished + """ + return np.sum(self.resource_manager.stop_flags) == len(self.resource_manager.stop_flags) + + def update_requests_chunk_size(self, requests): + """ + update each request's chunk size info + """ + + def update_tokens(idx, chunk_size, update_chunk=False): + nonlocal remain_batched_tokens, chunk_request_num + if update_chunk: + requests_chunk[idx][-1] += chunk_size + else: + requests_chunk[idx].append(chunk_size) + remain_batched_tokens -= chunk_size + current_request_size[idx] -= chunk_size + if current_request_size[idx] <= 0: + chunk_request_num -= 1 + + if not self.cfg.cache_config.enable_chunked_prefill or len(requests) == 0: + return + + current_request_size = [request.prompt_token_ids_len for request in requests] + requests_chunk = [[] for _ in range(len(requests))] + chunk_request_num = len(current_request_size) + while chunk_request_num >= 1: + remain_batched_tokens = self.cfg.max_num_batched_tokens + for idx in range(len(current_request_size)): + if current_request_size[idx] <= 0: + continue + chunk_size = min( + current_request_size[idx], + self.partial_chunked_tokens[chunk_request_num], + ) + update_tokens(idx, chunk_size) + + while remain_batched_tokens >= self.cfg.cache_config.block_size: + # 当前 max_num_batched_tokens 还有剩余时,优先分配给较短的请求 + waiting_requests = [input_lens for input_lens in current_request_size if input_lens > 0] + if len(waiting_requests) == 0: + break + + available_tokens = ( + remain_batched_tokens // self.cfg.cache_config.block_size * self.cfg.cache_config.block_size + ) + append_idx = current_request_size.index(min(waiting_requests)) + chunk_size = min( + current_request_size[append_idx], + self.partial_chunked_tokens[chunk_request_num], + available_tokens, + ) + update_tokens(append_idx, chunk_size, update_chunk=True) + + for idx in range(len(requests)): + requests[idx].set("prefill_chunk_info", requests_chunk[idx]) + + def update_mm_requests_chunk_size(self, requests): + """ + update each multimodal request's chunk size info + """ + if not self.cfg.cache_config.enable_chunked_prefill or len(requests) == 0: + return + + for request in requests: + inputs = request.multimodal_inputs + # 兼容没有图片和视频的情况 + if inputs["images"] is None: + inputs["image_type_ids"] = np.array([], dtype="int32") + inputs["grid_thw"] = np.array([], dtype="int64") + inputs["images"] = np.array([], dtype="uint8") + input_ids = paddle.to_tensor(inputs["input_ids"], dtype="int64") + image_type_ids = paddle.to_tensor(inputs["image_type_ids"], dtype="int32") + image_mask = input_ids == self.data_processor.image_patch_id + image_token_sum = paddle.full(shape=[len(input_ids) + 1], fill_value=0, dtype="int32") + image_token_sum[1:] = paddle.cumsum(image_mask.cast("int32")) + grid_thw = [] + for one in inputs["grid_thw"]: + if one[0] == 1: + grid_thw.append(one) + else: + grid_thw.extend([[2, one[1], one[2]]] * (one[0] // 2)) + grid_thw = paddle.to_tensor(grid_thw, dtype="int64") + + from fastdeploy.model_executor.ops.gpu import get_mm_split_fuse + + chunk_image_num, chunk_seq_len = get_mm_split_fuse( + input_ids, + image_type_ids, + image_token_sum, + grid_thw, + self.data_processor.image_patch_id, + len(grid_thw), + 0, + len(input_ids), + 0, + self.partial_chunked_tokens[1], + 2048, + ) + + grid_thw = grid_thw.numpy().reshape([-1, 3]) + num_chunks = len(chunk_image_num) + chunks_info = [] + input_ids_st, image_type_ids_st, grid_thw_st, patch_st = 0, 0, 0, 0 + for idx in range(num_chunks): + chunk_input_ids = inputs["input_ids"][input_ids_st : input_ids_st + chunk_seq_len[idx]] + chunk_token_type_ids = inputs["token_type_ids"][input_ids_st : input_ids_st + chunk_seq_len[idx]] + actual_image_num = np.sum(grid_thw[grid_thw_st : grid_thw_st + chunk_image_num[idx], 0]) + chunk_image_type_ids = inputs["image_type_ids"][ + image_type_ids_st : image_type_ids_st + actual_image_num + ] + chunk_grid_thw = grid_thw[grid_thw_st : grid_thw_st + chunk_image_num[idx]] + chunk_patch_num = np.sum(np.prod(chunk_grid_thw, axis=1)) + chunk_images = inputs["images"][patch_st : patch_st + chunk_patch_num] + + chunks_info.append( + { + "input_ids": chunk_input_ids, + "token_type_ids": chunk_token_type_ids, + "image_type_ids": (chunk_image_type_ids if chunk_image_type_ids.shape[0] else None), + "grid_thw": (chunk_grid_thw if chunk_grid_thw.shape[0] else None), + "images": (chunk_images if chunk_images.shape[0] else None), + "position_ids": None, + } + ) + + input_ids_st += chunk_seq_len[idx] + image_type_ids_st += actual_image_num + grid_thw_st += chunk_image_num[idx] + patch_st += chunk_patch_num + request.set("prefill_chunk_info", chunks_info) + + def _insert_task_to_worker(self): + """ + Insert task to engine thread, monitor scheduler request queue. + if the engine has resource, insert task to engine + """ + current_id = -1 + while getattr(self, "running", True): + try: + if self.resource_manager.available_batch() == 0: + time.sleep(0.001) + continue + if self.engine_worker_queue.num_tasks() > 0: + time.sleep(0.001) + continue + if hasattr(self, "exist_prefill_task_signal") and self.exist_prefill_task_signal.value[0] > 0: + if self.cfg.splitwise_role == "mixed" or self.split_connector.has_splitwise_tasks(): + time.sleep(0.005) + continue + if self.engine_worker_queue.num_cache_infos() > 0: + time.sleep(0.001) + continue + if len(self.split_connector.current_request_ids) > 0: + time.sleep(0.001) + continue + + num_prefill_batch = min( + int(self.resource_manager.available_batch()), + self.cfg.max_prefill_batch, + ) + + self.resource_manager.check_and_free_block_tables() + tasks = self.scheduler.get_requests( + available_blocks=self.resource_manager.available_block_num(), + block_size=self.cfg.cache_config.block_size, + reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num, + max_num_batched_tokens=self.cfg.max_num_batched_tokens, + batch=num_prefill_batch, + ) + + if len(tasks) == 0: + time.sleep(0.001) + continue + + current_id = (current_id + 1) % 100003 + if self.cfg.splitwise_role != "mixed": + llm_logger.info("Inserting splitwise tasks") + self.split_connector.send_splitwise_tasks(tasks, current_id) + + self.insert_tasks(tasks, current_id) + + main_process_metrics.num_requests_waiting.dec(len(tasks)) + main_process_metrics.num_requests_running.inc(len(tasks)) + except Exception as e: + err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}." + llm_logger.error(err_msg) + + def _scheduler_task_to_worker_v1(self): + """ + Insert tasks to worker with scheduler v1 (ENABLE_V1_KVCACHE_SCHEDULER=1). + """ + get_request_pool = ThreadPoolExecutor(max_workers=1) + is_fetching = False + + def _fetch_request(): + nonlocal is_fetching + is_fetching = True + num_prefill_batch = min( + int(self.resource_manager.available_batch()), + self.cfg.max_prefill_batch, + ) + + self.resource_manager.check_and_free_block_tables() + tasks = self.scheduler.get_requests( + available_blocks=self.resource_manager.available_block_num(), + block_size=self.cfg.cache_config.block_size, + reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num, + max_num_batched_tokens=self.cfg.max_model_len, + batch=num_prefill_batch, + ) + # Fetch requests and add them to the scheduling queue + for task in tasks: + self.resource_manager.add_request(task) + is_fetching = False + + while self.running: + try: + if self.engine_worker_queue.num_tasks() > 0: + time.sleep(0.001) + continue + if ( + len(self.resource_manager.waiting) == 0 + and (not is_fetching) + and self.exist_prefill_task_signal.value[0] == 0 + ): + get_request_pool.submit(_fetch_request) + # 2. Schedule requests + tasks = self.resource_manager.schedule() + # 3. Send to engine + if tasks: + self.resource_manager.get_real_bsz() + self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz)) + else: + time.sleep(0.005) + + except Exception as e: + err_msg = "Error happend while insert task to engine: {}, {}.".format(e, str(traceback.format_exc())) + llm_logger.error(err_msg) + + def start_zmq_service(self, api_server_pid=None): + if api_server_pid is None: + return + self.api_server_pid = api_server_pid + self.zmq_server = ZmqClient(name=api_server_pid, mode=zmq.PULL) + self.zmq_server.start_server() + self.zmq_server.create_router() + time.sleep(3) + self.insert_task_to_scheduler_thread = threading.Thread(target=self._insert_zmq_task_to_scheduler, daemon=True) + self.insert_task_to_scheduler_thread.start() + + self.receive_output_thread = threading.Thread(target=self._zmq_send_generated_tokens, daemon=True) + self.receive_output_thread.start() + + def _insert_zmq_task_to_scheduler(self): + added_requests: Dict[str, int] = dict() + while self.running: + try: + block = True if len(added_requests) == 0 else False + if not self.cfg.model_config.enable_mm: + err, data = self.zmq_server.receive_json_once(block) + else: + err, data = self.zmq_server.receive_pyobj_once(block) + if err is not None: + llm_logger.error("Engine stops inserting zmq task into scheduler, err:{err}") + break + + request, insert_task = None, [] + results: List[Tuple[str, Optional[str]]] = list() + if data: + request = Request.from_dict(data) + start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER) + llm_logger.debug(f"Receive request: {request}") + + err_msg = None + if self.guided_decoding_checker is not None: + request, err_msg = self.guided_decoding_checker.schema_format(request) + + if err_msg is not None: + llm_logger.error(err_msg) + results.append((request.request_id, err_msg)) + else: + insert_task.append(request) + + response = self.scheduler.put_requests(insert_task) + results.extend(response) + + if request: + if request.request_id not in added_requests: + added_requests[request.request_id] = 0 + added_requests[request.request_id] += 1 + + for request_id, failed in results: + added_requests[request_id] -= 1 + if added_requests[request_id] == 0: + added_requests.pop(request_id) + + if failed is None: + main_process_metrics.num_requests_waiting.inc(1) + continue + + error_result = RequestOutput( + request_id=request_id, + finished=True, + error_code=500, + error_msg=failed, + ) + # Since the request is not in scheduler + # Send result by zmq directly + self.zmq_server.send_multipart(request_id, error_result) + except Exception as e: + llm_logger.error( + f"Error happend while receving new request from zmq, details={e}, " + f"traceback={traceback.format_exc()}" + ) + + def _zmq_send_generated_tokens(self): + """ + Recieve output for zmq + """ + while self.running: + try: + results = self.scheduler.get_results() + if len(results) == 0: + time.sleep(0.005) + continue + for request_id, contents in results.items(): + llm_logger.info(f"Send results: {request_id} {contents}") + self.zmq_server.send_multipart(request_id, contents) + + except Exception as e: + llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}") + + def split_mode_get_tasks(self): + """ + Split mode get tasks + """ + + def receiver_loop(): + while self.running: + try: + + processed_indices = [] + for idx, task in enumerate(self.waiting_requests): + if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): + self.insert_tasks([task]) + llm_logger.info(f"Resource available, processing task {task.request_id}") + processed_indices.append(idx) + else: + llm_logger.debug(f"Still waiting for resources {task.request_id}") + break + + for idx in sorted(processed_indices, reverse=True): + self.waiting_requests.pop(idx) + + if not self.engine_worker_queue.disaggregate_queue_empty(): + items = self.engine_worker_queue.get_disaggregated_tasks() + for item in items: + role = item[0] + tasks = item[1] + + if role == "prefill": + for task in tasks: + task.max_tokens = task.min_tokens = 2 + self.insert_tasks(tasks) + + elif role == "decode": + if hasattr(tasks[0], "finished"): + if not isinstance(tasks, list): + tasks = [tasks] + for task in tasks: + task.finished = False + self.insert_tasks(tasks, allocated=True) + + if self.cfg.innode_prefill_ports is not None: + self.scheduler.put_results(tasks) + + else: + if len(self.waiting_requests): + llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}") + self.waiting_requests.extend(tasks) + else: + new_waiting = [] + for task in tasks: + if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): + self.insert_tasks([task]) + else: + new_waiting.append(task) + + if new_waiting: + self.waiting_requests.extend(new_waiting) + llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue") + + else: + time.sleep(0.001) + + except Exception as e: + llm_logger.error(f"Error in main loop: {e}") + time.sleep(0.1) + + threading.Thread(target=receiver_loop, daemon=True).start() + + def start_cache_service(self, device_ids, ipc_signal_suffix): + return self.resource_manager.cache_manager.launch_cache_manager( + cache_config=self.cfg.cache_config, + tensor_parallel_size=self.cfg.parallel_config.tensor_parallel_size, + device_ids=device_ids, + pod_ip=self.cfg.master_ip, + engine_worker_queue_port=int( + self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id] + ), + pid_suffix=ipc_signal_suffix, + ) + + def check_and_free_block_tables(self): + self.resource_manager.check_and_free_block_tables() + + def _exit_sub_services(self): + """ + exit sub services + """ + self.running = False + self.engine_worker_queue_server.cleanup() + self.exist_task_signal.clear() + self.exist_swapped_task_signal.clear() + self.worker_healthy_live_signal.clear() + self.exist_prefill_task_signal.clear() + self.model_weights_status_signal.clear() + if hasattr(self, "zmq_server") and self.zmq_server is not None: + self.zmq_server.close() diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 8b49f2659..de215d183 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -16,7 +16,6 @@ from __future__ import annotations -import copy import multiprocessing import os import re @@ -28,32 +27,17 @@ import time import traceback import uuid import weakref -from concurrent.futures import ThreadPoolExecutor -from typing import Dict, List, Optional, Tuple import numpy as np import paddle -import zmq -from opentelemetry import trace from tqdm import tqdm from fastdeploy.engine.args_utils import EngineArgs -from fastdeploy.engine.expert_service import start_expert_service -from fastdeploy.engine.request import Request, RequestOutput -from fastdeploy.engine.resource_manager import ResourceManager -from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1 +from fastdeploy.engine.common_engine import EngineSevice +from fastdeploy.engine.expert_service import start_data_parallel_service +from fastdeploy.engine.request import Request from fastdeploy.input.preprocess import InputPreprocessor -from fastdeploy.inter_communicator import ( - EngineCacheQueue, - EngineWorkerQueue, - IPCSignal, - ZmqClient, -) -from fastdeploy.metrics.metrics import main_process_metrics -from fastdeploy.metrics.trace_util import start_span, start_span_request -from fastdeploy.model_executor.guided_decoding import schema_checker -from fastdeploy.output.token_processor import TokenProcessor, WarmUpTokenProcessor -from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector +from fastdeploy.inter_communicator import IPCSignal from fastdeploy.utils import EngineError, console_logger, envs, llm_logger @@ -98,7 +82,7 @@ class LLMEngine: """ self.cfg = cfg self.running = True - self.scheduler = cfg.scheduler_config.scheduler() + self.is_started = False self.input_processor = InputPreprocessor( cfg.tokenizer, @@ -108,61 +92,14 @@ class LLMEngine: cfg.model_config.enable_mm, cfg.tool_parser, ) - - self.start_queue_service() - - if envs.ENABLE_V1_KVCACHE_SCHEDULER: - self.resource_manager = ResourceManagerV1( - cfg.max_num_seqs, cfg, cfg.parallel_config.tensor_parallel_size, cfg.splitwise_role - ) - if cfg.splitwise_role != "mixed": - raise NotImplementedError( - "Currently ENABLE_V1_KVCACHE_SCHEDULER=1 only supported in mixed sampling now." - ) - else: - self.resource_manager = ResourceManager( - cfg.max_num_seqs, cfg, cfg.parallel_config.tensor_parallel_size, cfg.splitwise_role - ) - - os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.cfg.engine_worker_queue_port) - - self.split_connector = SplitwiseConnector(cfg, self.scheduler, self.engine_worker_queue, self.resource_manager) - - self.token_processor = TokenProcessor( - cfg=self.cfg, - cached_generated_tokens=self.scheduler, - engine_worker_queue=self.engine_worker_queue, - split_connector=self.split_connector, - ) - self.token_processor.set_resource_manager(self.resource_manager) - - self.is_started = False - - self.waiting_requests = [] + self.engine = EngineSevice(cfg) if self.cfg.cache_config.num_gpu_blocks_override is None: self.do_profile = 1 else: self.do_profile = 0 - - self.partial_chunked_tokens = [0] * (self.cfg.max_num_partial_prefills + 1) - for idx in range(1, self.cfg.max_num_partial_prefills + 1): - self.partial_chunked_tokens[idx] = ( - (self.cfg.max_num_batched_tokens // idx) - // self.cfg.cache_config.block_size - * self.cfg.cache_config.block_size - ) - self.partial_chunked_tokens[idx] = max(1, self.partial_chunked_tokens[idx]) - self._finalizer = weakref.finalize(self, self._exit_sub_services) - self.guided_decoding_checker = None - if self.cfg.guided_decoding_backend != "off": - self.guided_decoding_checker = schema_checker( - self.cfg.guided_decoding_backend, - disable_any_whitespace=self.cfg.disable_any_whitespace, - ) - def start(self, api_server_pid=None): """ Initializes the engine and starts its sub-services. @@ -173,30 +110,22 @@ class LLMEngine: start_time = time.time() self.api_server_pid = api_server_pid - self.engine_pid = os.getpid() - self.ipc_signal_suffix = self.engine_pid if self.api_server_pid is None else self.api_server_pid + self.ipc_signal_suffix = self.cfg.engine_worker_queue_port[0] self._init_worker_signals() self.data_processor = self.input_processor.create_processor() + self.engine.data_processor = self.data_processor + self.engine.start() if api_server_pid is not None: - self.zmq_server = ZmqClient(name=api_server_pid, mode=zmq.PULL) - self.zmq_server.start_server() - self.zmq_server.create_router() - time.sleep(3) + llm_logger.info(f"Start zmq server, api_server_pid: {api_server_pid}") + self.engine.start_zmq_service(api_server_pid) if self.do_profile == 0 and ( self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed" ): device_ids = self.cfg.device_ids.split(",") - self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager( - cache_config=self.cfg.cache_config, - tensor_parallel_size=self.cfg.parallel_config.tensor_parallel_size, - device_ids=device_ids, - pod_ip=self.cfg.master_ip, - engine_worker_queue_port=self.cfg.engine_worker_queue_port, - pid_suffix=self.ipc_signal_suffix, - ) + self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix) self.worker_proc = self._start_worker_service() console_logger.info("Waiting worker processes ready...") @@ -236,214 +165,17 @@ class LLMEngine: console_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.") return False - # Start warmup if enabled - if self.cfg.use_warmup: - console_logger.info("Starting warmup") - self._set_warmup_token_processor() - self.warmup() - self._del_warmup_token_processor() - console_logger.info("Warmup finished") - console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.") return True - def _zmq_send_generated_tokens(self): - """ - Recieve output for zmq - """ - assert self.api_server_pid is not None - while self.running: - try: - results = self.scheduler.get_results() - if len(results) == 0: - time.sleep(0.005) - continue - for request_id, contents in results.items(): - self.zmq_server.send_multipart(request_id, contents) - - except Exception as e: - llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}") - def _get_generated_result(self): """ Get result from scheduler, this function is called by generate() which is only used in offline inference. """ - return self.scheduler.get_results() + return self.engine.scheduler.get_results() - def _insert_task_to_worker(self): - """ - Insert task to engine thread, monitor scheduler request queue. - if the engine has resource, insert task to engine - """ - current_id = -1 - while self.running: - try: - if self.resource_manager.available_batch() == 0: - time.sleep(0.001) - continue - if self.engine_worker_queue.num_tasks() > 0: - time.sleep(0.001) - continue - if self.exist_prefill_task_signal.value[0] > 0: - if self.cfg.splitwise_role == "mixed" or self.split_connector.has_splitwise_tasks(): - time.sleep(0.005) - continue - if self.engine_worker_queue.num_cache_infos() > 0: - time.sleep(0.001) - continue - if len(self.split_connector.current_request_ids) > 0: - time.sleep(0.001) - continue - - num_prefill_batch = min( - int(self.resource_manager.available_batch()), - self.cfg.max_prefill_batch, - ) - - self.resource_manager.check_and_free_block_tables() - tasks = self.scheduler.get_requests( - available_blocks=self.resource_manager.available_block_num(), - block_size=self.cfg.cache_config.block_size, - reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num, - max_num_batched_tokens=self.cfg.max_num_batched_tokens, - batch=num_prefill_batch, - ) - - if len(tasks) == 0: - time.sleep(0.001) - continue - - current_id = (current_id + 1) % 100003 - if self.cfg.splitwise_role != "mixed": - llm_logger.info("Inserting splitwise tasks") - self.split_connector.send_splitwise_tasks(tasks, current_id) - - self.insert_tasks(tasks, current_id) - - main_process_metrics.num_requests_waiting.dec(len(tasks)) - main_process_metrics.num_requests_running.inc(len(tasks)) - except Exception as e: - err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}." - llm_logger.error(err_msg) - - def _scheduler_task_to_worker_v1(self): - """ - Insert tasks to worker with scheduler v1 (ENABLE_V1_KVCACHE_SCHEDULER=1). - """ - get_request_pool = ThreadPoolExecutor(max_workers=1) - is_fetching = False - - def _fetch_request(): - nonlocal is_fetching - is_fetching = True - num_prefill_batch = min( - int(self.resource_manager.available_batch()), - self.cfg.max_prefill_batch, - ) - - self.resource_manager.check_and_free_block_tables() - tasks = self.scheduler.get_requests( - available_blocks=self.resource_manager.available_block_num(), - block_size=self.cfg.cache_config.block_size, - reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num, - max_num_batched_tokens=self.cfg.max_model_len, - batch=num_prefill_batch, - ) - # Fetch requests and add them to the scheduling queue - for task in tasks: - self.resource_manager.add_request(task) - is_fetching = False - - while self.running: - try: - if self.engine_worker_queue.num_tasks() > 0: - time.sleep(0.001) - continue - if ( - len(self.resource_manager.waiting) == 0 - and (not is_fetching) - and self.exist_prefill_task_signal.value[0] == 0 - ): - get_request_pool.submit(_fetch_request) - # 2. Schedule requests - tasks = self.resource_manager.schedule() - # 3. Send to engine - if tasks: - self.resource_manager.get_real_bsz() - self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz)) - else: - time.sleep(0.005) - - except Exception as e: - err_msg = "Error happend while insert task to engine: {}, {}.".format(e, str(traceback.format_exc())) - llm_logger.error(err_msg) - - def _insert_zmq_task_to_scheduler(self): - if self.api_server_pid is None: - return - - added_requests: Dict[str, int] = dict() - while self.running: - try: - block = True if len(added_requests) == 0 else False - if not self.cfg.model_config.enable_mm: - err, data = self.zmq_server.receive_json_once(block) - else: - err, data = self.zmq_server.receive_pyobj_once(block) - if err is not None: - llm_logger.error("Engine stops inserting zmq task into scheduler, err:{err}") - break - - request, insert_task = None, [] - results: List[Tuple[str, Optional[str]]] = list() - if data: - request = Request.from_dict(data) - start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER) - - llm_logger.debug(f"Receive request: {request}") - - err_msg = None - if self.guided_decoding_checker is not None: - request, err_msg = self.guided_decoding_checker.schema_format(request) - - if err_msg is not None: - llm_logger.error(err_msg) - results.append((request.request_id, err_msg)) - else: - insert_task.append(request) - - response = self.scheduler.put_requests(insert_task) - results.extend(response) - - if request: - if request.request_id not in added_requests: - added_requests[request.request_id] = 0 - added_requests[request.request_id] += 1 - - for request_id, failed in results: - added_requests[request_id] -= 1 - if added_requests[request_id] == 0: - added_requests.pop(request_id) - - if failed is None: - main_process_metrics.num_requests_waiting.inc(1) - continue - - error_result = RequestOutput( - request_id=request_id, - finished=True, - error_code=500, - error_msg=failed, - ) - # Since the request is not in scheduler - # Send result by zmq directly - self.zmq_server.send_multipart(request_id, error_result) - except Exception as e: - llm_logger.error( - f"Error happend while receving new request from zmq, details={e}, " - f"traceback={traceback.format_exc()}" - ) + # _insert_task_to_worker moved to CommonEngine def add_requests(self, task, sampling_params=None, **kwargs): """ @@ -514,341 +246,17 @@ class LLMEngine: llm_logger.error(error_msg) raise EngineError(error_msg, error_code=400) - if self.guided_decoding_checker is not None: - request, err_msg = self.guided_decoding_checker.schema_format(request) + if self.engine.guided_decoding_checker is not None: + request, err_msg = self.engine.guided_decoding_checker.schema_format(request) if err_msg is not None: llm_logger.error(err_msg) raise EngineError(err_msg, error_code=400) request.preprocess_end_time = time.time() - self.scheduler.put_requests([request]) + self.engine.scheduler.put_requests([request]) llm_logger.info(f"Cache task with request_id ({request.get('request_id')})") llm_logger.debug(f"cache task: {request}") - def warmup(self): - """ - construct test tasks and avoid out of memory problem in the worker process - """ - # get eos_token_id - pass - - def split_mode_get_tasks(self): - """ - Split mode get tasks - """ - - def receiver_loop(): - while self.running: - try: - - processed_indices = [] - for idx, task in enumerate(self.waiting_requests): - if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): - self.insert_tasks([task]) - llm_logger.info(f"Resource available, processing task {task.request_id}") - processed_indices.append(idx) - else: - llm_logger.debug(f"Still waiting for resources {task.request_id}") - break - - for idx in sorted(processed_indices, reverse=True): - self.waiting_requests.pop(idx) - - if not self.engine_worker_queue.disaggregate_queue_empty(): - items = self.engine_worker_queue.get_disaggregated_tasks() - for item in items: - role = item[0] - tasks = item[1] - - if role == "prefill": - for task in tasks: - task.max_tokens = task.min_tokens = 2 - self.insert_tasks(tasks) - - elif role == "decode": - if hasattr(tasks[0], "finished"): - if not isinstance(tasks, list): - tasks = [tasks] - for task in tasks: - task.finished = False - self.insert_tasks(tasks, allocated=True) - - if self.cfg.innode_prefill_ports is not None: - self.scheduler.put_results(tasks) - - else: - if len(self.waiting_requests): - llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}") - self.waiting_requests.extend(tasks) - else: - new_waiting = [] - for task in tasks: - if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): - self.insert_tasks([task]) - else: - new_waiting.append(task) - - if new_waiting: - self.waiting_requests.extend(new_waiting) - llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue") - - else: - time.sleep(0.001) - - except Exception as e: - llm_logger.error(f"Error in main loop: {e}, {str(traceback.format_exc())}") - time.sleep(0.1) - - threading.Thread(target=receiver_loop, daemon=True).start() - - def update_requests_chunk_size(self, requests): - """ - update each request's chunk size info - """ - - def update_tokens(idx, chunk_size, update_chunk=False): - nonlocal remain_batched_tokens, chunk_request_num - if update_chunk: - requests_chunk[idx][-1] += chunk_size - else: - requests_chunk[idx].append(chunk_size) - remain_batched_tokens -= chunk_size - current_request_size[idx] -= chunk_size - if current_request_size[idx] <= 0: - chunk_request_num -= 1 - - if not self.cfg.cache_config.enable_chunked_prefill or len(requests) == 0: - return - - current_request_size = [request.prompt_token_ids_len for request in requests] - requests_chunk = [[] for _ in range(len(requests))] - chunk_request_num = len(current_request_size) - while chunk_request_num >= 1: - remain_batched_tokens = self.cfg.max_num_batched_tokens - for idx in range(len(current_request_size)): - if current_request_size[idx] <= 0: - continue - chunk_size = min( - current_request_size[idx], - self.partial_chunked_tokens[chunk_request_num], - ) - update_tokens(idx, chunk_size) - - while remain_batched_tokens >= self.cfg.cache_config.block_size: - # 当前 max_num_batched_tokens 还有剩余时,优先分配给较短的请求 - waiting_requests = [input_lens for input_lens in current_request_size if input_lens > 0] - if len(waiting_requests) == 0: - break - - available_tokens = ( - remain_batched_tokens // self.cfg.cache_config.block_size * self.cfg.cache_config.block_size - ) - append_idx = current_request_size.index(min(waiting_requests)) - chunk_size = min( - current_request_size[append_idx], - self.partial_chunked_tokens[chunk_request_num], - available_tokens, - ) - update_tokens(append_idx, chunk_size, update_chunk=True) - - for idx in range(len(requests)): - requests[idx].set("prefill_chunk_info", requests_chunk[idx]) - - def update_mm_requests_chunk_size(self, requests): - """ - update each multimodal request's chunk size info - """ - if not self.cfg.cache_config.enable_chunked_prefill or len(requests) == 0: - return - - for request in requests: - inputs = request.multimodal_inputs - # 兼容没有图片和视频的情况 - if inputs["images"] is None: - inputs["image_type_ids"] = np.array([], dtype="int32") - inputs["grid_thw"] = np.array([], dtype="int64") - inputs["images"] = np.array([], dtype="uint8") - input_ids = paddle.to_tensor(inputs["input_ids"], dtype="int64") - image_type_ids = paddle.to_tensor(inputs["image_type_ids"], dtype="int32") - image_mask = input_ids == self.data_processor.image_patch_id - image_token_sum = paddle.full(shape=[len(input_ids) + 1], fill_value=0, dtype="int32") - image_token_sum[1:] = paddle.cumsum(image_mask.cast("int32")) - grid_thw = [] - for one in inputs["grid_thw"]: - if one[0] == 1: - grid_thw.append(one) - else: - grid_thw.extend([[2, one[1], one[2]]] * (one[0] // 2)) - grid_thw = paddle.to_tensor(grid_thw, dtype="int64") - - from fastdeploy.model_executor.ops.gpu import get_mm_split_fuse - - chunk_image_num, chunk_seq_len = get_mm_split_fuse( - input_ids, - image_type_ids, - image_token_sum, - grid_thw, - self.data_processor.image_patch_id, - len(grid_thw), - 0, - len(input_ids), - 0, - self.partial_chunked_tokens[1], - 2048, - ) - - grid_thw = grid_thw.numpy().reshape([-1, 3]) - num_chunks = len(chunk_image_num) - chunks_info = [] - input_ids_st, image_type_ids_st, grid_thw_st, patch_st = 0, 0, 0, 0 - for idx in range(num_chunks): - chunk_input_ids = inputs["input_ids"][input_ids_st : input_ids_st + chunk_seq_len[idx]] - chunk_token_type_ids = inputs["token_type_ids"][input_ids_st : input_ids_st + chunk_seq_len[idx]] - actual_image_num = np.sum(grid_thw[grid_thw_st : grid_thw_st + chunk_image_num[idx], 0]) - chunk_image_type_ids = inputs["image_type_ids"][ - image_type_ids_st : image_type_ids_st + actual_image_num - ] - chunk_grid_thw = grid_thw[grid_thw_st : grid_thw_st + chunk_image_num[idx]] - chunk_patch_num = np.sum(np.prod(chunk_grid_thw, axis=1)) - chunk_images = inputs["images"][patch_st : patch_st + chunk_patch_num] - - chunks_info.append( - { - "input_ids": chunk_input_ids, - "token_type_ids": chunk_token_type_ids, - "image_type_ids": (chunk_image_type_ids if chunk_image_type_ids.shape[0] else None), - "grid_thw": (chunk_grid_thw if chunk_grid_thw.shape[0] else None), - "images": (chunk_images if chunk_images.shape[0] else None), - "position_ids": None, - } - ) - - input_ids_st += chunk_seq_len[idx] - image_type_ids_st += actual_image_num - grid_thw_st += chunk_image_num[idx] - patch_st += chunk_patch_num - request.set("prefill_chunk_info", chunks_info) - - def insert_tasks(self, tasks, current_id=-1, allocated=False): - """ - Insert tasks to engine. - """ - # TODO 返回至 scheduler - if allocated: - current_tasks = [] - for task in tasks: - cur_task_idx = self.resource_manager.req_dict[task.request_id] - del self.resource_manager.req_dict[task.request_id] - cur_task = self.resource_manager.tasks_list[cur_task_idx] - cur_task.prompt_token_ids[0] = task.outputs.token_ids[0] - if self.cfg.speculative_config.method in ["mtp"] and self.cfg.splitwise_role == "decode": - cur_task.draft_token_ids = copy.deepcopy(task.outputs.draft_token_ids) - if task.error_code != 200: - self.resource_manager.stop_flags[cur_task_idx] = True - self.resource_manager.tasks_list[cur_task_idx] = None - self.resource_manager._recycle_block_tables(cur_task) - if task.request_id in self.token_processor.tokens_counter: - del self.token_processor.tokens_counter[task.request_id] - self.scheduler.put_results([task]) - llm_logger.warning( - f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource." - ) - continue - self.token_processor.tokens_counter[task.request_id] = 1 - current_tasks.append(cur_task) - self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz)) - return True - - for task in tasks: - start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER) - - self.resource_manager.check_and_free_block_tables() - - if not isinstance(tasks, list): - tasks = [tasks] - - for item in tasks: - item.schedule_start_time = time.time() - - available_batch = np.sum(self.resource_manager.stop_flags) - if len(tasks) > available_batch: - llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.") - llm_logger.error("The exceeded part will be ignored!") - tasks = tasks[:available_batch] - - req_ids = [t.request_id for t in tasks] - - tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks) - - if not tasks: - error_msg = f"The request required resources is exceed the limit, request id={req_ids}." - llm_logger.error(error_msg) - raise EngineError(error_msg, error_code=500) - return False - - self.token_processor.number_of_tasks += len(tasks) - - is_decode = False - is_prefill = False - for i in range(len(tasks)): - if tasks[i].disaggregate_info is not None: - if tasks[i].disaggregate_info["role"] == "decode": - is_decode = True - else: - is_prefill = True - self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len - - self.split_connector.send_cache_infos(tasks, current_id) - if not is_decode: - llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}") - for task in tasks: - task.inference_start_time = time.time() - if not is_prefill: - if not self.cfg.model_config.enable_mm: - self.update_requests_chunk_size(tasks) - else: - self.update_mm_requests_chunk_size(tasks) - self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz)) - if is_prefill and self.cfg.scheduler_config.name != "splitwise": - self.engine_worker_queue.available_prefill_instances.put(1) - return True - - def task_is_finished(self, index): - """ - judge if the task is finished - """ - assert index < len(self.resource_manager.stop_flags) - return self.resource_manager.stop_flags[index] - - def all_tasks_finished(self): - """ - judge if all tasks are finished - """ - return np.sum(self.resource_manager.stop_flags) == len(self.resource_manager.stop_flags) - - def _set_warmup_token_processor(self): - """ - set token_processor for warmup - """ - self.token_processor_backup = self.token_processor - self.token_processor = WarmUpTokenProcessor(self.cfg) - self.token_processor.set_resource_manager(self.resource_manager) - self.token_processor.tasks_queue = self.engine_worker_queue - - # start TokenProcessor thread - self.token_processor.run() - - def _del_warmup_token_processor(self): - """ - delete token_processor for warmup - """ - self.token_processor.stop() - del self.token_processor - - # reset token_processor - self.token_processor = self.token_processor_backup - del self.token_processor_backup - def _worker_processes_ready(self): """ judge if all worker processes are ready @@ -862,7 +270,7 @@ class LLMEngine: """ Initialize shared memory to indicate engine status """ - # worker_ready_signatensor_parallel_size + # worker_ready_signal 用于worker进程感知engine是否启动完成 worker_ready_signal_data = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32) self.worker_ready_signal = IPCSignal( name="worker_ready_signal", @@ -872,37 +280,7 @@ class LLMEngine: create=True, ) - # exist_task_signal: Used by each worker process to detect whether there is a new task to be processed - exist_task_signal_data = np.zeros([self.cfg.parallel_config.data_parallel_size], dtype=np.int32) - self.exist_task_signal = IPCSignal( - name="exist_task_signal", - array=exist_task_signal_data, - dtype=np.int32, - suffix=self.ipc_signal_suffix, - create=True, - ) - - # exist_swapped_task_signal: Used by the engine to detect whether there is a swapped task in the worker - exist_swapped_task_signal_data = np.zeros([self.cfg.parallel_config.data_parallel_size], dtype=np.int32) - self.exist_swapped_task_signal = IPCSignal( - name="exist_swapped_task_signal", - array=exist_swapped_task_signal_data, - dtype=np.int32, - suffix=self.ipc_signal_suffix, - create=True, - ) - - # exist_prefill_task_signal: Used by each worker process to detect whether to prefill - exist_prefill_task_signal_data = np.zeros([1], dtype=np.int32) - self.exist_prefill_task_signal = IPCSignal( - name="exist_prefill_task_signal", - array=exist_prefill_task_signal_data, - dtype=np.int32, - suffix=self.ipc_signal_suffix, - create=True, - ) - - # launched_cache_manager_signal: Used to detect whether the engine has started cache_manager + # launched_cache_manager_signal 用于感知engine是否启动了cache_manager if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed": launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32) self.launched_cache_manager_signal = IPCSignal( @@ -936,16 +314,6 @@ class LLMEngine: create=True, ) - # worker_live_signal: Used by the engine to detect whether each worker process is alive and record the time of each step - worker_healthy_live_recorded_time_array = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32) - self.worker_healthy_live_signal = IPCSignal( - name="worker_healthy_live_signal", - array=worker_healthy_live_recorded_time_array, - dtype=np.int32, - suffix=self.ipc_signal_suffix, - create=True, - ) - if self.do_profile: if paddle.is_compiled_with_custom_device("iluvatar_gpu"): get_profile_block_num = np.zeros([self.cfg.worker_num_per_node], dtype=np.int32) @@ -959,15 +327,6 @@ class LLMEngine: create=True, ) - model_weights_status = np.zeros([1], dtype=np.int32) - self.model_weights_status_signal = IPCSignal( - name="model_weights_status", - array=model_weights_status, - dtype=np.int32, - suffix=self.ipc_signal_suffix, - create=True, - ) - def _exit_sub_services(self): """ exit sub services @@ -975,8 +334,8 @@ class LLMEngine: self.running = False if hasattr(self, "cache_manager_processes"): - self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear() - self.resource_manager.cache_manager.cache_ready_signal.clear() + self.engine.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear() + self.engine.resource_manager.cache_manager.cache_ready_signal.clear() for p in self.cache_manager_processes: llm_logger.info(f"Killing cache manager process {p.pid}") try: @@ -986,20 +345,16 @@ class LLMEngine: f"Error killing cache manager process {p.pid}: {e}, {str(traceback.format_exc())}" ) self.worker_ready_signal.clear() - self.exist_task_signal.clear() - self.exist_swapped_task_signal.clear() - self.worker_healthy_live_signal.clear() - self.exist_prefill_task_signal.clear() + self.loaded_model_signal.clear() + if hasattr(self, "get_profile_block_num_signal"): self.get_profile_block_num_signal.clear() - self.model_weights_status_signal.clear() if hasattr(self, "worker_proc") and self.worker_proc is not None: try: os.killpg(self.worker_proc.pid, signal.SIGTERM) except Exception as e: console_logger.error(f"Error extracting sub services: {e}, {str(traceback.format_exc())}") - self.engine_worker_queue.cleanup() if hasattr(self, "zmq_server") and self.zmq_server is not None: self.zmq_server.close() if hasattr(self, "dp_processed"): @@ -1077,6 +432,10 @@ class LLMEngine: else len(self.data_processor.tokenizer.vocab) ) + ports = ",".join(self.cfg.engine_worker_queue_port) + ips = None + if self.cfg.ips is not None: + ips = ",".join(self.cfg.ips) arguments = ( f" --devices {self.cfg.device_ids} {py_script}" f" --max_num_seqs {self.cfg.max_num_seqs} --max_model_len {self.cfg.max_model_len}" @@ -1084,14 +443,14 @@ class LLMEngine: f" --model {self.cfg.model_config.model!s}" f" --device_ids {self.cfg.device_ids}" f" --tensor_parallel_size {self.cfg.parallel_config.tensor_parallel_size}" - f" --engine_worker_queue_port {self.cfg.engine_worker_queue_port!s}" + f" --engine_worker_queue_port {ports}" f" --pod_ip {self.cfg.master_ip}" f" --total_block_num {self.cfg.cache_config.total_block_num}" f" --block_size {self.cfg.cache_config.block_size}" f" --enc_dec_block_num {self.cfg.cache_config.enc_dec_block_num}" f" --eos_tokens_lens {self.data_processor.eos_token_id_len}" f" --pad_token_id {self.data_processor.pad_token_id}" - f" --engine_pid {self.engine_pid}" + f" --engine_pid {self.cfg.engine_worker_queue_port[0]}" f" --max_num_batched_tokens {self.cfg.max_num_batched_tokens}" f" --splitwise_role {self.cfg.splitwise_role}" f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}" @@ -1105,7 +464,7 @@ class LLMEngine: f" --load_strategy {self.cfg.load_config.load_strategy}" f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'" f" --load_choices {self.cfg.load_config.load_choices}" - f" --ips {self.cfg.ips}" + f" --ips {ips}" ) worker_append_flag = { @@ -1121,8 +480,9 @@ class LLMEngine: for worker_flag, value in worker_append_flag.items(): if value: arguments = arguments + f" --{worker_flag}" + llm_logger.info(f"gaoziyuan test ips :{self.cfg.ips}") if self.cfg.nnode > 1: - pd_cmd = pd_cmd + f" --ips {','.join(self.cfg.ips)} --nnodes {len(self.cfg.ips)}" + pd_cmd = pd_cmd + f" --ips {ips} --nnodes {len(self.cfg.ips)}" pd_cmd = pd_cmd + arguments + f" 2>{log_dir}/launch_worker.log" llm_logger.info(f"Launch worker service command: {pd_cmd}") p = subprocess.Popen( @@ -1197,7 +557,7 @@ class LLMEngine: output["outputs"]["reasoning_content"] = "" yield output - self.resource_manager.check_and_free_block_tables() + self.engine.check_and_free_block_tables() def _stop_profile(self): """ @@ -1208,57 +568,32 @@ class LLMEngine: time.sleep(1) num_gpu_blocks = self.get_profile_block_num_signal.value[0] self.cfg.cache_config.reset(num_gpu_blocks) - self.resource_manager.reset_cache_config(self.cfg.cache_config) + self.engine.resource_manager.reset_cache_config(self.cfg.cache_config) if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed": device_ids = self.cfg.device_ids.split(",") - self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager( - cache_config=self.cfg.cache_config, - tensor_parallel_size=self.cfg.parallel_config.tensor_parallel_size, - device_ids=device_ids, - pod_ip=self.cfg.master_ip, - engine_worker_queue_port=self.cfg.engine_worker_queue_port, - pid_suffix=self.ipc_signal_suffix, - ) + self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix) def check_health(self, time_interval_threashold=30): """ Check the health of the model server by checking whether all workers are alive. """ - if self.worker_healthy_live_signal.value[0]: - elapsed_time = time.time() - self.worker_healthy_live_signal.value[0] + if self.engine.worker_healthy_live_signal.value[0]: + elapsed_time = time.time() - self.engine.worker_healthy_live_signal.value[0] if elapsed_time > time_interval_threashold: return False, "Worker Service Not Healthy" return True, "" def launch_components(self): - self.token_processor.tasks_queue = self.engine_worker_queue - - if envs.ENABLE_V1_KVCACHE_SCHEDULER: - self.insert_task_to_worker_thread = threading.Thread(target=self._scheduler_task_to_worker_v1, daemon=True) - else: - self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, daemon=True) - self.insert_task_to_worker_thread.start() - - if self.api_server_pid is not None: - self.insert_task_to_scheduler_thread = threading.Thread( - target=self._insert_zmq_task_to_scheduler, daemon=True - ) - self.insert_task_to_scheduler_thread.start() - - self.receive_output_thread = threading.Thread(target=self._zmq_send_generated_tokens, daemon=True) - self.receive_output_thread.start() - - # Start TokenProcessor thread - self.token_processor.run() - if self.cfg.splitwise_role != "mixed": # 单机逻辑 - self.engine_worker_queue.available_prefill_instances.put(1) - self.split_mode_get_tasks() + self.engine.engine_worker_queue.available_prefill_instances.put(1) + self.engine.split_mode_get_tasks() if self.cfg.scheduler_config.name == "splitwise": - self.splitwise_receive_thread = threading.Thread(target=self.split_connector.start_receiver, args=()) + self.splitwise_receive_thread = threading.Thread( + target=self.engine.split_connector.start_receiver, args=() + ) self.splitwise_receive_thread.daemon = True self.splitwise_receive_thread.start() @@ -1268,35 +603,31 @@ class LLMEngine: host_ip = self.cfg.host_ip disaggregate = self.cfg.disaggregate_info if self.cfg.scheduler_config.name == "splitwise": - self.scheduler.start(role, host_ip, disaggregate) + self.engine.scheduler.start(role, host_ip, disaggregate) - time.sleep(1) - expert_service_nums = self.cfg.parallel_config.data_parallel_size // self.cfg.nnode - if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1: - self.dp_processed = [] - for i in range( - 1, - expert_service_nums, - ): - time.sleep(1) - self.dp_processed.append( - multiprocessing.Process( - target=start_expert_service, - args=( - self.cfg, - i + self.cfg.node_rank * self.cfg.worker_num_per_node, - self.ipc_signal_suffix, - ), + if not envs.FD_ENABLE_MULTI_API_SERVER: + if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1: + self.dp_processed = [] + for i in range( + 1, + self.cfg.parallel_config.data_parallel_size // self.cfg.nnode, + ): + self.dp_processed.append( + multiprocessing.Process( + target=start_data_parallel_service, + args=( + self.cfg, + i, + ), + ) ) - ) - llm_logger.info( - f"Engine is initialized successfully with {self.cfg.parallel_config.tensor_parallel_size}" - + f" data parallel id {i}" - ) - self.dp_processed[-1].start() - for i in range(1, expert_service_nums): - while self.launched_expert_service_signal.value[i] == 0: - time.sleep(10) + llm_logger.info( + f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}" + + f" data parallel id {i}" + ) + self.dp_processed[-1].start() + while self.launched_expert_service_signal.value[i] == 0: + time.sleep(1) def check_worker_initialize_status(self): """ @@ -1356,42 +687,3 @@ class LLMEngine: except Exception: pass return True - - def start_queue_service(self): - """ - start queue service for engine worker communication - """ - address = (self.cfg.master_ip, self.cfg.engine_worker_queue_port) - if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0": - llm_logger.info(f"Starting engine worker queue server service at {address}") - self.engine_worker_queue_server = EngineWorkerQueue( - address=address, - is_server=True, - num_client=self.cfg.parallel_config.tensor_parallel_size, - local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, - ) - - if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed": - self.cache_task_queue = EngineCacheQueue( - address=( - self.cfg.master_ip, - self.cfg.cache_config.cache_queue_port, - ), - authkey=b"cache_queue_service", - is_server=True, - num_client=self.cfg.parallel_config.tensor_parallel_size, - client_id=-1, - local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, - ) - - self.engine_worker_queue = EngineWorkerQueue( - address=address, - is_server=False, - num_client=self.cfg.parallel_config.tensor_parallel_size, - client_id=0, - local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, - local_data_parallel_id=min( - self.cfg.worker_num_per_node * self.cfg.node_rank, - self.cfg.parallel_config.data_parallel_size - 1, - ), - ) diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index 3b1e28c5d..3cbb68b0e 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -25,12 +25,9 @@ import weakref import numpy as np -from fastdeploy.engine.resource_manager import ResourceManager -from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal -from fastdeploy.metrics.metrics import main_process_metrics -from fastdeploy.output.token_processor import TokenProcessor -from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector -from fastdeploy.utils import EngineError, console_logger, llm_logger +from fastdeploy.engine.common_engine import EngineSevice +from fastdeploy.inter_communicator import IPCSignal +from fastdeploy.utils import console_logger, envs, llm_logger class ExpertService: @@ -49,36 +46,16 @@ class ExpertService: Args: cfg (Config): Config object containing all the configuration parameters. """ + self.cfg = cfg start_pos = (local_data_parallel_id * self.cfg.parallel_config.tensor_parallel_size) % cfg.worker_num_per_node end_pos = start_pos + self.cfg.parallel_config.tensor_parallel_size if cfg.splitwise_role != "mixed": self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos] self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos] - self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id + llm_logger.info(f"local_data_parallel_id: {local_data_parallel_id}") self.cfg.disaggregate_info = None - self.scheduler = cfg.scheduler_config.scheduler() - if cfg.splitwise_role != "mixed": - self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}") - - self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id - - address = (cfg.master_ip, cfg.engine_worker_queue_port) - self.engine_worker_queue = EngineWorkerQueue( - address=address, - is_server=False, - client_id=0, - num_client=cfg.parallel_config.tensor_parallel_size, - local_data_parallel_id=local_data_parallel_id, - ) - self.resource_manager = ResourceManager( - cfg.max_num_seqs, - cfg, - cfg.parallel_config.tensor_parallel_size, - cfg.splitwise_role, - local_data_parallel_id, - ) if cfg.splitwise_role != "mixed": if len(self.cfg.cache_config.pd_comm_port) == 1: self.cfg.cache_config.pd_comm_port[0] = ( @@ -86,29 +63,11 @@ class ExpertService: ) else: self.cfg.cache_config.pd_comm_port = [self.cfg.cache_config.pd_comm_port[local_data_parallel_id]] + self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id - self.split_connector = SplitwiseConnector( - self.cfg, - self.scheduler, - self.engine_worker_queue, - self.resource_manager, - ) - - self.token_processor = TokenProcessor( - cfg=cfg, - cached_generated_tokens=self.scheduler, - engine_worker_queue=self.engine_worker_queue, - split_connector=self.split_connector, - ) - self.token_processor.set_resource_manager(self.resource_manager) - - self.partial_chunked_tokens = [0] * (self.cfg.max_num_partial_prefills + 1) - for idx in range(1, self.cfg.max_num_partial_prefills + 1): - self.partial_chunked_tokens[idx] = ( - (self.cfg.max_num_batched_tokens // idx) - // self.cfg.cache_config.block_size - * self.cfg.cache_config.block_size - ) + self.engine = EngineSevice(self.cfg) + if self.cfg.scheduler_config.name == "splitwise": + self.engine.scheduler.reset_nodeid(f"{self.engine.scheduler.infer.nodeid}_{local_data_parallel_id!s}") self._finalizer = weakref.finalize(self, self._exit_sub_services) @@ -119,245 +78,62 @@ class ExpertService: to keep getting request from zmq_server. """ # assert not self.is_started, "The engine is already started." + start_time = time.time() + self.engine.start() + if ipc_signal_suffix is not None: + self.api_server_pid = ipc_signal_suffix + self.engine.start_zmq_service(ipc_signal_suffix) + else: + ipc_signal_suffix = self.cfg.engine_worker_queue_port[0] llm_logger.info(f"start expert service {local_data_parallel_id}") if self.cfg.splitwise_role != "mixed": - self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager( - cache_config=self.cfg.cache_config, - tensor_parallel_size=self.cfg.parallel_config.tensor_parallel_size, - device_ids=self.cfg.local_device_ids, - pod_ip=self.cfg.master_ip, - engine_worker_queue_port=self.cfg.engine_worker_queue_port, - pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}", + self.engine.start_cache_service(self.cfg.local_device_ids, ipc_signal_suffix) + self.engine.split_mode_get_tasks() + + if self.cfg.scheduler_config.name == "splitwise": + self.cfg.init_cache_info() + role = self.cfg.splitwise_role + host_ip = self.cfg.host_ip + disaggregate = self.cfg.disaggregate_info + self.engine.scheduler.start(role, host_ip, disaggregate) + + if self.cfg.splitwise_role != "mixed": + self.splitwise_receive_thread = threading.Thread( + target=self.engine.split_connector.start_receiver, args=() ) - self.split_mode_get_tasks() - - self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, args=()) - self.insert_task_to_worker_thread.daemon = True - self.insert_task_to_worker_thread.start() - - # Start TokenProcessor thread - os.environ["INFERENCE_MSG_QUEUE_ID"] = str(local_data_parallel_id + int(self.cfg.engine_worker_queue_port)) - - self.token_processor.run() - self.cfg.init_cache_info() - role = self.cfg.splitwise_role - host_ip = self.cfg.host_ip - disaggregate = self.cfg.disaggregate_info - self.scheduler.start(role, host_ip, disaggregate) + self.splitwise_receive_thread.daemon = True + self.splitwise_receive_thread.start() self.cfg.print() - - launched_expert_service_signal_data = np.zeros( - shape=[self.cfg.parallel_config.data_parallel_size // self.cfg.nnode], dtype=np.int32 - ) - self.launched_expert_service_signal = IPCSignal( - name="launched_expert_service_signal", - array=launched_expert_service_signal_data, - dtype=np.int32, - suffix=ipc_signal_suffix, - create=False, - ) local_rank = local_data_parallel_id % self.cfg.worker_num_per_node - self.launched_expert_service_signal.value[local_rank] = 1 + + if not envs.FD_ENABLE_MULTI_API_SERVER: + launched_expert_service_signal_data = np.zeros( + shape=[self.cfg.parallel_config.data_parallel_size // self.cfg.nnode], dtype=np.int32 + ) + self.launched_expert_service_signal = IPCSignal( + name="launched_expert_service_signal", + array=launched_expert_service_signal_data, + dtype=np.int32, + suffix=ipc_signal_suffix, + create=False, + ) + self.launched_expert_service_signal.value[local_rank] = 1 console_logger.info( f"Worker processes(rank {local_rank}) are launched with {time.time() - start_time} seconds." ) return True - def _insert_task_to_worker(self): - """ - Insert task to engine thread, monitor scheduler request queue. - if the engine has resource, insert task to engine - """ - current_id = -1 - while True: - try: - if self.resource_manager.available_batch() == 0: - time.sleep(0.001) - continue - if self.engine_worker_queue.num_tasks() > 0: - time.sleep(0.001) - continue - if len(self.split_connector.current_request_ids) > 0: - time.sleep(0.001) - continue - - num_prefill_batch = min( - int(self.resource_manager.available_batch()), - self.cfg.max_prefill_batch, - ) - - self.resource_manager.check_and_free_block_tables() - tasks = self.scheduler.get_requests( - available_blocks=self.resource_manager.available_block_num(), - block_size=self.cfg.cache_config.block_size, - reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num, - max_num_batched_tokens=self.cfg.max_num_batched_tokens, - batch=num_prefill_batch, - ) - - if len(tasks) == 0: - time.sleep(0.001) - continue - - if self.cfg.splitwise_role != "mixed": - llm_logger.info("Inserting splitwise tasks") - self.split_connector.send_splitwise_tasks(tasks, current_id) - - current_id = (current_id + 1) % 100003 - - self.insert_tasks(tasks, current_id) - - main_process_metrics.num_requests_waiting.dec(len(tasks)) - main_process_metrics.num_requests_running.inc(len(tasks)) - except Exception as e: - err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}." - llm_logger.error(err_msg) - - def split_mode_get_tasks(self): - """ - Split mode get tasks - """ - waiting_requests = [] - - def receiver_loop(): - while True: - try: - if len(waiting_requests) > 0: - for task in waiting_requests: - if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): - self.insert_tasks([task]) - waiting_requests.remove(task) - else: - break - if not self.engine_worker_queue.disaggregate_queue_empty(): - items = self.engine_worker_queue.get_disaggregated_tasks() - for item in items: - role = item[0] - tasks = item[1] - if role == "prefill": - llm_logger.info("get prefill tasks") - for task in tasks: - task.max_tokens = task.min_tokens = 2 - self.insert_tasks(tasks) - elif role == "decode": - llm_logger.info(f"get decode tasks {tasks}") - if hasattr(tasks[0], "finished"): - if not isinstance(tasks, list): - tasks = [tasks] - for task in tasks: - task.finished = False - # self.scheduler.put_results(tasks) - - self.insert_tasks(tasks, allocated=True) - else: - if len(waiting_requests): - for task in tasks: - waiting_requests.append(task) - else: - for task in tasks: - if not self.resource_manager.is_resource_sufficient( - task.prompt_token_ids_len - ): - waiting_requests.append(task) - else: - self.insert_tasks([task]) - - else: - time.sleep(0.001) - continue - except Exception as e: - llm_logger.error(f"get decode tasks error: {e}, {str(traceback.format_exc())}") - - threading.Thread(target=receiver_loop, daemon=True).start() - - def insert_tasks(self, tasks, current_id=-1, allocated=False): - """ - Insert tasks to engine. - """ - if allocated: - current_tasks = [] - for task in tasks: - cur_task_idx = self.resource_manager.req_dict[task.request_id] - del self.resource_manager.req_dict[task.request_id] - cur_task = self.resource_manager.tasks_list[cur_task_idx] - if task.error_code != 200: - self.resource_manager.stop_flags[cur_task_idx] = True - self.resource_manager.tasks_list[cur_task_idx] = None - self.resource_manager._recycle_block_tables(cur_task) - if task.request_id in self.token_processor.tokens_counter: - del self.token_processor.tokens_counter[task.request_id] - self.scheduler.put_results([task]) - llm_logger.warning( - f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource." - ) - continue - llm_logger.info(f"{cur_task_idx} {task.request_id}") - cur_task.prompt_token_ids[0] = task.outputs.token_ids[0] - self.token_processor.tokens_counter[task.request_id] = 1 - current_tasks.append(cur_task) - self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz)) - return True - - self.resource_manager.check_and_free_block_tables() - - if not isinstance(tasks, list): - tasks = [tasks] - - for item in tasks: - item.schedule_start_time = time.time() - - available_batch = np.sum(self.resource_manager.stop_flags) - if len(tasks) > available_batch: - llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.") - llm_logger.error("The exceeded part will be ignored!") - tasks = tasks[:available_batch] - - req_ids = [t.request_id for t in tasks] - - tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks) - - if not tasks: - error_msg = f"The request required resources is exceed the limit, request id={req_ids}." - llm_logger.error(error_msg) - raise EngineError(error_msg, error_code=500) - return False - - self.token_processor.number_of_tasks += len(tasks) - - is_decode = False - is_prefill = False - for i in range(len(tasks)): - if tasks[i].disaggregate_info is not None: - if tasks[i].disaggregate_info["role"] == "decode": - is_decode = True - else: - is_prefill = True - self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len - if is_decode or is_prefill: - self.split_connector.send_cache_infos(tasks, current_id) - for task in tasks: - task.infer_start_time = time.time() - if not is_decode: - llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}") - if not is_prefill and self.cfg.cache_config.enable_chunked_prefill: - if not self.cfg.model_config.enable_mm: - self.update_requests_chunk_size(tasks) - else: - self.update_mm_requests_chunk_size(tasks) - self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz)) - return True - def _exit_sub_services(self): """ exit sub services """ if hasattr(self, "cache_manager_processes"): - self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear() - self.resource_manager.cache_manager.cache_ready_signal.clear() + self.engine.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear() + self.engine.resource_manager.cache_manager.cache_ready_signal.clear() for p in self.cache_manager_processes: llm_logger.info(f"Killing cache manager process {p.pid}") try: @@ -369,13 +145,16 @@ class ExpertService: self.zmq_server.close() -def start_expert_service(cfg, local_data_parallel_id, ipc_signal_suffix): +def start_data_parallel_service(cfg, local_data_parallel_id, ipc_signal_suffix=None): """ Start expert service """ expert_service = ExpertService(cfg, local_data_parallel_id) + try: expert_service.start(ipc_signal_suffix, local_data_parallel_id) - expert_service.split_connector.start_receiver() + while True: + time.sleep(1000) + except Exception as e: llm_logger.exception(f"Expert service failed to start: {e}, {str(traceback.format_exc())}") diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index c407a7663..4e7857c85 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -45,6 +45,7 @@ class EngineClient: max_model_len, tensor_parallel_size, pid, + port, limit_mm_per_prompt, mm_processor_kwargs, # enable_mm=False, @@ -75,13 +76,19 @@ class EngineClient: self.data_processor = input_processor.create_processor() self.max_model_len = max_model_len max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 - array_size = min(max_chips_per_node, tensor_parallel_size * data_parallel_size) + + if tensor_parallel_size < max_chips_per_node: + self.is_master = True + else: + self.is_master = False + + array_size = min(max_chips_per_node, tensor_parallel_size) self.worker_healthy_live_recorded_time_array = np.zeros(shape=[array_size], dtype=np.int32) self.worker_healthy_live_signal = IPCSignal( name="worker_healthy_live_signal", array=self.worker_healthy_live_recorded_time_array, dtype=np.int32, - suffix=pid, + suffix=port, create=False, ) self.semaphore = StatefulSemaphore((FD_SUPPORT_MAX_CONNECTIONS + workers - 1) // workers) @@ -90,7 +97,7 @@ class EngineClient: name="model_weights_status", array=model_weights_status, dtype=np.int32, - suffix=pid, + suffix=port, create=False, ) self.connection_manager = DealerConnectionManager( diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 6abdcb768..7c9fbf01f 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -31,6 +31,7 @@ from prometheus_client import CONTENT_TYPE_LATEST from fastdeploy.engine.args_utils import EngineArgs from fastdeploy.engine.engine import LLMEngine +from fastdeploy.engine.expert_service import ExpertService from fastdeploy.entrypoints.chat_utils import load_chat_template from fastdeploy.entrypoints.engine_client import EngineClient from fastdeploy.entrypoints.openai.protocol import ( @@ -60,6 +61,7 @@ from fastdeploy.utils import ( FlexibleArgumentParser, StatefulSemaphore, api_server_logger, + configure_uvicorn_logging, console_logger, is_port_available, retrive_model_from_server, @@ -98,15 +100,10 @@ def load_engine(): api_server_logger.info(f"FastDeploy LLM API server starting... {os.getpid()}") engine_args = EngineArgs.from_cli_args(args) engine = LLMEngine.from_engine_args(engine_args) - if not engine.start(api_server_pid=os.getpid()): api_server_logger.error("Failed to initialize FastDeploy LLM engine, service exit now!") return None - api_server_logger.info("FastDeploy LLM engine initialized!\n") - console_logger.info(f"Launching metrics service at http://{args.host}:{args.metrics_port}/metrics") - console_logger.info(f"Launching chat completion service at http://{args.host}:{args.port}/v1/chat/completions") - console_logger.info(f"Launching completion service at http://{args.host}:{args.port}/v1/completions") llm_engine = engine return engine @@ -117,6 +114,25 @@ MAX_CONCURRENT_CONNECTIONS = (args.max_concurrency + args.workers - 1) // args.w connection_semaphore = StatefulSemaphore(MAX_CONCURRENT_CONNECTIONS) +def load_data_service(): + """ + load data service + """ + global llm_engine + if llm_engine is not None: + return llm_engine + api_server_logger.info(f"FastDeploy LLM API server starting... {os.getpid()}") + engine_args = EngineArgs.from_cli_args(args) + config = engine_args.create_engine_config() + api_server_logger.info(f"local_data_parallel_id: {config.parallel_config}") + expert_service = ExpertService(config, config.parallel_config.local_data_parallel_id) + if not expert_service.start(os.getpid(), config.parallel_config.local_data_parallel_id): + api_server_logger.error("Failed to initialize FastDeploy LLM expert service, service exit now!") + return None + llm_engine = expert_service + return expert_service + + @asynccontextmanager async def lifespan(app: FastAPI): """ @@ -140,19 +156,20 @@ async def lifespan(app: FastAPI): model_paths = [ModelPath(name=served_model_names, model_path=args.model, verification=verification)] engine_client = EngineClient( - args.model, - args.tokenizer, - args.max_model_len, - args.tensor_parallel_size, - pid, - args.limit_mm_per_prompt, - args.mm_processor_kwargs, + model_name_or_path=args.model, + tokenizer=args.tokenizer, + max_model_len=args.max_model_len, + tensor_parallel_size=args.tensor_parallel_size, + pid=pid, + port=int(args.engine_worker_queue_port[args.local_data_parallel_id]), + limit_mm_per_prompt=args.limit_mm_per_prompt, + mm_processor_kwargs=args.mm_processor_kwargs, # args.enable_mm, - args.reasoning_parser, - args.data_parallel_size, - args.enable_logprob, - args.workers, - args.tool_call_parser, + reasoning_parser=args.reasoning_parser, + data_parallel_size=args.data_parallel_size, + enable_logprob=args.enable_logprob, + workers=args.workers, + tool_parser=args.tool_call_parser, ) app.state.dynamic_load_weight = args.dynamic_load_weight model_handler = OpenAIServingModels( @@ -176,6 +193,9 @@ async def lifespan(app: FastAPI): app.state.engine_client = engine_client app.state.chat_handler = chat_handler app.state.completion_handler = completion_handler + global llm_engine + if llm_engine is not None: + llm_engine.engine.data_processor = engine_client.data_processor yield # close zmq try: @@ -510,8 +530,18 @@ def launch_controller_server(): def main(): """main函数""" - if load_engine() is None: - return + configure_uvicorn_logging() + load_model_register_plugins() + if args.local_data_parallel_id == 0: + if not load_engine(): + return + else: + if not load_data_service(): + return + api_server_logger.info("FastDeploy LLM engine initialized!\n") + console_logger.info(f"Launching metrics service at http://{args.host}:{args.metrics_port}/metrics") + console_logger.info(f"Launching chat completion service at http://{args.host}:{args.port}/v1/chat/completions") + console_logger.info(f"Launching completion service at http://{args.host}:{args.port}/v1/completions") launch_controller_server() launch_metrics_server() diff --git a/fastdeploy/entrypoints/openai/multi_api_server.py b/fastdeploy/entrypoints/openai/multi_api_server.py new file mode 100644 index 000000000..358f0f8f0 --- /dev/null +++ b/fastdeploy/entrypoints/openai/multi_api_server.py @@ -0,0 +1,107 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import argparse +import os +import subprocess +import sys +import time + +from fastdeploy.utils import get_logger, is_port_available + +logger = get_logger("multi_api_server", "multi_api_server.log") + + +def start_servers(server_count, server_args, ports, metrics_ports): + processes = [] + logger.info(f"Starting servers on ports: {ports} with args: {server_args} and metrics ports: {metrics_ports}") + for i in range(len(server_args)): + if server_args[i] == "--engine-worker-queue-port": + engine_worker_queue_port = server_args[i + 1].split(",") + break + check_param(ports, server_count) + check_param(metrics_ports, server_count) + check_param(engine_worker_queue_port, server_count) + # check_param(server_args, server_count) + for i in range(server_count): + port = int(ports[i]) + metrics_port = int(metrics_ports[i]) + + env = os.environ.copy() + env["FD_LOG_DIR"] = f"log_{i}" + cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + *server_args, + "--port", + str(port), + "--metrics-port", + str(metrics_port), + "--local-data-parallel-id", + str(i), + ] + + # 启动子进程 + proc = subprocess.Popen(cmd, env=env) + processes.append(proc) + logger.info(f"Starting servers #{i+1} (PID: {proc.pid}) port: {port} | command: {' '.join(cmd)}") + + return processes + + +def check_param(ports, num_servers): + logger.info(f"check param {ports}, {num_servers}") + assert len(ports) == num_servers, "Number of ports must match num-servers" + for port in ports: + logger.info(f"check port {port}") + if not is_port_available("0.0.0.0", int(port)): + raise ValueError(f"Port {port} is already in use.") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--ports", default="8000,8002", type=str, help="ports to the http server") + parser.add_argument("--num-servers", default=2, type=int, help="number of workers") + parser.add_argument("--metrics-ports", default="8800,8802", type=str, help="ports for metrics server") + parser.add_argument("--args", nargs=argparse.REMAINDER, help="remaining arguments are passed to api_server.py") + args = parser.parse_args() + + logger.info(f"Starting {args.num_servers} servers on ports: {args.ports} with args: {args.args}") + # check_param(args.ports, args.num_servers) + # check_param(args.metrics_ports, args.num_servers) + # check_param(args.args.engine_worker_queue_port, args.num_servers) + + processes = start_servers( + server_count=args.num_servers, + server_args=args.args, + ports=args.ports.split(","), + metrics_ports=args.metrics_ports.split(","), + ) + + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + for proc in processes: + proc.terminate() + for proc in processes: + proc.wait() + logger.info("All servers stopped.") + + +if __name__ == "__main__": + main() diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index c65f8c38d..c157bd0e7 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -37,7 +37,7 @@ from fastdeploy.entrypoints.openai.protocol import ( UsageInfo, ) from fastdeploy.metrics.work_metrics import work_process_metrics -from fastdeploy.utils import api_server_logger, get_host_ip +from fastdeploy.utils import api_server_logger from fastdeploy.worker.output import LogprobsLists @@ -50,15 +50,16 @@ class OpenAIServingChat: self.engine_client = engine_client self.models = models self.pid = pid - self.master_ip = ips self.max_waiting_time = max_waiting_time - self.host_ip = get_host_ip() self.chat_template = chat_template - if self.master_ip is not None: - if isinstance(self.master_ip, list): - self.master_ip = self.master_ip[0] + if ips is not None: + if isinstance(ips, list): + self.master_ip = ips[0] else: - self.master_ip = self.master_ip.split(",")[0] + self.master_ip = ips.split(",")[0] + else: + self.master_ip = "0.0.0.0" + api_server_logger.info(f"master ip: {self.master_ip}") async def _ensure_connection_manager(self): """ensure connection manager initialized""" @@ -67,19 +68,16 @@ class OpenAIServingChat: self.engine_client.connection_initialized = True def _check_master(self): - if self.master_ip is None: - return True - if self.host_ip == self.master_ip: - return True - return False + return self.engine_client.is_master async def create_chat_completion(self, request: ChatCompletionRequest): """ Create a new chat completion using the specified parameters. """ - if not self._check_master(): - err_msg = f"Only master node can accept completion request, please send request to master node: {self.pod_ips[0]}" + err_msg = ( + f"Only master node can accept completion request, please send request to master node: {self.master_ip}" + ) api_server_logger.error(err_msg) return ErrorResponse(message=err_msg, code=400) @@ -117,7 +115,6 @@ class OpenAIServingChat: api_server_logger.error(error_msg) self.engine_client.semaphore.release() return ErrorResponse(code=400, message=error_msg) - del current_req_dict if request.stream: @@ -193,6 +190,7 @@ class OpenAIServingChat: choices=[], model=model_name, ) + api_server_logger.info(f"create chat completion request: {request_id}") try: await self._ensure_connection_manager() @@ -388,7 +386,6 @@ class OpenAIServingChat: enable_thinking = request.metadata.get("enable_thinking") if request.metadata else None include_stop_str_in_output = request.include_stop_str_in_output - try: await self._ensure_connection_manager() dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id) diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 858eed735..a6f782e17 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -33,7 +33,7 @@ from fastdeploy.entrypoints.openai.protocol import ( ErrorResponse, UsageInfo, ) -from fastdeploy.utils import api_server_logger, get_host_ip +from fastdeploy.utils import api_server_logger from fastdeploy.worker.output import LogprobsLists @@ -42,14 +42,14 @@ class OpenAIServingCompletion: self.engine_client = engine_client self.models = models self.pid = pid - self.master_ip = ips - self.host_ip = get_host_ip() self.max_waiting_time = max_waiting_time - if self.master_ip is not None: - if isinstance(self.master_ip, list): - self.master_ip = self.master_ip[0] + if ips is not None: + if isinstance(ips, list): + self.master_ip = ips[0] else: - self.master_ip = self.master_ip.split(",")[0] + self.master_ip = ips.split(",")[0] + else: + self.master_ip = "0.0.0.0" async def _ensure_connection_manager(self): """ensure connection manager initialized""" @@ -58,18 +58,16 @@ class OpenAIServingCompletion: self.engine_client.connection_initialized = True def _check_master(self): - if self.master_ip is None: - return True - if self.host_ip == self.master_ip: - return True - return False + return self.engine_client.is_master async def create_completion(self, request: CompletionRequest): """ Create a completion for the given prompt. """ if not self._check_master(): - err_msg = f"Only master node can accept completion request, please send request to master node: {self.pod_ips[0]}" + err_msg = ( + f"Only master node can accept completion request, please send request to master node: {self.master_ip}" + ) api_server_logger.error(err_msg) return ErrorResponse(message=err_msg, code=400) if self.models: diff --git a/fastdeploy/entrypoints/openai/utils.py b/fastdeploy/entrypoints/openai/utils.py index d33eb01c2..08c414051 100644 --- a/fastdeploy/entrypoints/openai/utils.py +++ b/fastdeploy/entrypoints/openai/utils.py @@ -47,7 +47,7 @@ class DealerConnectionManager: self.running = True for index in range(self.max_connections): await self._add_connection(index) - api_server_logger.info(f"Started {self.max_connections} connections") + api_server_logger.info(f"Started {self.max_connections} connections, pid {self.pid}") async def _add_connection(self, index): """create a new connection and start listening task""" diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 790af9552..24b85ba91 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -86,6 +86,8 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_JOB_ID": lambda: os.getenv("FD_JOB_ID"), # support max connections "FD_SUPPORT_MAX_CONNECTIONS": lambda: int(os.getenv("FD_SUPPORT_MAX_CONNECTIONS", "1024")), + # enable multi api server + "FD_ENABLE_MULTI_API_SERVER": lambda: bool(int(os.getenv("FD_ENABLE_MULTI_API_SERVER", "0"))), "FD_FOR_TORCH_MODEL_FORMAT": lambda: bool(int(os.getenv("FD_FOR_TORCH_MODEL_FORMAT", "0"))), } diff --git a/fastdeploy/inter_communicator/__init__.py b/fastdeploy/inter_communicator/__init__.py index 0c1cc0d9f..41eb1ccc2 100644 --- a/fastdeploy/inter_communicator/__init__.py +++ b/fastdeploy/inter_communicator/__init__.py @@ -16,7 +16,7 @@ from .engine_cache_queue import EngineCacheQueue from .engine_worker_queue import EngineWorkerQueue -from .ipc_signal import IPCSignal +from .ipc_signal import IPCSignal, shared_memory_exists from .zmq_client import ZmqClient -__all__ = ["ZmqClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue"] +__all__ = ["ZmqClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue", "shared_memory_exists"] diff --git a/fastdeploy/inter_communicator/ipc_signal.py b/fastdeploy/inter_communicator/ipc_signal.py index 0ac2e3fa0..075f1a461 100644 --- a/fastdeploy/inter_communicator/ipc_signal.py +++ b/fastdeploy/inter_communicator/ipc_signal.py @@ -18,6 +18,8 @@ from multiprocessing.shared_memory import SharedMemory import numpy as np +from fastdeploy.utils import llm_logger + def shared_memory_exists(name: str) -> bool: """Check if a shared memory block with the given name exists. @@ -35,7 +37,7 @@ def shared_memory_exists(name: str) -> bool: except FileNotFoundError: return False except Exception as e: - print(f"Unexpected error: {e}") + llm_logger.error(f"Unexpected error: {e}") return False @@ -78,7 +80,9 @@ class IPCSignal: name = name + f".{suffix}" if create: - assert not shared_memory_exists(name), f"ShareMemory: {name} already exists" + if shared_memory_exists(name): + llm_logger.warning(f"ShareMemory: {name} already exists, delete it") + SharedMemory(name=name, create=False).unlink() self.shm = SharedMemory(create=True, size=array.nbytes, name=name) self.value: np.ndarray = np.ndarray(array.shape, dtype=array.dtype, buffer=self.shm.buf) self.value[:] = array # Initialize with input array data diff --git a/fastdeploy/inter_communicator/zmq_client.py b/fastdeploy/inter_communicator/zmq_client.py index 6affcd8e7..9b259f40e 100644 --- a/fastdeploy/inter_communicator/zmq_client.py +++ b/fastdeploy/inter_communicator/zmq_client.py @@ -71,6 +71,7 @@ class ZmqClient: self.router.setsockopt(zmq.ROUTER_MANDATORY, 1) self.router.setsockopt(zmq.SNDTIMEO, -1) self.router.bind(f"ipc://{self.router_path}") + llm_logger.info(f"router path: {self.router_path}") def send_json(self, data): """ @@ -126,7 +127,6 @@ class ZmqClient: continue else: break - if self.req_dict[req_id] == -1: if data[-1].finished: with self.mutex: diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index 2403dcd7c..261aaf620 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -49,6 +49,7 @@ def get_moe_scores( compute moe scores using e_score_correction_bias. """ scores = paddle.nn.functional.sigmoid(gating_output) + assert e_score_correction_bias is not None, "e_score_correction_bias is none!" scores_with_bias = scores + e_score_correction_bias scores, topk_values, topk_idx = noaux_tc( scores, @@ -104,11 +105,12 @@ class DeepEPEngine: # In mixed EP mode on a single node, we dynamically switch between # high throughput and low latency modes. + if splitwise_role == "mixed": self.deepep_engine = deep_ep.Buffer( self.group, int(2e9), - int(5e9), + int(6e9), low_latency_mode=True, num_qps_per_rank=24, ) @@ -387,6 +389,7 @@ class EPPrefillRunner(EPRunner): *args, **kwargs, ): + ( num_tokens_per_rank, num_tokens_per_rdma_rank, diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index d60ab8ad8..31a7124ef 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -35,23 +35,22 @@ class SplitwiseConnector: SplitwiseConnector class for managing and scheduling Splitwise tasks. """ - def __init__(self, cfg, scheduler, worker_queue, resource_manager): + def __init__(self, cfg, worker_queue, resource_manager): """ Initialize the SplitwiseConnector instance. Parameters: cfg (dict): Configuration information. - scheduler (object): Scheduler object. worker_queue (object): Worker queue object. resource_manager (object): Resource manager object. """ self.cfg = cfg - self.scheduler = scheduler self.engine_worker_queue = worker_queue self.resource_manager = resource_manager self.connect_innode_instances = {} self.temp_cache_info = dict() self.current_request_ids = dict() + self.idx = self.cfg.parallel_config.local_data_parallel_id if self.cfg.cache_config.pd_comm_port is not None: self.zmq_ctx = zmq.Context() @@ -85,18 +84,20 @@ class SplitwiseConnector: """ while True: try: - socks = dict(self.poller.poll(100)) - if not socks: - continue + if hasattr(self, "poller"): + socks = dict(self.poller.poll(100)) + if not socks: + continue + else: + logger.debug(f"receive {socks}") + + frames = self.router_socket.recv_multipart() + logger.debug(f"frames: {frames}") + message = frames[-1] + self.io_executor.submit(self._process_message, message) + time.sleep(0.001) else: - logger.debug(f"receive {socks}") - - frames = self.router_socket.recv_multipart() - logger.debug(f"frames: {frames}") - message = frames[-1] - self.io_executor.submit(self._process_message, message) - time.sleep(0.001) - + time.sleep(5) except Exception as e: logger.error(f"Receiver error: {e}, {str(traceback.format_exc())}") time.sleep(1) @@ -183,7 +184,7 @@ class SplitwiseConnector: def dispatch_innode_splitwise_tasks(self, tasks, current_id): """ - Dispatch splitwise tasks to the scheduler. + Dispatch splitwise tasks . Parameters: tasks (list): List of tasks. @@ -203,7 +204,7 @@ class SplitwiseConnector: "cache_info": { "ipc": { "ip": "0.0.0.0", - "port": self.cfg.engine_worker_queue_port, + "port": self.cfg.engine_worker_queue_port[self.idx], "current_id": current_id, }, }, @@ -286,7 +287,7 @@ class SplitwiseConnector: if port not in self.connect_innode_instances: self.create_connection(port) for task in tasks: - task.disaggregate_info["cache_info"]["ipc"]["port"] = self.cfg.engine_worker_queue_port + task.disaggregate_info["cache_info"]["ipc"]["port"] = self.cfg.engine_worker_queue_port[self.idx] self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks)) for task in tasks: task.disaggregate_info["cache_info"]["ipc"]["port"] = port diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index f60a96468..508bdb5e7 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -38,6 +38,7 @@ import yaml from aistudio_sdk.snapshot_download import snapshot_download as aistudio_download from tqdm import tqdm from typing_extensions import TypeIs, assert_never +from uvicorn.config import LOGGING_CONFIG from fastdeploy import envs from fastdeploy.logger.logger import FastDeployLogger @@ -76,6 +77,35 @@ class ColoredFormatter(logging.Formatter): return message +def configure_uvicorn_logging(): + """ + uvicorn logger config + """ + # add timestamp to log + log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + date_format = "%Y-%m-%d %H:%M:%S" + LOGGING_CONFIG["formatters"]["default"]["fmt"] = log_format + LOGGING_CONFIG["formatters"]["default"]["datefmt"] = date_format + LOGGING_CONFIG["formatters"]["access"]["fmt"] = log_format + LOGGING_CONFIG["formatters"]["access"]["datefmt"] = date_format + + uvicorn_error_logger = logging.getLogger("") + uvicorn_access_logger = logging.getLogger("uvicorn.access") + for handler in uvicorn_error_logger.handlers[:]: + uvicorn_error_logger.removeHandler(handler) + for handler in uvicorn_access_logger.handlers[:]: + uvicorn_access_logger.removeHandler(handler) + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(logging.Formatter(log_format, date_format)) + + uvicorn_error_logger.addHandler(console_handler) + uvicorn_access_logger.addHandler(console_handler) + uvicorn_error_logger.setLevel(logging.INFO) + uvicorn_access_logger.setLevel(logging.INFO) + uvicorn_error_logger.propagate = False + uvicorn_access_logger.propagate = False + + class DailyRotatingFileHandler(BaseRotatingHandler): """ like `logging.TimedRotatingFileHandler`, but this class support multi-process diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py index d1f8f2c68..07341c23b 100644 --- a/fastdeploy/worker/gcu_model_runner.py +++ b/fastdeploy/worker/gcu_model_runner.py @@ -106,9 +106,7 @@ class GCUModelRunner(ModelRunnerBase): self.forward_meta: ForwardMeta = None # Postprocess Env params - os.environ["INFERENCE_MSG_QUEUE_ID"] = str( - self.local_rank + int(self.parallel_config.engine_worker_queue_port) - ) + os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.engine_worker_queue_port) def exist_prefill(self): """ diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index b2568dfaa..cb4b8809c 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -153,9 +153,8 @@ class GPUModelRunner(ModelRunnerBase): self.forward_meta: ForwardMeta = None # Postprocess Env params - os.environ["INFERENCE_MSG_QUEUE_ID"] = str( - self.local_rank + int(self.parallel_config.engine_worker_queue_port) - ) + os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.engine_worker_queue_port) + logger.info(f"queue id is {str(self.parallel_config.engine_worker_queue_port)}") def exist_prefill(self): """ diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 3f4e87302..5fa8e142e 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -152,19 +152,7 @@ class PaddleDisWorkerProc: # TODO(gongshaotian): Use worker factory to get worker self.worker = get_worker(fd_config=fd_config, local_rank=self.local_rank, rank=self.ranks) - # Initialize task queue - task_address = ( - self.parallel_config.pod_ip, - self.parallel_config.engine_worker_queue_port, - ) self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 - self.task_queue = TaskQueue( - address=task_address, - is_server=False, - num_client=self.parallel_config.tensor_parallel_size, - client_id=self.parallel_config.tensor_parallel_rank, - local_data_parallel_id=self.parallel_config.data_parallel_rank, - ) def init_health_status(self) -> None: """ @@ -193,15 +181,16 @@ class PaddleDisWorkerProc: self.worker_ready_signal.value[self.local_rank % self.max_chips_per_node] = 1 # init worker_healthy_live_signal - workers_alive = np.zeros(shape=[array_size], dtype=np.int32) + workers_alive = np.zeros(shape=[min(array_size, self.parallel_config.tensor_parallel_size)], dtype=np.int32) self.worker_healthy_live_signal = IPCSignal( name="worker_healthy_live_signal", array=workers_alive, dtype=np.int32, - suffix=self.parallel_config.engine_pid, + suffix=self.parallel_config.engine_worker_queue_port, create=False, ) - self.worker_healthy_live_signal.value[self.local_rank % self.max_chips_per_node] = int(time.time()) + local_rank = self.local_rank % self.parallel_config.tensor_parallel_size + self.worker_healthy_live_signal.value[local_rank % self.max_chips_per_node] = int(time.time()) # init model_weights_status workers_model_weights = np.zeros(shape=[1], dtype=np.int32) @@ -209,27 +198,27 @@ class PaddleDisWorkerProc: name="model_weights_status", array=workers_model_weights, dtype=np.int32, - suffix=self.parallel_config.engine_pid, + suffix=self.parallel_config.engine_worker_queue_port, create=False, ) # init exist_task_signal - workers_exist_task = np.zeros([self.parallel_config.data_parallel_size], dtype=np.int32) + workers_exist_task = np.zeros([1], dtype=np.int32) self.exist_task_signal = IPCSignal( name="exist_task_signal", array=workers_exist_task, dtype=np.int32, - suffix=self.parallel_config.engine_pid, + suffix=self.parallel_config.engine_worker_queue_port, create=False, ) # init exist_swapped_task_signal - workers_swapped_task = np.zeros(shape=[self.parallel_config.data_parallel_size], dtype=np.int32) + workers_swapped_task = np.zeros(shape=[1], dtype=np.int32) self.exist_swapped_task_signal = IPCSignal( name="exist_swapped_task_signal", array=workers_swapped_task, dtype=np.int32, - suffix=self.parallel_config.engine_pid, + suffix=self.parallel_config.engine_worker_queue_port, create=False, ) @@ -239,9 +228,10 @@ class PaddleDisWorkerProc: name="exist_prefill_task_signal", array=exist_prefill_task_signal_data, dtype=np.int32, - suffix=self.parallel_config.engine_pid, + suffix=self.parallel_config.engine_worker_queue_port, create=False, ) + logger.info("gaoziyuan test init_health_status") def event_loop_normal(self) -> None: """Main event loop for Paddle Distrubuted Workers. @@ -411,6 +401,21 @@ class PaddleDisWorkerProc: """Initialize device and Construct model runner""" self.worker.init_device() + def start_task_queue_service(self): + # Initialize task queue + task_address = ( + self.parallel_config.pod_ip, + self.parallel_config.engine_worker_queue_port, + ) + logger.info(f"connect task queue address {task_address}") + self.task_queue = TaskQueue( + address=task_address, + is_server=False, + num_client=self.parallel_config.tensor_parallel_size, + client_id=self.parallel_config.tensor_parallel_rank, + local_data_parallel_id=self.parallel_config.expert_parallel_rank, + ) + def load_model(self) -> None: """Load weights and create model""" @@ -444,7 +449,7 @@ def parse_args(): parser.add_argument("--total_block_num", type=int, default=2000) parser.add_argument("--block_size", type=int, default=64) parser.add_argument("--pod_ip", type=str, default="127.0.0.1") - parser.add_argument("--engine_worker_queue_port", type=int, default=9923) + parser.add_argument("--engine_worker_queue_port", type=str, default="9923") parser.add_argument("--max_model_len", type=int, default=3072, help="max model len") parser.add_argument("--device_ids", type=str, default="0", help="cuda visible devices") parser.add_argument("--dtype", type=str, default="bfloat16", help="input dtype") @@ -619,10 +624,16 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: num_experts_per_rank = num_experts // parallel_config.expert_parallel_size num_experts_start_offset = expert_parallel_rank * num_experts_per_rank + max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 + parallel_config.local_data_parallel_id = expert_parallel_rank % max_chips_per_node parallel_config.expert_parallel_rank = expert_parallel_rank parallel_config.num_experts_per_rank = num_experts_per_rank parallel_config.num_experts_start_offset = num_experts_start_offset + + parallel_config.engine_worker_queue_port = parallel_config.engine_worker_queue_port[ + parallel_config.local_data_parallel_id + ] parallel_config.set_tp_group() load_config = LoadConfig(vars(args)) @@ -640,6 +651,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: logger.info(f"parallel_config.use_ep {parallel_config.use_ep}") logger.info(f"parallel_config.tensor_parallel_size {parallel_config.tensor_parallel_size}") logger.info(f"parallel_config.tensor_parallel_rank {parallel_config.tensor_parallel_rank}") + logger.info(f"parallel_config.engine_worker_queue_port {parallel_config.engine_worker_queue_port}") if getattr(model_config, "num_hidden_layers", None) is None: raise ValueError("num_hidden_layers is None") @@ -705,6 +717,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: graph_opt_config=graph_opt_config, early_stop_config=early_stop_config, cache_config=cache_config, + engine_worker_queue_port=args.engine_worker_queue_port, ips=args.ips, ) update_fd_config_for_mm(fd_config) @@ -746,6 +759,8 @@ def run_worker_proc() -> None: # Initialize health status worker_proc.init_health_status() + worker_proc.start_task_queue_service() + # Start event loop worker_proc.event_loop_normal() diff --git a/scripts/run_ci_xpu.sh b/scripts/run_ci_xpu.sh index 3b0c4252a..850db5aa5 100644 --- a/scripts/run_ci_xpu.sh +++ b/scripts/run_ci_xpu.sh @@ -92,6 +92,8 @@ if [ ${exit_code} -ne 0 ]; then exit 1 fi +sleep 5 + #0731新增kv block集中式管理相关测试,在起服务时启用对应环境变量 export ENABLE_V1_KVCACHE_SCHEDULER=True # 起服务 rm -rf log/*