[LLM] support multi node deploy (#2708)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* [LLM] support multi node deploy

* Update engine.py

* fix bugs

* fix

* [LLM] support multi node deploy

* [LLM] support multi node deploy

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
ltd0924
2025-07-06 10:33:51 +08:00
committed by GitHub
parent 04a8e1ef2b
commit 68b4755587
13 changed files with 157 additions and 87 deletions

View File

@@ -103,7 +103,7 @@ class PaddleDisWorkerProc():
rank=self.ranks)
# Initialize task queue
task_address = ('0.0.0.0',
task_address = (self.parallel_config.pod_ip,
self.parallel_config.engine_worker_queue_port)
self.task_queue = TaskQueue(
@@ -218,7 +218,8 @@ class PaddleDisWorkerProc():
TODO(gongshaotian): support remote calling of functions that control worker.
"""
# Currently, only support single node
self.nnode = 1
self.nnode = int((self.parallel_config.tensor_parallel_degree + 7) // 8)
mp_num_per_node = self.parallel_config.tensor_parallel_degree // self.nnode
req_ids = []
while True:
if self.local_rank == 0:
@@ -236,8 +237,7 @@ class PaddleDisWorkerProc():
time.time())
# The first worker detects whether there are tasks in the task queue
mp_num_per_node = self.ranks / self.nnode
if self.local_rank % mp_num_per_node == 0:
if self.local_rank % mp_num_per_node == 0:
if self.task_queue.num_tasks() > 0:
if self.nnode > 1:
self.task_queue.read_finish_flag.set(1)
@@ -412,6 +412,7 @@ def parse_args():
help="max batch size")
parser.add_argument("--total_block_num", type=int, default=2000)
parser.add_argument("--block_size", type=int, default=64)
parser.add_argument("--pod_ip", type=str, default="127.0.0.1")
parser.add_argument("--engine_worker_queue_port", type=int, default=9923)
parser.add_argument("--max_model_len",
type=int,
@@ -600,6 +601,7 @@ def initialize_fd_config(args: argparse.Namespace) -> FDConfig:
parallel_config.max_num_seqs = args.max_num_seqs
parallel_config.max_block_num = args.total_block_num
parallel_config.block_size = args.block_size
parallel_config.pod_ip = args.pod_ip
parallel_config.engine_worker_queue_port = args.engine_worker_queue_port
parallel_config.max_model_len = args.max_model_len
model_config.max_seq_len = args.max_model_len