diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 07468cb84..6513e49fb 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -341,7 +341,8 @@ class ParallelConfig: def set_tp_group(self): # different tp group id # prevent different tp_groups using the same group_id - dist.collective._set_custom_gid(self.data_parallel_rank + 100) + tp_gid_offset = envs.FD_TP_GROUP_GID_OFFSET + dist.collective._set_custom_gid(self.data_parallel_rank + tp_gid_offset) self.tp_group = dist.new_group( range( self.data_parallel_rank * self.tensor_parallel_size, @@ -349,7 +350,8 @@ class ParallelConfig: ) ) # same ep group id - dist.collective._set_custom_gid(self.data_parallel_size + 100) + # (TODO:gaoziyuan move this gid config to ep.py) + dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset) logger.info( f"data_parallel_size: {self.data_parallel_size}, tensor_parallel_size: {self.tensor_parallel_size}, expert_parallel_size: {self.expert_parallel_size}, data_parallel_rank: {self.data_parallel_rank}, tensor_parallel_rank: {self.tensor_parallel_rank}, expert_parallel_rank: {self.expert_parallel_rank}, tp_group: {self.tp_group}." ) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index ddcf31d28..c79658d34 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -484,7 +484,6 @@ class LLMEngine: for worker_flag, value in worker_append_flag.items(): if value: arguments = arguments + f" --{worker_flag}" - llm_logger.info(f"gaoziyuan test ips :{self.cfg.ips}") if self.cfg.nnode > 1: pd_cmd = pd_cmd + f" --ips {ips} --nnodes {len(self.cfg.ips)}" pd_cmd = pd_cmd + arguments + f" 2>{log_dir}/launch_worker.log" diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index cb8e93a8f..d726c0dca 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -88,6 +88,8 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_JOB_ID": lambda: os.getenv("FD_JOB_ID"), # support max connections "FD_SUPPORT_MAX_CONNECTIONS": lambda: int(os.getenv("FD_SUPPORT_MAX_CONNECTIONS", "1024")), + # Offset for Tensor Parallelism group GID. + "FD_TP_GROUP_GID_OFFSET": lambda: int(os.getenv("FD_TP_GROUP_GID_OFFSET", "1000")), # enable multi api server "FD_ENABLE_MULTI_API_SERVER": lambda: bool(int(os.getenv("FD_ENABLE_MULTI_API_SERVER", "0"))), "FD_FOR_TORCH_MODEL_FORMAT": lambda: bool(int(os.getenv("FD_FOR_TORCH_MODEL_FORMAT", "0"))), diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py index 5b3b1c6a4..3fc37b845 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py @@ -170,10 +170,12 @@ class MoEMethodBase(QuantMethodBase): """ if layer.ep_size > 1: if layer.fd_config.parallel_config.moe_phase.phase == "prefill": - self.ep_prefill_runner.clean_low_latency_buffer() + if layer.fd_config.parallel_config.splitwise_role == "mixed": + self.ep_prefill_runner.clean_low_latency_buffer() return self.apply_ep_prefill(layer, x, gate) else: - self.ep_decoder_runner.clean_low_latency_buffer() + if layer.fd_config.parallel_config.splitwise_role == "mixed": + self.ep_decoder_runner.clean_low_latency_buffer() return self.apply_ep_decode(layer, x, gate) else: return self.apply_tp(layer, x, gate) diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index ce7bb8ac4..6e1097d86 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -325,28 +325,39 @@ def load_composite_checkpoint( # 2. Tensor Parallel (TP) # 3. Pre-sharded (pre-split) """ - rank_dirs = [ - f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f)) - ] - if len(rank_dirs) > 1: - if fd_config.parallel_config.tensor_parallel_size != len(rank_dirs): - raise ValueError(f"Your model only supports loading with tp{len(rank_dirs)}") - state_dict = load_pre_sharded_checkpoint( - model_path, - fd_config.parallel_config.tensor_parallel_rank, - use_fastsafetensor=False, - ) + # (TODO: remove in the future) + if ( + fd_config.parallel_config.use_ep + and fd_config.speculative_config.model_type != "mtp" + and fd_config.parallel_config.tensor_parallel_size == 1 + ): + state_dict = load_ep_checkpoint(model_path, fd_config, return_numpy=True) else: - if fd_config.load_config.use_fastsafetensor and (current_platform.available() and current_platform.is_cuda()): - state_dict = load_tp_checkpoint_v1(model_path, cls, fd_config, use_fastsafetensor=True) - deal_state_dict(state_dict) - else: - state_dict = load_tp_checkpoint( + rank_dirs = [ + f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f)) + ] + if len(rank_dirs) > 1: + if fd_config.parallel_config.tensor_parallel_size != len(rank_dirs): + raise ValueError(f"Your model only supports loading with tp{len(rank_dirs)}") + state_dict = load_pre_sharded_checkpoint( model_path, - cls, - fd_config.model_config.pretrained_config, - return_numpy=return_numpy, + fd_config.parallel_config.tensor_parallel_rank, + use_fastsafetensor=False, ) + else: + if fd_config.load_config.use_fastsafetensor and ( + current_platform.available() and current_platform.is_cuda() + ): + state_dict = load_tp_checkpoint_v1(model_path, cls, fd_config, use_fastsafetensor=True) + deal_state_dict(state_dict) + else: + # NOTE: for very big model, cpu will be out of memory + state_dict = load_tp_checkpoint( + model_path, + cls, + fd_config.model_config.pretrained_config, + return_numpy=return_numpy, + ) if not state_dict: raise ValueError("weight not found in state_dict !") return state_dict diff --git a/fastdeploy/model_executor/model_loader/default_loader.py b/fastdeploy/model_executor/model_loader/default_loader.py index 75c80bfa8..e1ee0ce1f 100644 --- a/fastdeploy/model_executor/model_loader/default_loader.py +++ b/fastdeploy/model_executor/model_loader/default_loader.py @@ -54,6 +54,7 @@ class DefaultModelLoader(BaseModelLoader): @measure_time def load_weights(self, model, fd_config: FDConfig, architectures: str) -> None: model_class = ModelRegistry.get_pretrain_cls(architectures) + state_dict = load_composite_checkpoint( fd_config.model_config.model, model_class, diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index d7d0e8e40..08f8d2045 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -170,6 +170,7 @@ class PaddleDisWorkerProc: self.max_chips_per_node, self.parallel_config.tensor_parallel_size * self.parallel_config.data_parallel_size, ) + workers_ready = np.zeros(shape=[array_size], dtype=np.int32) self.worker_ready_signal = IPCSignal( name="worker_ready_signal", @@ -179,7 +180,6 @@ class PaddleDisWorkerProc: create=False, ) self.worker_ready_signal.value[self.local_rank % self.max_chips_per_node] = 1 - # init worker_healthy_live_signal workers_alive = np.zeros(shape=[min(array_size, self.parallel_config.tensor_parallel_size)], dtype=np.int32) self.worker_healthy_live_signal = IPCSignal( @@ -231,7 +231,6 @@ class PaddleDisWorkerProc: suffix=self.parallel_config.engine_worker_queue_port, create=False, ) - logger.info("gaoziyuan test init_health_status") def event_loop_normal(self) -> None: """Main event loop for Paddle Distrubuted Workers. @@ -255,7 +254,8 @@ class PaddleDisWorkerProc: self.insert_step = False req_dicts = None - self.worker_healthy_live_signal.value[self.local_rank % self.max_chips_per_node] = int(time.time()) + local_rank = self.local_rank % self.parallel_config.tensor_parallel_size + self.worker_healthy_live_signal.value[local_rank % self.max_chips_per_node] = int(time.time()) # The first worker detects whether there are tasks in the task queue if self.local_rank % mp_num_per_node == 0: @@ -267,7 +267,7 @@ class PaddleDisWorkerProc: if self.nnode > 1 and self.parallel_config.tensor_parallel_size > self.max_chips_per_node: self.task_queue.read_finish_flag.set(1) else: - self.exist_task_signal.value[self.fd_config.parallel_config.data_parallel_rank] = 1 + self.exist_task_signal.value[0] = 1 if self.parallel_config.tensor_parallel_size > 1: # Synchronize the signal for other workers @@ -285,17 +285,14 @@ class PaddleDisWorkerProc: self.parallel_config.engine_pid, ) - if ( - self.exist_task_signal.value[self.fd_config.parallel_config.data_parallel_rank] == 1 - or self.task_queue.read_finish_flag.get() == 1 - ): + if self.exist_task_signal.value[0] == 1 or self.task_queue.read_finish_flag.get() == 1: logger.info(f"Rank: {self.local_rank} Detected new requests.") self.insert_step = True tasks, read_finish = self.task_queue.get_tasks() if read_finish: # Ensure that every worker get the task - self.exist_task_signal.value[self.fd_config.parallel_config.data_parallel_rank] = 0 + self.exist_task_signal.value[0] = 0 self.task_queue.read_finish_flag.set(0) req_dicts = [] @@ -413,7 +410,7 @@ class PaddleDisWorkerProc: is_server=False, num_client=self.parallel_config.tensor_parallel_size, client_id=self.parallel_config.tensor_parallel_rank, - local_data_parallel_id=self.parallel_config.expert_parallel_rank, + local_data_parallel_id=self.parallel_config.data_parallel_rank, ) def load_model(self) -> None: