[Feature] ernie4_5_vl_moe support huggingface safetensor loading (#3750)

* update

* update

* update in tp

* add todo

* update

---------

Co-authored-by: aquagull <hongyuh@qq.com>
This commit is contained in:
Ayakouji
2025-09-03 17:58:59 +08:00
committed by GitHub
parent 4c998c3636
commit 31313e0f3d
4 changed files with 59 additions and 12 deletions

View File

@@ -100,13 +100,11 @@ class VocabParallelEmbedding(nn.Layer):
state_dict (dict): A dictionary containing the checkpoint weights and biases. state_dict (dict): A dictionary containing the checkpoint weights and biases.
""" """
if self.tie_word_embeddings: if self.tie_word_embeddings:
self.embeddings.weight.set_value( weight_tensor = get_tensor(state_dict[self.prefix + ".weight"]).astype(paddle.get_default_dtype())
get_tensor(state_dict[self.prefix + ".weight"]).astype(paddle.get_default_dtype())
)
else: else:
self.embeddings.weight.set_value( weight_tensor = get_tensor(state_dict.pop(self.prefix + ".weight")).astype(paddle.get_default_dtype())
get_tensor(state_dict.pop(self.prefix + ".weight")).astype(paddle.get_default_dtype())
) self.embeddings.weight.set_value(weight_tensor)
def forward(self, ids_remove_padding=None) -> paddle.Tensor: def forward(self, ids_remove_padding=None) -> paddle.Tensor:
""" """

View File

@@ -156,7 +156,12 @@ class VisionFlashAttention2(nn.Layer):
""" """
def __init__( def __init__(
self, dim: int, num_heads: int = 16, tensor_parallel_degree: int = 1, tensor_parallel_rank: int = 0 self,
dim: int,
num_heads: int = 16,
tensor_parallel_degree: int = 1,
tensor_parallel_rank: int = 0,
model_format: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
@@ -180,19 +185,25 @@ class VisionFlashAttention2(nn.Layer):
has_bias=True, has_bias=True,
) )
set_weight_attrs(self.qkv.weight, {"weight_loader": self.weight_loader}) set_weight_attrs(self.qkv.weight, {"weight_loader": self.weight_loader})
set_weight_attrs(self.qkv.bias, {"weight_loader": self.weight_loader, "load_bias": True}) set_weight_attrs(
set_weight_attrs(self.qkv.bias, {"output_dim": True}) self.qkv.bias, {"weight_loader": self.weight_loader, "load_bias": True, "output_dim": True}
)
set_weight_attrs(self.proj.weight, {"output_dim": False}) set_weight_attrs(self.proj.weight, {"output_dim": False})
else: else:
self.qkv = nn.Linear(dim, dim * 3, bias_attr=True) self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
self.proj = nn.Linear(dim, dim) self.proj = nn.Linear(dim, dim)
set_weight_attrs(self.qkv.weight, {"model_format": model_format})
set_weight_attrs(self.proj.weight, {"model_format": model_format})
self.head_dim = dim // num_heads # must added self.head_dim = dim // num_heads # must added
self.num_heads = num_heads self.num_heads = num_heads
self.hidden_size = dim self.hidden_size = dim
self.num_heads_per_rank = divide(self.num_heads, self.tensor_parallel_degree) self.num_heads_per_rank = divide(self.num_heads, self.tensor_parallel_degree)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
model_format = getattr(param, "model_format", "")
if model_format == "torch":
loaded_weight = loaded_weight.transpose([1, 0])
load_bias = getattr(param, "load_bias", None) load_bias = getattr(param, "load_bias", None)
if load_bias: if load_bias:
head_dim = self.hidden_size // self.num_heads head_dim = self.hidden_size // self.num_heads
@@ -320,6 +331,7 @@ class VisionMlp(nn.Layer):
hidden_dim: int, hidden_dim: int,
hidden_act: str, hidden_act: str,
tensor_parallel_degree: int = 1, tensor_parallel_degree: int = 1,
model_format: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.tensor_parallel_degree = tensor_parallel_degree self.tensor_parallel_degree = tensor_parallel_degree
@@ -345,6 +357,10 @@ class VisionMlp(nn.Layer):
else: else:
self.fc1 = nn.Linear(dim, hidden_dim) self.fc1 = nn.Linear(dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, dim) self.fc2 = nn.Linear(hidden_dim, dim)
set_weight_attrs(self.fc1.weight, {"model_format": model_format})
set_weight_attrs(self.fc2.weight, {"model_format": model_format})
self.act = ACT2FN[hidden_act] self.act = ACT2FN[hidden_act]
def forward(self, x) -> paddle.Tensor: def forward(self, x) -> paddle.Tensor:
@@ -403,6 +419,7 @@ class DFNRopeVisionBlock(nn.Layer):
tensor_parallel_degree: int, tensor_parallel_degree: int,
tensor_parallel_rank: int, tensor_parallel_rank: int,
attn_implementation: str = "sdpa", attn_implementation: str = "sdpa",
model_format: str = "",
) -> None: ) -> None:
"""_summary_ """_summary_
@@ -420,12 +437,14 @@ class DFNRopeVisionBlock(nn.Layer):
num_heads=config.num_heads, num_heads=config.num_heads,
tensor_parallel_degree=tensor_parallel_degree, tensor_parallel_degree=tensor_parallel_degree,
tensor_parallel_rank=tensor_parallel_rank, tensor_parallel_rank=tensor_parallel_rank,
model_format=model_format,
) )
self.mlp = VisionMlp( self.mlp = VisionMlp(
dim=config.embed_dim, dim=config.embed_dim,
hidden_dim=mlp_hidden_dim, hidden_dim=mlp_hidden_dim,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
tensor_parallel_degree=tensor_parallel_degree, tensor_parallel_degree=tensor_parallel_degree,
model_format=model_format,
) )
self.config = config self.config = config
@@ -509,6 +528,8 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
in_channels=config.vision_config.in_channels, in_channels=config.vision_config.in_channels,
embed_dim=config.vision_config.embed_dim, embed_dim=config.vision_config.embed_dim,
) )
model_format = getattr(config, "model_format", "")
set_weight_attrs(self.patch_embed.proj.weight, {"model_format": model_format})
head_dim = config.vision_config.embed_dim // config.vision_config.num_heads head_dim = config.vision_config.embed_dim // config.vision_config.num_heads
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
@@ -519,6 +540,7 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
config.vision_config, config.vision_config,
config.pretrained_config.tensor_parallel_degree, config.pretrained_config.tensor_parallel_degree,
config.pretrained_config.tensor_parallel_rank, config.pretrained_config.tensor_parallel_rank,
model_format=model_format,
) )
for _ in range(config.vision_config.depth) for _ in range(config.vision_config.depth)
] ]

View File

@@ -156,6 +156,13 @@ class Ernie4_5_VLMoeBlock(nn.Layer):
weight_key="weight" if moe_tag == "Text" else "weight_1", weight_key="weight" if moe_tag == "Text" else "weight_1",
) )
# TODO(hehongyu): remove this after fix model network
setattr(
self.gate.weight,
"model_format",
"",
)
def forward(self, hidden_states: paddle.Tensor): def forward(self, hidden_states: paddle.Tensor):
out = self.experts(hidden_states, self.gate) out = self.experts(hidden_states, self.gate)
return out return out
@@ -609,6 +616,13 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
("resampler_model", "ernie.resampler_model", None, None), ("resampler_model", "ernie.resampler_model", None, None),
("vision_model", "ernie.vision_model", None, None), ("vision_model", "ernie.vision_model", None, None),
("gate_correction_bias", "moe_statics.e_score_correction_bias", None, None), ("gate_correction_bias", "moe_statics.e_score_correction_bias", None, None),
# for torch model
("resampler_model", "model.resampler_model", None, None),
("qkv_proj", "q_proj", None, "q"),
("qkv_proj", "k_proj", None, "k"),
("qkv_proj", "v_proj", None, "v"),
("up_gate_proj", "gate_proj", None, "gate"),
("up_gate_proj", "up_proj", None, "up"),
] ]
text_expert_params_mapping = [] text_expert_params_mapping = []
@@ -617,6 +631,8 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
num_experts=self.fd_config.model_config.moe_num_experts[0], num_experts=self.fd_config.model_config.moe_num_experts[0],
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_gate_up_proj_name="up_gate_proj", ckpt_gate_up_proj_name="up_gate_proj",
ckpt_gate_proj_name="gate_proj",
ckpt_up_proj_name="up_proj",
param_gate_up_proj_name="text_fused_moe.experts.up_gate_proj_", param_gate_up_proj_name="text_fused_moe.experts.up_gate_proj_",
param_down_proj_name="text_fused_moe.experts.down_proj_", param_down_proj_name="text_fused_moe.experts.down_proj_",
) )
@@ -624,6 +640,8 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
num_experts=self.fd_config.model_config.moe_num_experts[1], num_experts=self.fd_config.model_config.moe_num_experts[1],
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_gate_up_proj_name="up_gate_proj", ckpt_gate_up_proj_name="up_gate_proj",
ckpt_gate_proj_name="gate_proj",
ckpt_up_proj_name="up_proj",
param_gate_up_proj_name="image_fused_moe.experts.up_gate_proj_", param_gate_up_proj_name="image_fused_moe.experts.up_gate_proj_",
param_down_proj_name="image_fused_moe.experts.down_proj_", param_down_proj_name="image_fused_moe.experts.down_proj_",
experts_offset=self.fd_config.model_config.moe_num_experts[0], experts_offset=self.fd_config.model_config.moe_num_experts[0],
@@ -637,9 +655,12 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
shard_id = None shard_id = None
for loaded_weight_name, loaded_weight in weights_iterator: for loaded_weight_name, loaded_weight in weights_iterator:
for param_name, weight_name, exp_id, shard_id in all_param_mapping: for param_name, weight_name, exp_id, shard_id in all_param_mapping:
if weight_name not in loaded_weight_name:
continue
model_param_name = loaded_weight_name.replace(weight_name, param_name) model_param_name = loaded_weight_name.replace(weight_name, param_name)
if model_param_name.startswith("model.") and self.fd_config.model_config.model_format == "torch":
model_param_name = model_param_name.replace("model.", "ernie.")
if model_param_name not in params_dict:
continue
param = params_dict[model_param_name] param = params_dict[model_param_name]
expert_id = exp_id expert_id = exp_id
shard_id = shard_id shard_id = shard_id
@@ -657,7 +678,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
if "expert_id" in sig.parameters: if "expert_id" in sig.parameters:
weight_loader(param, loaded_weight, expert_id=expert_id, shard_id=shard_id) weight_loader(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
else: else:
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight, shard_id)
model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name) model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name)
process_weights_after_loading_fn(model_sublayer_name, param) process_weights_after_loading_fn(model_sublayer_name, param)
if self.tie_word_embeddings: if self.tie_word_embeddings:

View File

@@ -181,6 +181,8 @@ class VariableResolutionResamplerModel(nn.Layer):
nn.Linear(self.spatial_dim, self.spatial_dim), nn.Linear(self.spatial_dim, self.spatial_dim),
nn.LayerNorm(self.spatial_dim, epsilon=1e-6), nn.LayerNorm(self.spatial_dim, epsilon=1e-6),
) )
set_weight_attrs(self.spatial_linear[0].weight, {"model_format": config.model_format})
set_weight_attrs(self.spatial_linear[2].weight, {"model_format": config.model_format})
if self.use_temporal_conv: if self.use_temporal_conv:
self.temporal_linear = nn.Sequential( self.temporal_linear = nn.Sequential(
@@ -189,9 +191,13 @@ class VariableResolutionResamplerModel(nn.Layer):
nn.Linear(self.spatial_dim, self.spatial_dim), nn.Linear(self.spatial_dim, self.spatial_dim),
nn.LayerNorm(self.spatial_dim, epsilon=1e-6), nn.LayerNorm(self.spatial_dim, epsilon=1e-6),
) )
set_weight_attrs(self.temporal_linear[0].weight, {"model_format": config.model_format})
set_weight_attrs(self.temporal_linear[2].weight, {"model_format": config.model_format})
self.mlp = nn.Linear(self.spatial_dim, self.out_dim) self.mlp = nn.Linear(self.spatial_dim, self.out_dim)
set_weight_attrs(self.mlp.weight, {"model_format": config.model_format})
out_config = deepcopy(config) out_config = deepcopy(config)
out_config.hidden_size = out_dim out_config.hidden_size = out_dim
self.after_norm = RMSNorm(out_config) self.after_norm = RMSNorm(out_config)