diff --git a/fastdeploy/config.py b/fastdeploy/config.py index f802c53a8..381faa81e 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -325,6 +325,9 @@ class ModelConfig: self.moe_num_experts = self.num_experts if hasattr(self, "n_routed_experts") and getattr(self, "moe_num_experts") is None: self.moe_num_experts = self.n_routed_experts + if hasattr(self, "n_shared_experts") and getattr(self, "moe_num_shared_experts") is None: + # Because the ERNIE 4.5 config.json contains two sets of keys, adaptation is required. + self.moe_num_shared_experts = self.n_shared_experts def read_from_env(self): """ diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 2861d96e8..da705357c 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -1243,6 +1243,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): down_proj_attrs, ) else: + # offline quant # 1.init shape extra_weight_attrs = {**extra_weight_attrs} if layer.fd_config.load_config.load_choices == "default_v1": @@ -1258,17 +1259,9 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): down_proj_scale_shape = self.down_proj_scale_shape[:1] + self.down_proj_scale_shape[1:][::-1] up_gate_proj_attrs = { **extra_weight_attrs, - "tensor_track": TensorTracker( - shape=up_gate_proj_weight_shape, - output_dim=False, - ), } down_proj_attrs = { **extra_weight_attrs, - "tensor_track": TensorTracker( - shape=down_proj_weight_shape, - output_dim=False, - ), } else: up_gate_proj_weight_shape = self.up_gate_proj_weight_shape diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index eff72b2c2..1948e7669 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -803,7 +803,7 @@ class DeepSeekV3PretrainedModel(PretrainedModel): fn = split_or_merge_func( is_split=is_split, - tensor_parallel_degree=config.tensor_parallel_degree, + tensor_model_parallel_size=config.tensor_model_parallel_size, tensor_parallel_rank=config.tensor_parallel_rank, num_attention_heads=config.num_attention_heads, ) diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index ef0554435..7314a4acf 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -796,7 +796,7 @@ class Ernie4_5_MoePretrainedModel(PretrainedModel): fn = split_or_merge_func_v1( is_split=is_split, - tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_degree=config.tensor_model_parallel_size, tensor_parallel_rank=config.tensor_parallel_rank, num_attention_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, diff --git a/fastdeploy/model_executor/models/ernie4_5_mtp.py b/fastdeploy/model_executor/models/ernie4_5_mtp.py index 2d57ed504..4ddba9a9b 100644 --- a/fastdeploy/model_executor/models/ernie4_5_mtp.py +++ b/fastdeploy/model_executor/models/ernie4_5_mtp.py @@ -69,7 +69,7 @@ class Ernie4_5_MTPPretrainedModel(PretrainedModel): fn = split_or_merge_func( is_split=is_split, - tensor_parallel_degree=config.tensor_parallel_degree, + tensor_model_parallel_size=config.tensor_model_parallel_size, tensor_parallel_rank=config.tensor_parallel_rank, num_attention_heads=config.num_attention_heads, ) @@ -170,7 +170,7 @@ class Ernie4_5_MTPPretrainedModel(PretrainedModel): if is_split: qkv_fn = partial( gqa_qkv_split_func, - tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_degree=config.tensor_model_parallel_size, tensor_parallel_rank=config.tensor_parallel_rank, num_attention_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/modeling.py b/fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/modeling.py index 2d8c53b22..0706bf2ab 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/modeling.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/modeling.py @@ -14,7 +14,6 @@ # limitations under the License. """ -from functools import partial from typing import Optional import numpy as np @@ -543,7 +542,7 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel): [ DFNRopeVisionBlock( config.vision_config, - config.pretrained_config.tensor_parallel_degree, + config.pretrained_config.tensor_model_parallel_size, config.pretrained_config.tensor_parallel_rank, model_format=model_format, ) @@ -664,63 +663,6 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel): """ return self.forward(hidden_states, grid_thw) - @classmethod - def _get_tensor_parallel_mappings(cls, config, is_split=True): - """ - dummy - """ - - from paddleformers.transformers.conversion_utils import split_or_merge_func - - fn = split_or_merge_func( - is_split=is_split, - tensor_parallel_degree=config.tensor_parallel_degree, - tensor_parallel_rank=config.tensor_parallel_rank, - ) - vision_config = config.vision_config - - def split_qkv_weight(x): - head_dim = vision_config.hidden_size // vision_config.num_heads - x = x.reshape( - [ - vision_config.hidden_size, - 3, - vision_config.num_heads, - head_dim, - ] - ) - x = np.split(x, vision_config.tensor_parallel_degree, axis=-2)[vision_config.tensor_parallel_rank] - x = x.reshape([vision_config.hidden_size, -1]) - return x - - def split_qkv_bias(x): - head_dim = vision_config.hidden_size // vision_config.num_heads - x = x.reshape([3, vision_config.num_heads, head_dim]) - x = np.split(x, vision_config.tensor_parallel_degree, axis=-2)[vision_config.tensor_parallel_rank] - x = x.reshape([-1]) - return x - - def get_tensor_parallel_split_mappings(depth): - final_actions = {} - base_actions = { - "vision_model.blocks.0.attn.proj.weight": partial(fn, is_column=False), - "vision_model.blocks.0.fc1.weight": partial(fn, is_column=True), - "vision_model.blocks.0.fc1.bias": partial(fn, is_column=True), - "vision_model.blocks.0.fc2.weight": partial(fn, is_column=False), - "vision_model.blocks.0.qkv.weight": split_qkv_weight, - "vision_model.blocks.0.qkv.bias": split_qkv_bias, - } - - for key, action in base_actions.items(): - if "blocks.0." in key: - for i in range(depth): - newkey = key.replace("blocks.0.", f"blocks.{i}.") - final_actions[newkey] = action - return final_actions - - mappings = get_tensor_parallel_split_mappings(vision_config.depth) - return mappings - def load_state_dict(self, state_dict): params_dict = dict(self.named_parameters()) for param_name, param in params_dict.items(): diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index 4e564cedd..5b7e9bd11 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -978,7 +978,7 @@ class Ernie4_5_VLPretrainedModel(PretrainedModel): fn = split_or_merge_func_v1( is_split=is_split, - tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_degree=config.tensor_model_parallel_size, tensor_parallel_rank=config.tensor_parallel_rank, num_attention_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, @@ -986,7 +986,7 @@ class Ernie4_5_VLPretrainedModel(PretrainedModel): ) vision_fn = split_or_merge_func_v1( is_split=is_split, - tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_degree=config.tensor_model_parallel_size, tensor_parallel_rank=config.tensor_parallel_rank, num_attention_heads=config.vision_config.get("num_heads"), num_key_value_heads=config.vision_config.get("num_heads"), diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py b/fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py index 552be1337..320827b0b 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py @@ -15,7 +15,6 @@ """ from copy import deepcopy -from functools import partial import numpy as np import paddle @@ -156,7 +155,7 @@ class VariableResolutionResamplerModel(nn.Layer): self.temporal_conv_size = temporal_conv_size self.use_recompute_resampler = False self.use_temporal_conv = True - self.tensor_parallel_degree = config.pretrained_config.tensor_parallel_degree + self.tensor_parallel_degree = config.pretrained_config.tensor_model_parallel_size self.prefix_name = prefix_name # for 空间四合一 @@ -351,31 +350,3 @@ class VariableResolutionResamplerModel(nn.Layer): raise ValueError(f"{state_dict_key} param.shape={param.shape} tensor.shape={tensor.shape}") else: param.copy_(tensor, False) - - @classmethod - def _get_tensor_parallel_mappings(cls, config, is_split=True): - - from paddleformers.transformers.conversion_utils import split_or_merge_func - - fn = split_or_merge_func( - is_split=is_split, - tensor_parallel_degree=config.tensor_parallel_degree, - tensor_parallel_rank=config.tensor_parallel_rank, - num_attention_heads=config.num_attention_heads, - ) - res = {"spatial_linear.0.weight": partial(fn, is_column=False)} - for k in ( - "spatial_linear.0.bias", # row linear bias - "spatial_linear.2.weight", - "spatial_linear.2.bias", # linear - "spatial_linear.3.weight", - "spatial_linear.3.bias", # layernorm - "temporal_linear.0.weight", - "temporal_linear.0.weight", # linear - "temporal_linear.2.weight", - "temporal_linear.2.bias", # linear - "temporal_linear.3.weight", - "temporal_linear.3.bias", # bias - ): - res.update({k: lambda x: x}) - return res diff --git a/fastdeploy/model_executor/models/glm4_moe.py b/fastdeploy/model_executor/models/glm4_moe.py index a466ae01e..495341d83 100644 --- a/fastdeploy/model_executor/models/glm4_moe.py +++ b/fastdeploy/model_executor/models/glm4_moe.py @@ -549,7 +549,7 @@ class Glm4MoePretrainedModel(PretrainedModel): fn = split_or_merge_func_v1( is_split=is_split, - tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_degree=config.tensor_model_parallel_size, tensor_parallel_rank=config.tensor_parallel_rank, num_attention_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, diff --git a/fastdeploy/model_executor/models/qwen2.py b/fastdeploy/model_executor/models/qwen2.py index 4707f837b..3064791bc 100644 --- a/fastdeploy/model_executor/models/qwen2.py +++ b/fastdeploy/model_executor/models/qwen2.py @@ -445,7 +445,7 @@ class Qwen2PretrainedModel(PretrainedModel): fn = split_or_merge_func( is_split=is_split, - tensor_parallel_degree=config.tensor_parallel_degree, + tensor_model_parallel_size=config.tensor_model_parallel_size, tensor_parallel_rank=config.tensor_parallel_rank, num_attention_heads=config.num_attention_heads, ) @@ -468,7 +468,7 @@ class Qwen2PretrainedModel(PretrainedModel): base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True) # if we have enough num_key_value_heads to split, then split it. - if config.num_key_value_heads % config.tensor_parallel_degree == 0: + if config.num_key_value_heads % config.tensor_model_parallel_size == 0: base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True) base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True) base_actions["layers.0.self_attn.k_proj.bias"] = partial(fn, is_column=True) diff --git a/fastdeploy/model_executor/models/qwen2_5_vl/dfnrope/modeling.py b/fastdeploy/model_executor/models/qwen2_5_vl/dfnrope/modeling.py index 4414eb917..e8184f8d3 100644 --- a/fastdeploy/model_executor/models/qwen2_5_vl/dfnrope/modeling.py +++ b/fastdeploy/model_executor/models/qwen2_5_vl/dfnrope/modeling.py @@ -14,10 +14,8 @@ # limitations under the License. """ -from functools import partial from typing import Optional -import numpy as np import paddle import paddle.nn.functional as F from paddle import nn @@ -560,7 +558,7 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel): num_heads=config.vision_config.num_heads, mlp_hidden_dim=config.vision_config.intermediate_size, hidden_act=config.vision_config.hidden_act, - tensor_parallel_degree=config.pretrained_config.tensor_parallel_degree, + tensor_parallel_degree=config.pretrained_config.tensor_model_parallel_size, tensor_parallel_rank=config.pretrained_config.tensor_parallel_rank, model_format=model_format, ) @@ -731,65 +729,6 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel): """ return self.forward(hidden_states, grid_thw) - @classmethod - def _get_tensor_parallel_mappings(cls, config, is_split=True): - """ - dummy - """ - - from paddleformers.transformers.conversion_utils import split_or_merge_func - - fn = split_or_merge_func( - is_split=is_split, - tensor_parallel_degree=config.tensor_parallel_degree, - tensor_parallel_rank=config.tensor_parallel_rank, - ) - vision_config = config.vision_config - - def split_qkv_weight(x): - head_dim = vision_config.hidden_size // vision_config.num_heads - x = x.reshape( - [ - vision_config.hidden_size, - 3, - vision_config.num_heads, - head_dim, - ] - ) - x = np.split(x, vision_config.tensor_parallel_degree, axis=-2)[vision_config.tensor_parallel_rank] - x = x.reshape([vision_config.hidden_size, -1]) - return x - - def split_qkv_bias(x): - head_dim = vision_config.hidden_size // vision_config.num_heads - x = x.reshape([3, vision_config.num_heads, head_dim]) - x = np.split(x, vision_config.tensor_parallel_degree, axis=-2)[vision_config.tensor_parallel_rank] - x = x.reshape([-1]) - return x - - def get_tensor_parallel_split_mappings(depth): - final_actions = {} - base_actions = { - "visual.blocks.0.attn.proj.weight": partial(fn, is_column=False), - "visual.blocks.0.mlp.gate_proj.weight": partial(fn, is_column=True), - "visual.blocks.0.mlp.gate_proj.bias": partial(fn, is_column=True), - "visual.blocks.0.mlp.up_proj.weight": partial(fn, is_column=True), - "visual.blocks.0.mlp.up_proj.bias": partial(fn, is_column=True), - "visual.blocks.0.mlp.down_proj.weight": partial(fn, is_column=False), - "visual.blocks.0.qkv.weight": split_qkv_weight, - "visual.blocks.0.qkv.bias": split_qkv_bias, - } - - for key, action in base_actions.items(): - if "blocks.0." in key: - for i in range(depth): - newkey = key.replace("blocks.0.", f"blocks.{i}.") - final_actions[newkey] = action - return final_actions - - mappings = get_tensor_parallel_split_mappings(vision_config.depth) - return mappings - def load_state_dict(self, state_dict): params_dict = dict(self.named_parameters()) for param_name, param in params_dict.items(): diff --git a/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py b/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py index e308f9483..021028edf 100644 --- a/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py +++ b/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py @@ -388,7 +388,7 @@ class Qwen2_5_VLPretrainedModel(PretrainedModel): fn = split_or_merge_func_v1( is_split=is_split, - tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_degree=config.tensor_model_parallel_size, tensor_parallel_rank=config.tensor_parallel_rank, num_attention_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, @@ -397,7 +397,7 @@ class Qwen2_5_VLPretrainedModel(PretrainedModel): vision_fn = split_or_merge_func_v1( is_split=is_split, - tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_degree=config.tensor_model_parallel_size, tensor_parallel_rank=config.tensor_parallel_rank, num_attention_heads=config.vision_config.get("num_heads"), num_key_value_heads=config.vision_config.get("num_heads"), diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index 1f5c85e66..f14823508 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -386,7 +386,7 @@ class Qwen3PretrainedModel(PretrainedModel): fn = split_or_merge_func( is_split=is_split, - tensor_parallel_degree=config.tensor_parallel_degree, + tensor_model_parallel_size=config.tensor_model_parallel_size, tensor_parallel_rank=config.tensor_parallel_rank, num_attention_heads=config.num_attention_heads, ) @@ -407,7 +407,7 @@ class Qwen3PretrainedModel(PretrainedModel): base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True) # if we have enough num_key_value_heads to split, then split it. - if config.num_key_value_heads % config.tensor_parallel_degree == 0: + if config.num_key_value_heads % config.tensor_model_parallel_size == 0: base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True) base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True) diff --git a/fastdeploy/model_executor/models/qwen3moe.py b/fastdeploy/model_executor/models/qwen3moe.py index 647a1b164..6bf418f74 100644 --- a/fastdeploy/model_executor/models/qwen3moe.py +++ b/fastdeploy/model_executor/models/qwen3moe.py @@ -470,7 +470,7 @@ class Qwen3MoePretrainedModel(PretrainedModel): fn = split_or_merge_func( is_split=is_split, - tensor_parallel_degree=config.tensor_parallel_degree, + tensor_model_parallel_size=config.tensor_model_parallel_size, tensor_parallel_rank=config.tensor_parallel_rank, num_attention_heads=config.num_attention_heads, ) @@ -493,7 +493,7 @@ class Qwen3MoePretrainedModel(PretrainedModel): base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True) # if we have enough num_key_value_heads to split, then split it. - if config.num_key_value_heads % config.tensor_parallel_degree == 0: + if config.num_key_value_heads % config.tensor_model_parallel_size == 0: base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True) base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True) base_actions["layers.0.self_attn.k_proj.bias"] = partial(fn, is_column=True) diff --git a/fastdeploy/model_executor/models/tp_utils.py b/fastdeploy/model_executor/models/tp_utils.py index 2283d1b3f..367d8cffb 100644 --- a/fastdeploy/model_executor/models/tp_utils.py +++ b/fastdeploy/model_executor/models/tp_utils.py @@ -453,7 +453,7 @@ def split_or_merge_func_v1( else: func = split_or_merge_func( is_split=is_split, - tensor_parallel_degree=tensor_parallel_degree, + tensor_model_parallel_size=tensor_parallel_degree, tensor_parallel_rank=tensor_parallel_rank, num_attention_heads=num_attention_heads, ) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index c3a3b5076..d82343a7e 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -815,8 +815,8 @@ def parse_args(): parser.add_argument( "--load_choices", type=str, - default="default", - help="The format of the model weights to load. default/new_loader.", + default="default_v1", + help="The format of the model weights to load. default/default_v1.", ) parser.add_argument( @@ -952,7 +952,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: # Note(tangbinhan): used for load_checkpoint model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank - model_config.pretrained_config.tensor_parallel_degree = parallel_config.tensor_parallel_size + model_config.pretrained_config.tensor_model_parallel_size = parallel_config.tensor_parallel_size model_config.pretrained_config.is_mtp = False model_config.pretrained_config.head_dim = model_config.head_dim diff --git a/requirements.txt b/requirements.txt index 4cac619c4..e351cd7c8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ tqdm pynvml uvicorn>=0.38.0 fastapi -paddleformers>=0.3.1 +paddleformers @ https://paddle-qa.bj.bcebos.com/ernie/paddleformers-0.4.0.post20251222-py3-none-any.whl redis etcd3 httpx diff --git a/requirements_dcu.txt b/requirements_dcu.txt index 714e0ae1d..1f0a20f2d 100644 --- a/requirements_dcu.txt +++ b/requirements_dcu.txt @@ -10,7 +10,7 @@ tqdm pynvml uvicorn==0.29.0 fastapi -paddleformers +paddleformers @ https://paddle-qa.bj.bcebos.com/ernie/paddleformers-0.4.0.post20251222-py3-none-any.whl redis etcd3 httpx diff --git a/requirements_iluvatar.txt b/requirements_iluvatar.txt index 0cb60ae88..fb0d702c4 100644 --- a/requirements_iluvatar.txt +++ b/requirements_iluvatar.txt @@ -10,7 +10,7 @@ tqdm pynvml uvicorn==0.29.0 fastapi -paddleformers==0.4.0 +paddleformers @ https://paddle-qa.bj.bcebos.com/ernie/paddleformers-0.4.0.post20251222-py3-none-any.whl redis etcd3 httpx diff --git a/requirements_metaxgpu.txt b/requirements_metaxgpu.txt index d49339b0f..96f1c4584 100644 --- a/requirements_metaxgpu.txt +++ b/requirements_metaxgpu.txt @@ -10,7 +10,7 @@ tqdm pynvml uvicorn==0.29.0 fastapi -paddleformers==0.3.2 +paddleformers @ https://paddle-qa.bj.bcebos.com/ernie/paddleformers-0.4.0.post20251222-py3-none-any.whl redis etcd3 httpx diff --git a/tests/model_executor/test_tp_utils.py b/tests/model_executor/test_tp_utils.py index 97b6427ad..666733d55 100644 --- a/tests/model_executor/test_tp_utils.py +++ b/tests/model_executor/test_tp_utils.py @@ -106,13 +106,13 @@ def _install_dependency_stubs(): conversion_utils = types.ModuleType("paddleformers.transformers.conversion_utils") - def _split_or_merge_func(is_split, tensor_parallel_degree, tensor_parallel_rank, **_kwargs): + def _split_or_merge_func(is_split, tensor_model_parallel_size, tensor_parallel_rank, **_kwargs): axis = -1 def _fn(weight, *, is_column=True, **_kwargs): current_axis = axis if is_column else 0 if is_split: - chunks = np.array_split(weight, tensor_parallel_degree, axis=current_axis) + chunks = np.array_split(weight, tensor_model_parallel_size, axis=current_axis) if tensor_parallel_rank is None: return chunks return chunks[tensor_parallel_rank]