From e5804b1d985d992fa05f7950326e65e0e75e1734 Mon Sep 17 00:00:00 2001 From: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Date: Tue, 22 Jul 2025 21:23:48 +0800 Subject: [PATCH] Revert "[LLM] fix multinode bugs (#2945)" (#2971) This reverts commit b0f1e0eef4a9ddf22ad12dec3472112fc5779eb6. --- fastdeploy/engine/args_utils.py | 40 ++++++++++++++--- fastdeploy/engine/config.py | 43 ++++++++----------- fastdeploy/engine/engine.py | 15 +++++-- fastdeploy/entrypoints/engine_client.py | 4 +- fastdeploy/entrypoints/openai/api_server.py | 4 +- fastdeploy/entrypoints/openai/serving_chat.py | 10 ++--- .../entrypoints/openai/serving_completion.py | 8 +--- fastdeploy/worker/gpu_worker.py | 15 +++---- fastdeploy/worker/worker_process.py | 16 ++++--- 9 files changed, 87 insertions(+), 68 deletions(-) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index c18c7c65a..cdd9e81d9 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -124,9 +124,19 @@ class EngineArgs: Ratio of tokens to process in a block. """ - ips: Optional[List[str]] = None + dist_init_ip: Optional[str] = None """ - The ips of multinode deployment + The master node ip of multinode deployment + """ + + nnodes: int = 1 + """ + The number of nodes in multinode deployment + """ + + node_rank: int = 0 + """ + The rank of the current node in multinode deployment """ swap_space: float = None @@ -485,11 +495,25 @@ class EngineArgs: # Cluster system parameters group system_group = parser.add_argument_group("System Configuration") system_group.add_argument( - "--ips", - type=lambda s: s.split(",") if s else None, - default=EngineArgs.ips, + "--dist-init-ip", + default=EngineArgs.dist_init_ip, help= - "IP addresses of all nodes participating in distributed inference.") + "IP addresses of master node.") + + system_group.add_argument( + "--nnodes", + type=int, + default=EngineArgs.nnodes, + help= + "The number of all nodes.") + + system_group.add_argument( + "--node-rank", + type=int, + default=EngineArgs.node_rank, + help= + "node rank id (range [0, nnodes)).") + # Performance tuning parameters group @@ -789,7 +813,9 @@ class EngineArgs: max_num_seqs=self.max_num_seqs, speculative_config=speculative_cfg, max_num_batched_tokens=self.max_num_batched_tokens, - ips=self.ips, + dist_init_ip=self.dist_init_ip, + nnodes=self.nnodes, + node_rank=self.node_rank, use_warmup=self.use_warmup, engine_worker_queue_port=self.engine_worker_queue_port, limit_mm_per_prompt=self.limit_mm_per_prompt, diff --git a/fastdeploy/engine/config.py b/fastdeploy/engine/config.py index 004fd30eb..02df10328 100644 --- a/fastdeploy/engine/config.py +++ b/fastdeploy/engine/config.py @@ -6,6 +6,7 @@ # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 +#dist_init_ip # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -23,7 +24,7 @@ from fastdeploy import envs from fastdeploy.platforms import current_platform from fastdeploy.scheduler import SchedulerConfig from fastdeploy.utils import (ceil_div, check_unified_ckpt, get_host_ip, - is_port_available, llm_logger) + is_port_available, get_random_port, llm_logger) TaskOption = Literal["generate"] @@ -641,7 +642,9 @@ class Config: max_model_len: int = 8192, max_num_seqs: int = 8, max_num_batched_tokens: Optional[int] = None, - ips: str = None, + dist_init_ip: str = None, + nnodes: int = 1, + node_rank: int = 0, speculative_config: Optional[Dict[str, Any]] = None, graph_optimization_config: Optional[Dict[str, Any]] = None, use_warmup: bool = False, @@ -697,29 +700,15 @@ class Config: self.tokenizer = tokenizer self.max_num_batched_tokens = max_num_batched_tokens self.tensor_parallel_size = tensor_parallel_size - self.ips = ips + self.dist_init_ip = dist_init_ip - - if self.ips is None: + self.nnode = nnodes + self.node_rank = node_rank + if self.dist_init_ip is None: self.master_ip = "0.0.0.0" - elif isinstance(self.ips, list): - self.master_ip = self.ips[0] else: - self.ips = self.ips.split(",") - self.master_ip = self.ips[0] - - if self.ips is None: - self.nnode = 1 - self.node_rank = 0 - else: - self.nnode = len(self.ips) - - for idx, ip in enumerate(self.ips): - if ip == self.master_ip: - self.node_rank = idx - - - + self.master_ip = self.dist_init_ip + self.dist_init_addr = f"{self.dist_init_ip}:{get_random_port()}" self.max_model_len = max_model_len self.max_num_seqs = max_num_seqs @@ -786,11 +775,14 @@ class Config: assert self.device_ids.split(',').__len__() == self.worker_num_per_node, \ f"invalid CUDA_VISIBLE_DEVICES, should be equal to {self.worker_num_per_node}" - self.local_device_ids = self.device_ids.split(",")[: self.tensor_parallel_size] + assert self.worker_num_per_node % self.tensor_parallel_size == 0, \ + f"tensor_parallel_size: {self.tensor_parallel_size} should be divisible by worker_num_per_node: {self.worker_num_per_node}" + self.local_device_ids = self.device_ids.split( + ',')[:self.tensor_parallel_size] self.host_ip = get_host_ip() - if self.ips is None or self.host_ip == self.master_ip: + if self.dist_init_ip is None or self.host_ip == self.master_ip: self.is_master = True else: self.is_master = False @@ -829,6 +821,9 @@ class Config: assert ( is_port_available('0.0.0.0', self.engine_worker_queue_port) ), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use." + assert ( + self.max_chips_per_node >= self.tensor_parallel_size > 0 + ), f"tensor_parallel_size: {self.tensor_parallel_size} should be between 1 and {self.max_chips_per_node}" assert (self.nnode >= 1), f"nnode: {self.nnode} should no less than 1" assert ( self.max_model_len >= 16 diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 5a4185a4b..6a5d30d21 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -879,7 +879,7 @@ class LLMEngine(object): create=True) if self.do_profile: - get_profile_block_num = np.zeros([min(self.cfg.tensor_parallel_size, self.cfg.worker_num_per_node)], dtype=np.int32) + get_profile_block_num = np.zeros([self.cfg.worker_num_per_node], dtype=np.int32) self.get_profile_block_num_signal = IPCSignal( name="get_profile_block_num", array=get_profile_block_num, @@ -937,7 +937,10 @@ class LLMEngine(object): 配置环境变量 """ variables = { - + "PADDLE_TRAINER_ID": 0, + "PADDLE_TRAINERS_NUM": 1, + "TRAINER_INSTANCES_NUM": 1, + "TRAINER_INSTANCES": "0.0.0.0", "ENABLE_FASTDEPLOY_LOAD_MODEL_CONCURRENCY": 0, "LOAD_STATE_DICT_THREAD_NUM": len(self.cfg.device_ids.split(',')), "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python", @@ -1053,7 +1056,11 @@ class LLMEngine(object): if value: arguments = arguments + f" --{worker_flag}" if self.cfg.nnode > 1: - pd_cmd = pd_cmd + f" --ips {','.join(self.cfg.ips)} --nnodes {len(self.cfg.ips)}" + pd_cmd = pd_cmd + ( + f" --master {self.cfg.dist_init_addr}" + f" --nnodes {str(self.cfg.nnode)}" + f" --rank {str(self.cfg.node_rank)}" + ) pd_cmd = pd_cmd + arguments + f" 2>{log_dir}/launch_worker.log" llm_logger.info("Launch worker service command: {}".format(pd_cmd)) p = subprocess.Popen( @@ -1137,7 +1144,7 @@ class LLMEngine(object): """ self.do_profile = 0 num_gpu_blocks = -1 - for i in range(min(self.cfg.tensor_parallel_size, self.cfg.worker_num_per_node)): + for i in range(self.cfg.tensor_parallel_size): while self.get_profile_block_num_signal.value[i] == 0: time.sleep(1) if num_gpu_blocks < 0: diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index af9b4f45f..9ff35d47b 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -24,7 +24,6 @@ from fastdeploy.input.preprocess import InputPreprocessor from fastdeploy.engine.request import Request from fastdeploy.inter_communicator import ZmqClient, IPCSignal from fastdeploy.metrics.work_metrics import work_process_metrics -from fastdeploy.platforms import current_platform from fastdeploy.utils import api_server_logger, EngineError @@ -44,8 +43,7 @@ class EngineClient: self.reasoning_parser = reasoning_parser self.data_processor = input_processor.create_processor() self.max_model_len = max_model_len - max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 - self.worker_healthy_live_recorded_time_array = np.zeros(shape=[tensor_parallel_size % max_chips_per_node], dtype=np.int32) + self.worker_healthy_live_recorded_time_array = np.zeros(shape=[tensor_parallel_size], dtype=np.int32) self.worker_healthy_live_signal = IPCSignal(name="worker_healthy_live_signal", array=self.worker_healthy_live_recorded_time_array, dtype=np.int32, diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 57835b472..e2ebf925d 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -122,8 +122,8 @@ async def lifespan(app: FastAPI): args.mm_processor_kwargs, args.enable_mm, args.reasoning_parser) app.state.dynamic_load_weight = args.dynamic_load_weight - chat_handler = OpenAIServingChat(engine_client, pid, args.ips) - completion_handler = OpenAIServingCompletion(engine_client, pid, args.ips) + chat_handler = OpenAIServingChat(engine_client, pid, args.dist_init_ip) + completion_handler = OpenAIServingCompletion(engine_client, pid, args.dist_init_ip) engine_client.create_zmq_client(model=pid, mode=zmq.PUSH) engine_client.pid = pid app.state.engine_client = engine_client diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 07cddd8da..778061b85 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -40,19 +40,15 @@ class OpenAIServingChat: OpenAI-style chat completions serving """ - def __init__(self, engine_client, pid, ips): + def __init__(self, engine_client, pid, dist_init_ip): self.engine_client = engine_client self.pid = pid - self.master_ip = ips + self.master_ip = dist_init_ip self.host_ip = get_host_ip() - + def _check_master(self): if self.master_ip is None: return True - if isinstance(self.master_ip, list): - self.master_ip = self.master_ip[0] - else: - self.master_ip = self.master_ip.split(",")[0] if self.host_ip == self.master_ip: return True return False diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 1aeed2b09..acefc3d17 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -45,19 +45,15 @@ from fastdeploy.engine.request import RequestOutput class OpenAIServingCompletion: - def __init__(self, engine_client, pid, ips): + def __init__(self, engine_client, pid, dist_init_ip): self.engine_client = engine_client self.pid = pid - self.master_ip = ips + self.master_ip = dist_init_ip self.host_ip = get_host_ip() def _check_master(self): if self.master_ip is None: return True - if isinstance(self.master_ip, list): - self.master_ip = self.master_ip[0] - else: - self.master_ip = self.master_ip.split(",")[0] if self.host_ip == self.master_ip: return True return False diff --git a/fastdeploy/worker/gpu_worker.py b/fastdeploy/worker/gpu_worker.py index 66bb48a2a..18c1b4302 100644 --- a/fastdeploy/worker/gpu_worker.py +++ b/fastdeploy/worker/gpu_worker.py @@ -100,17 +100,16 @@ class GpuWorker(WorkerBase): # 1. Record memory state before profile run start_time = time.perf_counter() Gb = 1024**3 - local_rank = self.local_rank % self.max_chips_per_node - paddle.device.cuda.reset_max_memory_reserved(local_rank) - paddle.device.cuda.reset_max_memory_allocated(local_rank) + paddle.device.cuda.reset_max_memory_reserved(self.local_rank) + paddle.device.cuda.reset_max_memory_allocated(self.local_rank) paddle_reserved_mem_before_run = paddle.device.cuda.max_memory_reserved( - local_rank) + self.local_rank) paddle_allocated_mem_before_run = paddle.device.cuda.max_memory_allocated( - local_rank) # not reserved + self.local_rank) # not reserved pynvml.nvmlInit() handle = pynvml.nvmlDeviceGetHandleByIndex( - int(self.device_ids[local_rank])) + int(self.device_ids[self.local_rank])) before_run_meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) logger.info(( @@ -127,9 +126,9 @@ class GpuWorker(WorkerBase): # 3. Statistical memory information paddle_reserved_mem_after_run = paddle.device.cuda.max_memory_reserved( - local_rank) + self.local_rank) paddle_allocated_mem_after_run = paddle.device.cuda.max_memory_allocated( - local_rank) + self.local_rank) model_block_memory_used = self.cal_theortical_kvcache() paddle_peak_increase = paddle_reserved_mem_after_run - paddle_allocated_mem_before_run diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index cbb362647..89c5fdf2c 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -143,7 +143,7 @@ class PaddleDisWorkerProc(): # Initialize task queue task_address = (self.parallel_config.pod_ip, self.parallel_config.engine_worker_queue_port) - self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 + self.task_queue = TaskQueue( address=task_address, is_server=False, @@ -162,6 +162,7 @@ class PaddleDisWorkerProc(): model_weights_status: """ # init worker_ready_signal + self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 array_size = min( self.max_chips_per_node, self.parallel_config.tensor_parallel_size * self.parallel_config.expert_parallel_size) @@ -182,9 +183,9 @@ class PaddleDisWorkerProc(): array=workers_alive, dtype=np.int32, suffix=self.parallel_config.engine_pid, - create=False, - ) - self.worker_healthy_live_signal.value[self.local_rank % self.max_chips_per_node] = int(time.time()) + create=False) + self.worker_healthy_live_signal.value[self.local_rank % 8] = int( + time.time()) # init model_weights_status workers_model_weights = np.zeros(shape=[1], dtype=np.int32) @@ -270,7 +271,8 @@ class PaddleDisWorkerProc(): paddle.distributed.barrier() self.insert_step = False - self.worker_healthy_live_signal.value[self.local_rank % self.max_chips_per_node] = int(time.time()) + self.worker_healthy_live_signal.value[self.local_rank] = int( + time.time()) # The first worker detects whether there are tasks in the task queue if self.local_rank % mp_num_per_node == 0: @@ -386,7 +388,7 @@ class PaddleDisWorkerProc(): suffix=self.parallel_config.engine_pid, create=False) self.get_profile_block_num_signal.value[ - self.local_rank % self.max_chips_per_node] = num_blocks_local + self.local_rank] = num_blocks_local # Wait all worker send the signal while np.any(self.get_profile_block_num_signal.value <= 0): @@ -394,7 +396,7 @@ class PaddleDisWorkerProc(): num_blocks_global = self.get_profile_block_num_signal.value.min( ).item() self.get_profile_block_num_signal.value[ - self.local_rank % self.max_chips_per_node] = num_blocks_global + self.local_rank] = num_blocks_global else: num_blocks_global = self.fd_config.parallel_config.total_block_num # NOTE(liuzichang): Too big num_blocks_global will lead to error 700