mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
Adapt for iluvatar gpu (#2684)
This commit is contained in:
@@ -42,7 +42,7 @@ class ModelConfig:
|
||||
model_name_or_path: str,
|
||||
config_json_file: str = "config.json",
|
||||
dynamic_load_weight: bool = False,
|
||||
load_strategy: str="meta",
|
||||
load_strategy: str = "meta",
|
||||
quantization: str = None,
|
||||
download_dir: Optional[str] = None):
|
||||
"""
|
||||
@@ -590,7 +590,7 @@ class Config:
|
||||
self.nnode = 1
|
||||
else:
|
||||
self.nnode = len(self.pod_ips)
|
||||
|
||||
|
||||
assert self.splitwise_role in ["mixed", "prefill", "decode"]
|
||||
|
||||
# TODO
|
||||
@@ -608,8 +608,9 @@ class Config:
|
||||
== 1), "TP and EP cannot be enabled at the same time"
|
||||
|
||||
num_ranks = self.tensor_parallel_size * self.parallel_config.expert_parallel_size
|
||||
if num_ranks > 8:
|
||||
self.worker_num_per_node = 8
|
||||
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
if num_ranks > self.max_chips_per_node:
|
||||
self.worker_num_per_node = self.max_chips_per_node
|
||||
nnode = ceil_div(num_ranks, self.worker_num_per_node)
|
||||
assert nnode == self.nnode, \
|
||||
f"nnode: {nnode}, but got {self.nnode}"
|
||||
@@ -679,8 +680,8 @@ class Config:
|
||||
is_port_available('0.0.0.0', self.engine_worker_queue_port)
|
||||
), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use."
|
||||
assert (
|
||||
8 >= self.tensor_parallel_size > 0
|
||||
), f"tensor_parallel_size: {self.tensor_parallel_size} should be between 1 and 8"
|
||||
self.max_chips_per_node >= self.tensor_parallel_size > 0
|
||||
), f"tensor_parallel_size: {self.tensor_parallel_size} should be between 1 and {self.max_chips_per_node}"
|
||||
assert (self.nnode >= 1), f"nnode: {self.nnode} should no less than 1"
|
||||
assert (
|
||||
self.max_model_len >= 16
|
||||
@@ -816,7 +817,7 @@ class Config:
|
||||
|
||||
def _check_master(self):
|
||||
return self.is_master
|
||||
|
||||
|
||||
def _str_to_list(self, attr_name, default_type):
|
||||
if hasattr(self, attr_name):
|
||||
val = getattr(self, attr_name)
|
||||
|
Reference in New Issue
Block a user