mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Others] upgrade paddleformer to 0.4.0 (#5599)
This commit is contained in:
@@ -325,6 +325,9 @@ class ModelConfig:
|
||||
self.moe_num_experts = self.num_experts
|
||||
if hasattr(self, "n_routed_experts") and getattr(self, "moe_num_experts") is None:
|
||||
self.moe_num_experts = self.n_routed_experts
|
||||
if hasattr(self, "n_shared_experts") and getattr(self, "moe_num_shared_experts") is None:
|
||||
# Because the ERNIE 4.5 config.json contains two sets of keys, adaptation is required.
|
||||
self.moe_num_shared_experts = self.n_shared_experts
|
||||
|
||||
def read_from_env(self):
|
||||
"""
|
||||
|
||||
@@ -1243,6 +1243,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
down_proj_attrs,
|
||||
)
|
||||
else:
|
||||
# offline quant
|
||||
# 1.init shape
|
||||
extra_weight_attrs = {**extra_weight_attrs}
|
||||
if layer.fd_config.load_config.load_choices == "default_v1":
|
||||
@@ -1258,17 +1259,9 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
down_proj_scale_shape = self.down_proj_scale_shape[:1] + self.down_proj_scale_shape[1:][::-1]
|
||||
up_gate_proj_attrs = {
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(
|
||||
shape=up_gate_proj_weight_shape,
|
||||
output_dim=False,
|
||||
),
|
||||
}
|
||||
down_proj_attrs = {
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(
|
||||
shape=down_proj_weight_shape,
|
||||
output_dim=False,
|
||||
),
|
||||
}
|
||||
else:
|
||||
up_gate_proj_weight_shape = self.up_gate_proj_weight_shape
|
||||
|
||||
@@ -803,7 +803,7 @@ class DeepSeekV3PretrainedModel(PretrainedModel):
|
||||
|
||||
fn = split_or_merge_func(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_model_parallel_size=config.tensor_model_parallel_size,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
)
|
||||
|
||||
@@ -796,7 +796,7 @@ class Ernie4_5_MoePretrainedModel(PretrainedModel):
|
||||
|
||||
fn = split_or_merge_func_v1(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_parallel_degree=config.tensor_model_parallel_size,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
|
||||
@@ -69,7 +69,7 @@ class Ernie4_5_MTPPretrainedModel(PretrainedModel):
|
||||
|
||||
fn = split_or_merge_func(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_model_parallel_size=config.tensor_model_parallel_size,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
)
|
||||
@@ -170,7 +170,7 @@ class Ernie4_5_MTPPretrainedModel(PretrainedModel):
|
||||
if is_split:
|
||||
qkv_fn = partial(
|
||||
gqa_qkv_split_func,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_parallel_degree=config.tensor_model_parallel_size,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
@@ -543,7 +542,7 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
|
||||
[
|
||||
DFNRopeVisionBlock(
|
||||
config.vision_config,
|
||||
config.pretrained_config.tensor_parallel_degree,
|
||||
config.pretrained_config.tensor_model_parallel_size,
|
||||
config.pretrained_config.tensor_parallel_rank,
|
||||
model_format=model_format,
|
||||
)
|
||||
@@ -664,63 +663,6 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
|
||||
"""
|
||||
return self.forward(hidden_states, grid_thw)
|
||||
|
||||
@classmethod
|
||||
def _get_tensor_parallel_mappings(cls, config, is_split=True):
|
||||
"""
|
||||
dummy
|
||||
"""
|
||||
|
||||
from paddleformers.transformers.conversion_utils import split_or_merge_func
|
||||
|
||||
fn = split_or_merge_func(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
)
|
||||
vision_config = config.vision_config
|
||||
|
||||
def split_qkv_weight(x):
|
||||
head_dim = vision_config.hidden_size // vision_config.num_heads
|
||||
x = x.reshape(
|
||||
[
|
||||
vision_config.hidden_size,
|
||||
3,
|
||||
vision_config.num_heads,
|
||||
head_dim,
|
||||
]
|
||||
)
|
||||
x = np.split(x, vision_config.tensor_parallel_degree, axis=-2)[vision_config.tensor_parallel_rank]
|
||||
x = x.reshape([vision_config.hidden_size, -1])
|
||||
return x
|
||||
|
||||
def split_qkv_bias(x):
|
||||
head_dim = vision_config.hidden_size // vision_config.num_heads
|
||||
x = x.reshape([3, vision_config.num_heads, head_dim])
|
||||
x = np.split(x, vision_config.tensor_parallel_degree, axis=-2)[vision_config.tensor_parallel_rank]
|
||||
x = x.reshape([-1])
|
||||
return x
|
||||
|
||||
def get_tensor_parallel_split_mappings(depth):
|
||||
final_actions = {}
|
||||
base_actions = {
|
||||
"vision_model.blocks.0.attn.proj.weight": partial(fn, is_column=False),
|
||||
"vision_model.blocks.0.fc1.weight": partial(fn, is_column=True),
|
||||
"vision_model.blocks.0.fc1.bias": partial(fn, is_column=True),
|
||||
"vision_model.blocks.0.fc2.weight": partial(fn, is_column=False),
|
||||
"vision_model.blocks.0.qkv.weight": split_qkv_weight,
|
||||
"vision_model.blocks.0.qkv.bias": split_qkv_bias,
|
||||
}
|
||||
|
||||
for key, action in base_actions.items():
|
||||
if "blocks.0." in key:
|
||||
for i in range(depth):
|
||||
newkey = key.replace("blocks.0.", f"blocks.{i}.")
|
||||
final_actions[newkey] = action
|
||||
return final_actions
|
||||
|
||||
mappings = get_tensor_parallel_split_mappings(vision_config.depth)
|
||||
return mappings
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
params_dict = dict(self.named_parameters())
|
||||
for param_name, param in params_dict.items():
|
||||
|
||||
@@ -978,7 +978,7 @@ class Ernie4_5_VLPretrainedModel(PretrainedModel):
|
||||
|
||||
fn = split_or_merge_func_v1(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_parallel_degree=config.tensor_model_parallel_size,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
@@ -986,7 +986,7 @@ class Ernie4_5_VLPretrainedModel(PretrainedModel):
|
||||
)
|
||||
vision_fn = split_or_merge_func_v1(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_parallel_degree=config.tensor_model_parallel_size,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.vision_config.get("num_heads"),
|
||||
num_key_value_heads=config.vision_config.get("num_heads"),
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
@@ -156,7 +155,7 @@ class VariableResolutionResamplerModel(nn.Layer):
|
||||
self.temporal_conv_size = temporal_conv_size
|
||||
self.use_recompute_resampler = False
|
||||
self.use_temporal_conv = True
|
||||
self.tensor_parallel_degree = config.pretrained_config.tensor_parallel_degree
|
||||
self.tensor_parallel_degree = config.pretrained_config.tensor_model_parallel_size
|
||||
self.prefix_name = prefix_name
|
||||
|
||||
# for 空间四合一
|
||||
@@ -351,31 +350,3 @@ class VariableResolutionResamplerModel(nn.Layer):
|
||||
raise ValueError(f"{state_dict_key} param.shape={param.shape} tensor.shape={tensor.shape}")
|
||||
else:
|
||||
param.copy_(tensor, False)
|
||||
|
||||
@classmethod
|
||||
def _get_tensor_parallel_mappings(cls, config, is_split=True):
|
||||
|
||||
from paddleformers.transformers.conversion_utils import split_or_merge_func
|
||||
|
||||
fn = split_or_merge_func(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
)
|
||||
res = {"spatial_linear.0.weight": partial(fn, is_column=False)}
|
||||
for k in (
|
||||
"spatial_linear.0.bias", # row linear bias
|
||||
"spatial_linear.2.weight",
|
||||
"spatial_linear.2.bias", # linear
|
||||
"spatial_linear.3.weight",
|
||||
"spatial_linear.3.bias", # layernorm
|
||||
"temporal_linear.0.weight",
|
||||
"temporal_linear.0.weight", # linear
|
||||
"temporal_linear.2.weight",
|
||||
"temporal_linear.2.bias", # linear
|
||||
"temporal_linear.3.weight",
|
||||
"temporal_linear.3.bias", # bias
|
||||
):
|
||||
res.update({k: lambda x: x})
|
||||
return res
|
||||
|
||||
@@ -549,7 +549,7 @@ class Glm4MoePretrainedModel(PretrainedModel):
|
||||
|
||||
fn = split_or_merge_func_v1(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_parallel_degree=config.tensor_model_parallel_size,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
|
||||
@@ -445,7 +445,7 @@ class Qwen2PretrainedModel(PretrainedModel):
|
||||
|
||||
fn = split_or_merge_func(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_model_parallel_size=config.tensor_model_parallel_size,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
)
|
||||
@@ -468,7 +468,7 @@ class Qwen2PretrainedModel(PretrainedModel):
|
||||
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
|
||||
base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True)
|
||||
# if we have enough num_key_value_heads to split, then split it.
|
||||
if config.num_key_value_heads % config.tensor_parallel_degree == 0:
|
||||
if config.num_key_value_heads % config.tensor_model_parallel_size == 0:
|
||||
base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
|
||||
base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)
|
||||
base_actions["layers.0.self_attn.k_proj.bias"] = partial(fn, is_column=True)
|
||||
|
||||
@@ -14,10 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from paddle import nn
|
||||
@@ -560,7 +558,7 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
|
||||
num_heads=config.vision_config.num_heads,
|
||||
mlp_hidden_dim=config.vision_config.intermediate_size,
|
||||
hidden_act=config.vision_config.hidden_act,
|
||||
tensor_parallel_degree=config.pretrained_config.tensor_parallel_degree,
|
||||
tensor_parallel_degree=config.pretrained_config.tensor_model_parallel_size,
|
||||
tensor_parallel_rank=config.pretrained_config.tensor_parallel_rank,
|
||||
model_format=model_format,
|
||||
)
|
||||
@@ -731,65 +729,6 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
|
||||
"""
|
||||
return self.forward(hidden_states, grid_thw)
|
||||
|
||||
@classmethod
|
||||
def _get_tensor_parallel_mappings(cls, config, is_split=True):
|
||||
"""
|
||||
dummy
|
||||
"""
|
||||
|
||||
from paddleformers.transformers.conversion_utils import split_or_merge_func
|
||||
|
||||
fn = split_or_merge_func(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
)
|
||||
vision_config = config.vision_config
|
||||
|
||||
def split_qkv_weight(x):
|
||||
head_dim = vision_config.hidden_size // vision_config.num_heads
|
||||
x = x.reshape(
|
||||
[
|
||||
vision_config.hidden_size,
|
||||
3,
|
||||
vision_config.num_heads,
|
||||
head_dim,
|
||||
]
|
||||
)
|
||||
x = np.split(x, vision_config.tensor_parallel_degree, axis=-2)[vision_config.tensor_parallel_rank]
|
||||
x = x.reshape([vision_config.hidden_size, -1])
|
||||
return x
|
||||
|
||||
def split_qkv_bias(x):
|
||||
head_dim = vision_config.hidden_size // vision_config.num_heads
|
||||
x = x.reshape([3, vision_config.num_heads, head_dim])
|
||||
x = np.split(x, vision_config.tensor_parallel_degree, axis=-2)[vision_config.tensor_parallel_rank]
|
||||
x = x.reshape([-1])
|
||||
return x
|
||||
|
||||
def get_tensor_parallel_split_mappings(depth):
|
||||
final_actions = {}
|
||||
base_actions = {
|
||||
"visual.blocks.0.attn.proj.weight": partial(fn, is_column=False),
|
||||
"visual.blocks.0.mlp.gate_proj.weight": partial(fn, is_column=True),
|
||||
"visual.blocks.0.mlp.gate_proj.bias": partial(fn, is_column=True),
|
||||
"visual.blocks.0.mlp.up_proj.weight": partial(fn, is_column=True),
|
||||
"visual.blocks.0.mlp.up_proj.bias": partial(fn, is_column=True),
|
||||
"visual.blocks.0.mlp.down_proj.weight": partial(fn, is_column=False),
|
||||
"visual.blocks.0.qkv.weight": split_qkv_weight,
|
||||
"visual.blocks.0.qkv.bias": split_qkv_bias,
|
||||
}
|
||||
|
||||
for key, action in base_actions.items():
|
||||
if "blocks.0." in key:
|
||||
for i in range(depth):
|
||||
newkey = key.replace("blocks.0.", f"blocks.{i}.")
|
||||
final_actions[newkey] = action
|
||||
return final_actions
|
||||
|
||||
mappings = get_tensor_parallel_split_mappings(vision_config.depth)
|
||||
return mappings
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
params_dict = dict(self.named_parameters())
|
||||
for param_name, param in params_dict.items():
|
||||
|
||||
@@ -388,7 +388,7 @@ class Qwen2_5_VLPretrainedModel(PretrainedModel):
|
||||
|
||||
fn = split_or_merge_func_v1(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_parallel_degree=config.tensor_model_parallel_size,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
@@ -397,7 +397,7 @@ class Qwen2_5_VLPretrainedModel(PretrainedModel):
|
||||
|
||||
vision_fn = split_or_merge_func_v1(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_parallel_degree=config.tensor_model_parallel_size,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.vision_config.get("num_heads"),
|
||||
num_key_value_heads=config.vision_config.get("num_heads"),
|
||||
|
||||
@@ -386,7 +386,7 @@ class Qwen3PretrainedModel(PretrainedModel):
|
||||
|
||||
fn = split_or_merge_func(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_model_parallel_size=config.tensor_model_parallel_size,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
)
|
||||
@@ -407,7 +407,7 @@ class Qwen3PretrainedModel(PretrainedModel):
|
||||
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
|
||||
base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True)
|
||||
# if we have enough num_key_value_heads to split, then split it.
|
||||
if config.num_key_value_heads % config.tensor_parallel_degree == 0:
|
||||
if config.num_key_value_heads % config.tensor_model_parallel_size == 0:
|
||||
base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
|
||||
base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)
|
||||
|
||||
|
||||
@@ -470,7 +470,7 @@ class Qwen3MoePretrainedModel(PretrainedModel):
|
||||
|
||||
fn = split_or_merge_func(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_model_parallel_size=config.tensor_model_parallel_size,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
)
|
||||
@@ -493,7 +493,7 @@ class Qwen3MoePretrainedModel(PretrainedModel):
|
||||
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
|
||||
base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True)
|
||||
# if we have enough num_key_value_heads to split, then split it.
|
||||
if config.num_key_value_heads % config.tensor_parallel_degree == 0:
|
||||
if config.num_key_value_heads % config.tensor_model_parallel_size == 0:
|
||||
base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
|
||||
base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)
|
||||
base_actions["layers.0.self_attn.k_proj.bias"] = partial(fn, is_column=True)
|
||||
|
||||
@@ -453,7 +453,7 @@ def split_or_merge_func_v1(
|
||||
else:
|
||||
func = split_or_merge_func(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=tensor_parallel_degree,
|
||||
tensor_model_parallel_size=tensor_parallel_degree,
|
||||
tensor_parallel_rank=tensor_parallel_rank,
|
||||
num_attention_heads=num_attention_heads,
|
||||
)
|
||||
|
||||
@@ -815,8 +815,8 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--load_choices",
|
||||
type=str,
|
||||
default="default",
|
||||
help="The format of the model weights to load. default/new_loader.",
|
||||
default="default_v1",
|
||||
help="The format of the model weights to load. default/default_v1.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -952,7 +952,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
|
||||
# Note(tangbinhan): used for load_checkpoint
|
||||
model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank
|
||||
model_config.pretrained_config.tensor_parallel_degree = parallel_config.tensor_parallel_size
|
||||
model_config.pretrained_config.tensor_model_parallel_size = parallel_config.tensor_parallel_size
|
||||
model_config.pretrained_config.is_mtp = False
|
||||
model_config.pretrained_config.head_dim = model_config.head_dim
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ tqdm
|
||||
pynvml
|
||||
uvicorn>=0.38.0
|
||||
fastapi
|
||||
paddleformers>=0.3.1
|
||||
paddleformers @ https://paddle-qa.bj.bcebos.com/ernie/paddleformers-0.4.0.post20251222-py3-none-any.whl
|
||||
redis
|
||||
etcd3
|
||||
httpx
|
||||
|
||||
@@ -10,7 +10,7 @@ tqdm
|
||||
pynvml
|
||||
uvicorn==0.29.0
|
||||
fastapi
|
||||
paddleformers
|
||||
paddleformers @ https://paddle-qa.bj.bcebos.com/ernie/paddleformers-0.4.0.post20251222-py3-none-any.whl
|
||||
redis
|
||||
etcd3
|
||||
httpx
|
||||
|
||||
@@ -10,7 +10,7 @@ tqdm
|
||||
pynvml
|
||||
uvicorn==0.29.0
|
||||
fastapi
|
||||
paddleformers==0.4.0
|
||||
paddleformers @ https://paddle-qa.bj.bcebos.com/ernie/paddleformers-0.4.0.post20251222-py3-none-any.whl
|
||||
redis
|
||||
etcd3
|
||||
httpx
|
||||
|
||||
@@ -10,7 +10,7 @@ tqdm
|
||||
pynvml
|
||||
uvicorn==0.29.0
|
||||
fastapi
|
||||
paddleformers==0.3.2
|
||||
paddleformers @ https://paddle-qa.bj.bcebos.com/ernie/paddleformers-0.4.0.post20251222-py3-none-any.whl
|
||||
redis
|
||||
etcd3
|
||||
httpx
|
||||
|
||||
@@ -106,13 +106,13 @@ def _install_dependency_stubs():
|
||||
|
||||
conversion_utils = types.ModuleType("paddleformers.transformers.conversion_utils")
|
||||
|
||||
def _split_or_merge_func(is_split, tensor_parallel_degree, tensor_parallel_rank, **_kwargs):
|
||||
def _split_or_merge_func(is_split, tensor_model_parallel_size, tensor_parallel_rank, **_kwargs):
|
||||
axis = -1
|
||||
|
||||
def _fn(weight, *, is_column=True, **_kwargs):
|
||||
current_axis = axis if is_column else 0
|
||||
if is_split:
|
||||
chunks = np.array_split(weight, tensor_parallel_degree, axis=current_axis)
|
||||
chunks = np.array_split(weight, tensor_model_parallel_size, axis=current_axis)
|
||||
if tensor_parallel_rank is None:
|
||||
return chunks
|
||||
return chunks[tensor_parallel_rank]
|
||||
|
||||
Reference in New Issue
Block a user