mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[LLM] support multi node deploy (#2708)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* [LLM] support multi node deploy * Update engine.py * fix bugs * fix * [LLM] support multi node deploy * [LLM] support multi node deploy --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -505,7 +505,6 @@ class Config:
|
||||
model_name_or_path: str = None,
|
||||
tokenizer: str = None,
|
||||
tensor_parallel_size: int = 8,
|
||||
nnode: int = 1,
|
||||
max_model_len: int = 8192,
|
||||
max_num_seqs: int = 8,
|
||||
max_num_batched_tokens: Optional[int] = None,
|
||||
@@ -539,7 +538,6 @@ class Config:
|
||||
model_name_or_path (str): Model directory path or model name.
|
||||
tokenizer (str): Default is the model.
|
||||
tensor_parallel_size (int): Tensor parallel size. Default is 8.
|
||||
nnode (int): Number of nodes. Default is 1.
|
||||
max_model_len (int): Maximum model length. Default is 8192.
|
||||
max_num_seqs (int): Maximum number of sequences. Default is 8.
|
||||
max_num_batched_tokens (Optional[int]): Maximum number of batched tokens. Default is None.
|
||||
@@ -565,7 +563,6 @@ class Config:
|
||||
self.tokenizer = tokenizer
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
self.nnode = nnode
|
||||
self.pod_ips = pod_ips
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_seqs = max_num_seqs
|
||||
@@ -585,12 +582,15 @@ class Config:
|
||||
self.max_capture_batch_size = max_capture_batch_size
|
||||
self.guided_decoding_backend = guided_decoding_backend
|
||||
self.disable_any_whitespace = disable_any_whitespace
|
||||
self.is_master = True
|
||||
self._str_to_list("innode_prefill_ports", int)
|
||||
self._str_to_list("pod_ips", str)
|
||||
|
||||
if self.innode_prefill_ports is not None:
|
||||
if not isinstance(self.innode_prefill_ports, list):
|
||||
ports = str(self.innode_prefill_ports).split(',')
|
||||
self.innode_prefill_ports = [int(port) for port in ports]
|
||||
|
||||
if self.pod_ips is None:
|
||||
self.nnode = 1
|
||||
else:
|
||||
self.nnode = len(self.pod_ips)
|
||||
|
||||
assert self.splitwise_role in ["mixed", "prefill", "decode"]
|
||||
|
||||
# TODO
|
||||
@@ -609,14 +609,15 @@ class Config:
|
||||
|
||||
num_ranks = self.tensor_parallel_size * self.parallel_config.expert_parallel_size
|
||||
if num_ranks > 8:
|
||||
local_num_ranks = 8
|
||||
self.nnode = ceil_div(num_ranks, local_num_ranks)
|
||||
self.worker_num_per_node = 8
|
||||
nnode = ceil_div(num_ranks, self.worker_num_per_node)
|
||||
assert nnode == self.nnode, \
|
||||
f"nnode: {nnode}, but got {self.nnode}"
|
||||
else:
|
||||
local_num_ranks = num_ranks
|
||||
self.worker_num_per_node = num_ranks
|
||||
|
||||
self.engine_worker_queue_port = engine_worker_queue_port
|
||||
self.device_ids = ",".join([str(i) for i in range(min((self.tensor_parallel_size * \
|
||||
self.parallel_config.expert_parallel_size), 8))])
|
||||
self.device_ids = ",".join([str(i) for i in range(self.worker_num_per_node)])
|
||||
self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids)
|
||||
|
||||
self.read_from_config()
|
||||
@@ -628,16 +629,21 @@ class Config:
|
||||
"""
|
||||
calculate some parameters
|
||||
"""
|
||||
total_rank = self.tensor_parallel_size * self.parallel_config.expert_parallel_size
|
||||
assert self.device_ids.split(',').__len__() == min(total_rank, 8), \
|
||||
f"invalid CUDA_VISIBLE_DEVICES, should be equal to {min(total_rank, 8)}"
|
||||
assert self.device_ids.split(',').__len__() == self.worker_num_per_node, \
|
||||
f"invalid CUDA_VISIBLE_DEVICES, should be equal to {self.worker_num_per_node}"
|
||||
|
||||
assert self.worker_num_per_node % self.tensor_parallel_size == 0, \
|
||||
f"tensor_parallel_size: {self.tensor_parallel_size} should be divisible by worker_num_per_node: {self.worker_num_per_node}"
|
||||
self.local_device_ids = self.device_ids.split(
|
||||
',')[:self.tensor_parallel_size]
|
||||
assert self.tensor_parallel_size % self.nnode == 0, \
|
||||
f"tensor_parallel_size: {self.tensor_parallel_size} should be divisible by nnode: {self.nnode}"
|
||||
self.worker_num_per_node = total_rank // self.nnode
|
||||
|
||||
self.host_ip = get_host_ip()
|
||||
|
||||
if self.pod_ips is None:
|
||||
self.pod_ips = ["0.0.0.0"]
|
||||
elif self.host_ip != self.pod_ips[0]:
|
||||
self.is_master = False
|
||||
|
||||
import paddle
|
||||
self.paddle_commit_id = paddle.version.commit
|
||||
|
||||
@@ -808,5 +814,16 @@ class Config:
|
||||
"return_full_hidden_states")
|
||||
reset_value(self.cache_config, "cache_dtype", "infer_model_dtype")
|
||||
|
||||
def _check_master(self):
|
||||
return self.is_master
|
||||
|
||||
def _str_to_list(self, attr_name, default_type):
|
||||
if hasattr(self, attr_name):
|
||||
val = getattr(self, attr_name)
|
||||
if type(val) is str:
|
||||
setattr(self, attr_name, [default_type(i) for i in val.split(",")])
|
||||
else:
|
||||
setattr(self, attr_name, val)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.__dict__, indent=4)
|
||||
|
Reference in New Issue
Block a user