[LLM] support multi node deploy (#2708)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* [LLM] support multi node deploy

* Update engine.py

* fix bugs

* fix

* [LLM] support multi node deploy

* [LLM] support multi node deploy

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
ltd0924
2025-07-06 10:33:51 +08:00
committed by GitHub
parent 04a8e1ef2b
commit 68b4755587
13 changed files with 157 additions and 87 deletions

View File

@@ -505,7 +505,6 @@ class Config:
model_name_or_path: str = None,
tokenizer: str = None,
tensor_parallel_size: int = 8,
nnode: int = 1,
max_model_len: int = 8192,
max_num_seqs: int = 8,
max_num_batched_tokens: Optional[int] = None,
@@ -539,7 +538,6 @@ class Config:
model_name_or_path (str): Model directory path or model name.
tokenizer (str): Default is the model.
tensor_parallel_size (int): Tensor parallel size. Default is 8.
nnode (int): Number of nodes. Default is 1.
max_model_len (int): Maximum model length. Default is 8192.
max_num_seqs (int): Maximum number of sequences. Default is 8.
max_num_batched_tokens (Optional[int]): Maximum number of batched tokens. Default is None.
@@ -565,7 +563,6 @@ class Config:
self.tokenizer = tokenizer
self.max_num_batched_tokens = max_num_batched_tokens
self.tensor_parallel_size = tensor_parallel_size
self.nnode = nnode
self.pod_ips = pod_ips
self.max_model_len = max_model_len
self.max_num_seqs = max_num_seqs
@@ -585,12 +582,15 @@ class Config:
self.max_capture_batch_size = max_capture_batch_size
self.guided_decoding_backend = guided_decoding_backend
self.disable_any_whitespace = disable_any_whitespace
self.is_master = True
self._str_to_list("innode_prefill_ports", int)
self._str_to_list("pod_ips", str)
if self.innode_prefill_ports is not None:
if not isinstance(self.innode_prefill_ports, list):
ports = str(self.innode_prefill_ports).split(',')
self.innode_prefill_ports = [int(port) for port in ports]
if self.pod_ips is None:
self.nnode = 1
else:
self.nnode = len(self.pod_ips)
assert self.splitwise_role in ["mixed", "prefill", "decode"]
# TODO
@@ -609,14 +609,15 @@ class Config:
num_ranks = self.tensor_parallel_size * self.parallel_config.expert_parallel_size
if num_ranks > 8:
local_num_ranks = 8
self.nnode = ceil_div(num_ranks, local_num_ranks)
self.worker_num_per_node = 8
nnode = ceil_div(num_ranks, self.worker_num_per_node)
assert nnode == self.nnode, \
f"nnode: {nnode}, but got {self.nnode}"
else:
local_num_ranks = num_ranks
self.worker_num_per_node = num_ranks
self.engine_worker_queue_port = engine_worker_queue_port
self.device_ids = ",".join([str(i) for i in range(min((self.tensor_parallel_size * \
self.parallel_config.expert_parallel_size), 8))])
self.device_ids = ",".join([str(i) for i in range(self.worker_num_per_node)])
self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids)
self.read_from_config()
@@ -628,16 +629,21 @@ class Config:
"""
calculate some parameters
"""
total_rank = self.tensor_parallel_size * self.parallel_config.expert_parallel_size
assert self.device_ids.split(',').__len__() == min(total_rank, 8), \
f"invalid CUDA_VISIBLE_DEVICES, should be equal to {min(total_rank, 8)}"
assert self.device_ids.split(',').__len__() == self.worker_num_per_node, \
f"invalid CUDA_VISIBLE_DEVICES, should be equal to {self.worker_num_per_node}"
assert self.worker_num_per_node % self.tensor_parallel_size == 0, \
f"tensor_parallel_size: {self.tensor_parallel_size} should be divisible by worker_num_per_node: {self.worker_num_per_node}"
self.local_device_ids = self.device_ids.split(
',')[:self.tensor_parallel_size]
assert self.tensor_parallel_size % self.nnode == 0, \
f"tensor_parallel_size: {self.tensor_parallel_size} should be divisible by nnode: {self.nnode}"
self.worker_num_per_node = total_rank // self.nnode
self.host_ip = get_host_ip()
if self.pod_ips is None:
self.pod_ips = ["0.0.0.0"]
elif self.host_ip != self.pod_ips[0]:
self.is_master = False
import paddle
self.paddle_commit_id = paddle.version.commit
@@ -808,5 +814,16 @@ class Config:
"return_full_hidden_states")
reset_value(self.cache_config, "cache_dtype", "infer_model_dtype")
def _check_master(self):
return self.is_master
def _str_to_list(self, attr_name, default_type):
if hasattr(self, attr_name):
val = getattr(self, attr_name)
if type(val) is str:
setattr(self, attr_name, [default_type(i) for i in val.split(",")])
else:
setattr(self, attr_name, val)
def __str__(self) -> str:
return json.dumps(self.__dict__, indent=4)