diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 0c52cbfc5..8439cabc9 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -210,13 +210,42 @@ class LLMEngine: engine_worker_queue_port=self.cfg.engine_worker_queue_port, pid_suffix=self.ipc_signal_suffix, ) - self.launched_cache_manager_signal.value[0] = 1 self.worker_proc = self._start_worker_service() console_logger.info("Waitting worker processes ready...") time.sleep(5) self.worker_init_status = dict() - if not self.check_worker_initialize_status(): + + result_container = {} + + def check_worker_initialize_status_func(res: dict): + res["worker_is_alive"] = True + if not self.check_worker_initialize_status(): + console_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.") + res["worker_is_alive"] = False + + self.check_worker_initialize_status_func_thread = threading.Thread( + target=check_worker_initialize_status_func, args=(result_container,), daemon=True + ) + self.check_worker_initialize_status_func_thread.start() + + # Wait model loading + while self.loaded_model_signal.value[0] == 0: + # Make sure worker process is alive + if not self.check_worker_initialize_status_func_thread.is_alive(): + return False + time.sleep(1) + + if self.do_profile: + self._stop_profile() + # Launch components: scheduler, cache_manager, expert_service et.al. + self.launch_components() + if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed": + self.launched_cache_manager_signal.value[0] = 1 + + # Worker launched + self.check_worker_initialize_status_func_thread.join() + if not result_container["worker_is_alive"]: console_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.") return False @@ -228,68 +257,6 @@ class LLMEngine: self._del_warmup_token_processor() console_logger.info("Warmup finished") - self.token_processor.tasks_queue = self.engine_worker_queue - - if envs.ENABLE_V1_KVCACHE_SCHEDULER: - self.insert_task_to_worker_thread = threading.Thread(target=self._scheduler_task_to_worker_v1, daemon=True) - else: - self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, daemon=True) - self.insert_task_to_worker_thread.start() - - if self.api_server_pid is not None: - self.insert_task_to_scheduler_thread = threading.Thread( - target=self._insert_zmq_task_to_scheduler, daemon=True - ) - self.insert_task_to_scheduler_thread.start() - - self.receive_output_thread = threading.Thread(target=self._zmq_send_generated_tokens, daemon=True) - self.receive_output_thread.start() - - # Start TokenProcessor thread - self.token_processor.run() - - if self.cfg.splitwise_role != "mixed": - # 单机逻辑 - self.engine_worker_queue.available_prefill_instances.put(1) - self.split_mode_get_tasks() - if self.cfg.scheduler_config.name == "splitwise": - self.splitwise_receive_thread = threading.Thread(target=self.split_connector.start_receiver, args=()) - self.splitwise_receive_thread.daemon = True - self.splitwise_receive_thread.start() - - self.cfg.init_cache_info() - - role = self.cfg.splitwise_role - host_ip = self.cfg.host_ip - disaggregate = self.cfg.disaggregate_info - if self.cfg.scheduler_config.name == "splitwise": - self.scheduler.start(role, host_ip, disaggregate) - - time.sleep(1) - - if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1: - self.dp_processed = [] - for i in range( - 1, - self.cfg.parallel_config.data_parallel_size // self.cfg.nnode, - ): - time.sleep(1) - self.dp_processed.append( - multiprocessing.Process( - target=start_expert_service, - args=( - self.cfg, - i + self.cfg.node_rank * self.cfg.worker_num_per_node, - self.ipc_signal_suffix, - ), - ) - ) - llm_logger.info( - f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}" - + f" data parallel id {i}" - ) - self.dp_processed[-1].start() - console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.") return True @@ -950,7 +917,6 @@ class LLMEngine: suffix=self.ipc_signal_suffix, create=True, ) - # launched_cache_manager_signal 用于感知engine是否启动了cache_manager if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed": launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32) @@ -962,6 +928,29 @@ class LLMEngine: create=True, ) + # launched_expert_service_signal: Used to sense whether each expet_servic is started successfully + if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1: + launched_expert_service_signal_data = np.zeros( + shape=[self.cfg.parallel_config.data_parallel_size // self.cfg.nnode], dtype=np.int32 + ) + self.launched_expert_service_signal = IPCSignal( + name="launched_expert_service_signal", + array=launched_expert_service_signal_data, + dtype=np.int32, + suffix=self.ipc_signal_suffix, + create=True, + ) + + # loaded_model_signal: Used to detect whether each worker has completed model loading + loaded_model_signal_data = np.zeros([1], dtype=np.int32) + self.loaded_model_signal = IPCSignal( + name="loaded_model_signal", + array=loaded_model_signal_data, + dtype=np.int32, + suffix=self.ipc_signal_suffix, + create=True, + ) + # worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间 worker_healthy_live_recorded_time_array = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32) self.worker_healthy_live_signal = IPCSignal( @@ -1244,7 +1233,6 @@ class LLMEngine: engine_worker_queue_port=self.cfg.engine_worker_queue_port, pid_suffix=self.ipc_signal_suffix, ) - self.launched_cache_manager_signal.value[0] = 1 def check_health(self, time_interval_threashold=30): """ @@ -1258,6 +1246,72 @@ class LLMEngine: return True, "" + def launch_components(self): + self.token_processor.tasks_queue = self.engine_worker_queue + + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + self.insert_task_to_worker_thread = threading.Thread(target=self._scheduler_task_to_worker_v1, daemon=True) + else: + self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, daemon=True) + self.insert_task_to_worker_thread.start() + + if self.api_server_pid is not None: + self.insert_task_to_scheduler_thread = threading.Thread( + target=self._insert_zmq_task_to_scheduler, daemon=True + ) + self.insert_task_to_scheduler_thread.start() + + self.receive_output_thread = threading.Thread(target=self._zmq_send_generated_tokens, daemon=True) + self.receive_output_thread.start() + + # Start TokenProcessor thread + self.token_processor.run() + + if self.cfg.splitwise_role != "mixed": + # 单机逻辑 + self.engine_worker_queue.available_prefill_instances.put(1) + self.split_mode_get_tasks() + if self.cfg.scheduler_config.name == "splitwise": + self.splitwise_receive_thread = threading.Thread(target=self.split_connector.start_receiver, args=()) + self.splitwise_receive_thread.daemon = True + self.splitwise_receive_thread.start() + + self.cfg.init_cache_info() + + role = self.cfg.splitwise_role + host_ip = self.cfg.host_ip + disaggregate = self.cfg.disaggregate_info + if self.cfg.scheduler_config.name == "splitwise": + self.scheduler.start(role, host_ip, disaggregate) + + time.sleep(1) + expert_service_nums = self.cfg.parallel_config.data_parallel_size // self.cfg.nnode + if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1: + self.dp_processed = [] + for i in range( + 1, + expert_service_nums, + ): + time.sleep(1) + self.dp_processed.append( + multiprocessing.Process( + target=start_expert_service, + args=( + self.cfg, + i + self.cfg.node_rank * self.cfg.worker_num_per_node, + self.ipc_signal_suffix, + ), + ) + ) + llm_logger.info( + f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}" + + f" data parallel id {i}" + ) + self.dp_processed[-1].start() + for i in range(1, expert_service_nums): + while self.launched_expert_service_signal.value[i] == 0: + time.sleep(1) + def check_worker_initialize_status(self): """ Check the initlialize status of workers by stdout logging @@ -1283,10 +1337,6 @@ class LLMEngine: self.checking_worker_status_thread = threading.Thread(target=detect_thread, daemon=True) self.checking_worker_status_thread.start() - checking_worker_init_kv_cache_status_thread = None - if self.do_profile: - checking_worker_init_kv_cache_status_thread = threading.Thread(target=self._stop_profile, daemon=True) - checking_worker_init_kv_cache_status_thread.start() # display weight loadding progress with tqdm(total=100, desc="Loading Weights") as pbar: @@ -1317,8 +1367,6 @@ class LLMEngine: self.worker_init_status["finished"] = True try: self.checking_worker_status_thread.join(timeout=1) - if checking_worker_init_kv_cache_status_thread is not None: - checking_worker_init_kv_cache_status_thread.join(timeout=1) except Exception: pass return True diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index 0032780b9..ce743c1ed 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -26,7 +26,7 @@ import weakref import numpy as np from fastdeploy.engine.resource_manager import ResourceManager -from fastdeploy.inter_communicator import EngineWorkerQueue +from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.output.token_processor import TokenProcessor from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector @@ -127,7 +127,7 @@ class ExpertService: cache_config=self.cfg.cache_config, tensor_parallel_size=self.cfg.tensor_parallel_size, device_ids=self.cfg.local_device_ids, - pod_ip=self.cfg.pod_ips[0], + pod_ip=self.cfg.master_ip, engine_worker_queue_port=self.cfg.engine_worker_queue_port, pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}", ) @@ -150,7 +150,22 @@ class ExpertService: self.scheduler.start(role, host_ip, disaggregate) self.cfg.print() - console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.") + launched_expert_service_signal_data = np.zeros( + shape=[self.cfg.parallel_config.data_parallel_size // self.cfg.nnode], dtype=np.int32 + ) + self.launched_expert_service_signal = IPCSignal( + name="launched_expert_service_signal", + array=launched_expert_service_signal_data, + dtype=np.int32, + suffix=ipc_signal_suffix, + create=False, + ) + local_rank = local_data_parallel_id % self.cfg.worker_num_per_node + self.launched_expert_service_signal.value[local_rank] = 1 + + console_logger.info( + f"Worker processes(rank {local_rank}) are launched with {time.time() - start_time} seconds." + ) return True def _insert_task_to_worker(self): diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 924395565..828cdfd14 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -458,6 +458,17 @@ class PaddleDisWorkerProc: def load_model(self) -> None: """Load weights and create model""" self.worker.load_model() + loaded_model_signal_data = np.zeros(shape=[1], dtype=np.int32) + self.loaded_model_signal = IPCSignal( + name="loaded_model_signal", + array=loaded_model_signal_data, + dtype=np.int32, + suffix=self.parallel_config.engine_pid, + create=False, + ) + if self.ranks > 1: + paddle.distributed.barrier() + self.loaded_model_signal.value[0] = 1 def parse_args():