mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-16 05:30:58 +08:00
[Feature] ernie4_5_vl_moe
support huggingface safetensor loading (#3750)
* update * update * update in tp * add todo * update --------- Co-authored-by: aquagull <hongyuh@qq.com>
This commit is contained in:
@@ -100,13 +100,11 @@ class VocabParallelEmbedding(nn.Layer):
|
|||||||
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
||||||
"""
|
"""
|
||||||
if self.tie_word_embeddings:
|
if self.tie_word_embeddings:
|
||||||
self.embeddings.weight.set_value(
|
weight_tensor = get_tensor(state_dict[self.prefix + ".weight"]).astype(paddle.get_default_dtype())
|
||||||
get_tensor(state_dict[self.prefix + ".weight"]).astype(paddle.get_default_dtype())
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.embeddings.weight.set_value(
|
weight_tensor = get_tensor(state_dict.pop(self.prefix + ".weight")).astype(paddle.get_default_dtype())
|
||||||
get_tensor(state_dict.pop(self.prefix + ".weight")).astype(paddle.get_default_dtype())
|
|
||||||
)
|
self.embeddings.weight.set_value(weight_tensor)
|
||||||
|
|
||||||
def forward(self, ids_remove_padding=None) -> paddle.Tensor:
|
def forward(self, ids_remove_padding=None) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
|
@@ -156,7 +156,12 @@ class VisionFlashAttention2(nn.Layer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, dim: int, num_heads: int = 16, tensor_parallel_degree: int = 1, tensor_parallel_rank: int = 0
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int = 16,
|
||||||
|
tensor_parallel_degree: int = 1,
|
||||||
|
tensor_parallel_rank: int = 0,
|
||||||
|
model_format: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
@@ -180,19 +185,25 @@ class VisionFlashAttention2(nn.Layer):
|
|||||||
has_bias=True,
|
has_bias=True,
|
||||||
)
|
)
|
||||||
set_weight_attrs(self.qkv.weight, {"weight_loader": self.weight_loader})
|
set_weight_attrs(self.qkv.weight, {"weight_loader": self.weight_loader})
|
||||||
set_weight_attrs(self.qkv.bias, {"weight_loader": self.weight_loader, "load_bias": True})
|
set_weight_attrs(
|
||||||
set_weight_attrs(self.qkv.bias, {"output_dim": True})
|
self.qkv.bias, {"weight_loader": self.weight_loader, "load_bias": True, "output_dim": True}
|
||||||
|
)
|
||||||
set_weight_attrs(self.proj.weight, {"output_dim": False})
|
set_weight_attrs(self.proj.weight, {"output_dim": False})
|
||||||
else:
|
else:
|
||||||
self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
|
self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
|
||||||
self.proj = nn.Linear(dim, dim)
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
set_weight_attrs(self.qkv.weight, {"model_format": model_format})
|
||||||
|
set_weight_attrs(self.proj.weight, {"model_format": model_format})
|
||||||
self.head_dim = dim // num_heads # must added
|
self.head_dim = dim // num_heads # must added
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.hidden_size = dim
|
self.hidden_size = dim
|
||||||
self.num_heads_per_rank = divide(self.num_heads, self.tensor_parallel_degree)
|
self.num_heads_per_rank = divide(self.num_heads, self.tensor_parallel_degree)
|
||||||
|
|
||||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||||
|
model_format = getattr(param, "model_format", "")
|
||||||
|
if model_format == "torch":
|
||||||
|
loaded_weight = loaded_weight.transpose([1, 0])
|
||||||
load_bias = getattr(param, "load_bias", None)
|
load_bias = getattr(param, "load_bias", None)
|
||||||
if load_bias:
|
if load_bias:
|
||||||
head_dim = self.hidden_size // self.num_heads
|
head_dim = self.hidden_size // self.num_heads
|
||||||
@@ -320,6 +331,7 @@ class VisionMlp(nn.Layer):
|
|||||||
hidden_dim: int,
|
hidden_dim: int,
|
||||||
hidden_act: str,
|
hidden_act: str,
|
||||||
tensor_parallel_degree: int = 1,
|
tensor_parallel_degree: int = 1,
|
||||||
|
model_format: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tensor_parallel_degree = tensor_parallel_degree
|
self.tensor_parallel_degree = tensor_parallel_degree
|
||||||
@@ -345,6 +357,10 @@ class VisionMlp(nn.Layer):
|
|||||||
else:
|
else:
|
||||||
self.fc1 = nn.Linear(dim, hidden_dim)
|
self.fc1 = nn.Linear(dim, hidden_dim)
|
||||||
self.fc2 = nn.Linear(hidden_dim, dim)
|
self.fc2 = nn.Linear(hidden_dim, dim)
|
||||||
|
|
||||||
|
set_weight_attrs(self.fc1.weight, {"model_format": model_format})
|
||||||
|
set_weight_attrs(self.fc2.weight, {"model_format": model_format})
|
||||||
|
|
||||||
self.act = ACT2FN[hidden_act]
|
self.act = ACT2FN[hidden_act]
|
||||||
|
|
||||||
def forward(self, x) -> paddle.Tensor:
|
def forward(self, x) -> paddle.Tensor:
|
||||||
@@ -403,6 +419,7 @@ class DFNRopeVisionBlock(nn.Layer):
|
|||||||
tensor_parallel_degree: int,
|
tensor_parallel_degree: int,
|
||||||
tensor_parallel_rank: int,
|
tensor_parallel_rank: int,
|
||||||
attn_implementation: str = "sdpa",
|
attn_implementation: str = "sdpa",
|
||||||
|
model_format: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""_summary_
|
"""_summary_
|
||||||
|
|
||||||
@@ -420,12 +437,14 @@ class DFNRopeVisionBlock(nn.Layer):
|
|||||||
num_heads=config.num_heads,
|
num_heads=config.num_heads,
|
||||||
tensor_parallel_degree=tensor_parallel_degree,
|
tensor_parallel_degree=tensor_parallel_degree,
|
||||||
tensor_parallel_rank=tensor_parallel_rank,
|
tensor_parallel_rank=tensor_parallel_rank,
|
||||||
|
model_format=model_format,
|
||||||
)
|
)
|
||||||
self.mlp = VisionMlp(
|
self.mlp = VisionMlp(
|
||||||
dim=config.embed_dim,
|
dim=config.embed_dim,
|
||||||
hidden_dim=mlp_hidden_dim,
|
hidden_dim=mlp_hidden_dim,
|
||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
tensor_parallel_degree=tensor_parallel_degree,
|
tensor_parallel_degree=tensor_parallel_degree,
|
||||||
|
model_format=model_format,
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
@@ -509,6 +528,8 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
|
|||||||
in_channels=config.vision_config.in_channels,
|
in_channels=config.vision_config.in_channels,
|
||||||
embed_dim=config.vision_config.embed_dim,
|
embed_dim=config.vision_config.embed_dim,
|
||||||
)
|
)
|
||||||
|
model_format = getattr(config, "model_format", "")
|
||||||
|
set_weight_attrs(self.patch_embed.proj.weight, {"model_format": model_format})
|
||||||
|
|
||||||
head_dim = config.vision_config.embed_dim // config.vision_config.num_heads
|
head_dim = config.vision_config.embed_dim // config.vision_config.num_heads
|
||||||
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
||||||
@@ -519,6 +540,7 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
|
|||||||
config.vision_config,
|
config.vision_config,
|
||||||
config.pretrained_config.tensor_parallel_degree,
|
config.pretrained_config.tensor_parallel_degree,
|
||||||
config.pretrained_config.tensor_parallel_rank,
|
config.pretrained_config.tensor_parallel_rank,
|
||||||
|
model_format=model_format,
|
||||||
)
|
)
|
||||||
for _ in range(config.vision_config.depth)
|
for _ in range(config.vision_config.depth)
|
||||||
]
|
]
|
||||||
|
@@ -156,6 +156,13 @@ class Ernie4_5_VLMoeBlock(nn.Layer):
|
|||||||
weight_key="weight" if moe_tag == "Text" else "weight_1",
|
weight_key="weight" if moe_tag == "Text" else "weight_1",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO(hehongyu): remove this after fix model network
|
||||||
|
setattr(
|
||||||
|
self.gate.weight,
|
||||||
|
"model_format",
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: paddle.Tensor):
|
def forward(self, hidden_states: paddle.Tensor):
|
||||||
out = self.experts(hidden_states, self.gate)
|
out = self.experts(hidden_states, self.gate)
|
||||||
return out
|
return out
|
||||||
@@ -609,6 +616,13 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
|||||||
("resampler_model", "ernie.resampler_model", None, None),
|
("resampler_model", "ernie.resampler_model", None, None),
|
||||||
("vision_model", "ernie.vision_model", None, None),
|
("vision_model", "ernie.vision_model", None, None),
|
||||||
("gate_correction_bias", "moe_statics.e_score_correction_bias", None, None),
|
("gate_correction_bias", "moe_statics.e_score_correction_bias", None, None),
|
||||||
|
# for torch model
|
||||||
|
("resampler_model", "model.resampler_model", None, None),
|
||||||
|
("qkv_proj", "q_proj", None, "q"),
|
||||||
|
("qkv_proj", "k_proj", None, "k"),
|
||||||
|
("qkv_proj", "v_proj", None, "v"),
|
||||||
|
("up_gate_proj", "gate_proj", None, "gate"),
|
||||||
|
("up_gate_proj", "up_proj", None, "up"),
|
||||||
]
|
]
|
||||||
|
|
||||||
text_expert_params_mapping = []
|
text_expert_params_mapping = []
|
||||||
@@ -617,6 +631,8 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
|||||||
num_experts=self.fd_config.model_config.moe_num_experts[0],
|
num_experts=self.fd_config.model_config.moe_num_experts[0],
|
||||||
ckpt_down_proj_name="down_proj",
|
ckpt_down_proj_name="down_proj",
|
||||||
ckpt_gate_up_proj_name="up_gate_proj",
|
ckpt_gate_up_proj_name="up_gate_proj",
|
||||||
|
ckpt_gate_proj_name="gate_proj",
|
||||||
|
ckpt_up_proj_name="up_proj",
|
||||||
param_gate_up_proj_name="text_fused_moe.experts.up_gate_proj_",
|
param_gate_up_proj_name="text_fused_moe.experts.up_gate_proj_",
|
||||||
param_down_proj_name="text_fused_moe.experts.down_proj_",
|
param_down_proj_name="text_fused_moe.experts.down_proj_",
|
||||||
)
|
)
|
||||||
@@ -624,6 +640,8 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
|||||||
num_experts=self.fd_config.model_config.moe_num_experts[1],
|
num_experts=self.fd_config.model_config.moe_num_experts[1],
|
||||||
ckpt_down_proj_name="down_proj",
|
ckpt_down_proj_name="down_proj",
|
||||||
ckpt_gate_up_proj_name="up_gate_proj",
|
ckpt_gate_up_proj_name="up_gate_proj",
|
||||||
|
ckpt_gate_proj_name="gate_proj",
|
||||||
|
ckpt_up_proj_name="up_proj",
|
||||||
param_gate_up_proj_name="image_fused_moe.experts.up_gate_proj_",
|
param_gate_up_proj_name="image_fused_moe.experts.up_gate_proj_",
|
||||||
param_down_proj_name="image_fused_moe.experts.down_proj_",
|
param_down_proj_name="image_fused_moe.experts.down_proj_",
|
||||||
experts_offset=self.fd_config.model_config.moe_num_experts[0],
|
experts_offset=self.fd_config.model_config.moe_num_experts[0],
|
||||||
@@ -637,9 +655,12 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
|||||||
shard_id = None
|
shard_id = None
|
||||||
for loaded_weight_name, loaded_weight in weights_iterator:
|
for loaded_weight_name, loaded_weight in weights_iterator:
|
||||||
for param_name, weight_name, exp_id, shard_id in all_param_mapping:
|
for param_name, weight_name, exp_id, shard_id in all_param_mapping:
|
||||||
if weight_name not in loaded_weight_name:
|
|
||||||
continue
|
|
||||||
model_param_name = loaded_weight_name.replace(weight_name, param_name)
|
model_param_name = loaded_weight_name.replace(weight_name, param_name)
|
||||||
|
if model_param_name.startswith("model.") and self.fd_config.model_config.model_format == "torch":
|
||||||
|
model_param_name = model_param_name.replace("model.", "ernie.")
|
||||||
|
|
||||||
|
if model_param_name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[model_param_name]
|
param = params_dict[model_param_name]
|
||||||
expert_id = exp_id
|
expert_id = exp_id
|
||||||
shard_id = shard_id
|
shard_id = shard_id
|
||||||
@@ -657,7 +678,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
|||||||
if "expert_id" in sig.parameters:
|
if "expert_id" in sig.parameters:
|
||||||
weight_loader(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
|
weight_loader(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
|
||||||
else:
|
else:
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name)
|
model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name)
|
||||||
process_weights_after_loading_fn(model_sublayer_name, param)
|
process_weights_after_loading_fn(model_sublayer_name, param)
|
||||||
if self.tie_word_embeddings:
|
if self.tie_word_embeddings:
|
||||||
|
@@ -181,6 +181,8 @@ class VariableResolutionResamplerModel(nn.Layer):
|
|||||||
nn.Linear(self.spatial_dim, self.spatial_dim),
|
nn.Linear(self.spatial_dim, self.spatial_dim),
|
||||||
nn.LayerNorm(self.spatial_dim, epsilon=1e-6),
|
nn.LayerNorm(self.spatial_dim, epsilon=1e-6),
|
||||||
)
|
)
|
||||||
|
set_weight_attrs(self.spatial_linear[0].weight, {"model_format": config.model_format})
|
||||||
|
set_weight_attrs(self.spatial_linear[2].weight, {"model_format": config.model_format})
|
||||||
|
|
||||||
if self.use_temporal_conv:
|
if self.use_temporal_conv:
|
||||||
self.temporal_linear = nn.Sequential(
|
self.temporal_linear = nn.Sequential(
|
||||||
@@ -189,9 +191,13 @@ class VariableResolutionResamplerModel(nn.Layer):
|
|||||||
nn.Linear(self.spatial_dim, self.spatial_dim),
|
nn.Linear(self.spatial_dim, self.spatial_dim),
|
||||||
nn.LayerNorm(self.spatial_dim, epsilon=1e-6),
|
nn.LayerNorm(self.spatial_dim, epsilon=1e-6),
|
||||||
)
|
)
|
||||||
|
set_weight_attrs(self.temporal_linear[0].weight, {"model_format": config.model_format})
|
||||||
|
set_weight_attrs(self.temporal_linear[2].weight, {"model_format": config.model_format})
|
||||||
|
|
||||||
self.mlp = nn.Linear(self.spatial_dim, self.out_dim)
|
self.mlp = nn.Linear(self.spatial_dim, self.out_dim)
|
||||||
|
|
||||||
|
set_weight_attrs(self.mlp.weight, {"model_format": config.model_format})
|
||||||
|
|
||||||
out_config = deepcopy(config)
|
out_config = deepcopy(config)
|
||||||
out_config.hidden_size = out_dim
|
out_config.hidden_size = out_dim
|
||||||
self.after_norm = RMSNorm(out_config)
|
self.after_norm = RMSNorm(out_config)
|
||||||
|
Reference in New Issue
Block a user