[V1 Loader]Ernie VL support loader v1 (#3494)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* ernie vl support new loader

* add unittest

* fix test
This commit is contained in:
YuanRisheng
2025-08-22 11:16:57 +08:00
committed by GitHub
parent 3cc182236a
commit 85fbf5455a
5 changed files with 367 additions and 10 deletions

View File

@@ -15,6 +15,7 @@
"""
from functools import partial
from typing import Optional
import numpy as np
import paddle
@@ -32,7 +33,8 @@ from paddle.nn.functional.flash_attention import (
)
from paddleformers.transformers.model_utils import PretrainedModel
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.layers.utils import divide, get_tensor
from fastdeploy.model_executor.models.utils import set_weight_attrs
from .activation import ACT2FN
from .configuration import DFNRopeVisionTransformerConfig
@@ -153,11 +155,13 @@ class VisionFlashAttention2(nn.Layer):
nn (_type_): _description_
"""
def __init__(self, dim: int, num_heads: int = 16, tensor_parallel_degree: int = 1) -> None:
def __init__(
self, dim: int, num_heads: int = 16, tensor_parallel_degree: int = 1, tensor_parallel_rank: int = 0
) -> None:
super().__init__()
self.num_heads = num_heads
self.tensor_parallel_degree = tensor_parallel_degree
self.tensor_parallel_rank = tensor_parallel_rank
if tensor_parallel_degree > 1:
self.qkv = ColumnParallelLinear(
dim,
@@ -175,11 +179,42 @@ class VisionFlashAttention2(nn.Layer):
input_is_parallel=True,
has_bias=True,
)
set_weight_attrs(self.qkv.weight, {"weight_loader": self.weight_loader})
set_weight_attrs(self.qkv.bias, {"weight_loader": self.weight_loader, "load_bias": True})
set_weight_attrs(self.qkv.bias, {"output_dim": True})
set_weight_attrs(self.proj.weight, {"output_dim": False})
else:
self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
self.proj = nn.Linear(dim, dim)
self.head_dim = dim // num_heads # must added
self.num_heads = num_heads
self.hidden_size = dim
self.num_heads_per_rank = divide(self.num_heads, self.tensor_parallel_degree)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
load_bias = getattr(param, "load_bias", None)
if load_bias:
head_dim = self.hidden_size // self.num_heads
shard_weight = loaded_weight[...].reshape([3, self.num_heads, head_dim])
shard_weight = np.split(shard_weight, self.tensor_parallel_degree, axis=-2)[self.tensor_parallel_rank]
shard_weight = shard_weight.reshape([-1])
else:
shard_weight = loaded_weight[...].reshape(
[
self.hidden_size,
3,
self.num_heads,
self.head_dim,
]
)
shard_weight = np.split(shard_weight, self.tensor_parallel_degree, axis=-2)[self.tensor_parallel_rank]
shard_weight = shard_weight.reshape([self.hidden_size, -1])
shard_weight = get_tensor(shard_weight)
assert param.shape == shard_weight.shape, (
f" Attempted to load weight ({shard_weight.shape}) " f"into parameter ({param.shape})"
)
param.copy_(shard_weight, False)
def forward(
self,
@@ -211,7 +246,6 @@ class VisionFlashAttention2(nn.Layer):
.transpose(perm=[1, 0, 2, 3])
)
q, k, v = qkv.unbind(axis=0)
q = apply_rotary_pos_emb_vision(q.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0)
@@ -233,7 +267,6 @@ class VisionFlashAttention2(nn.Layer):
.squeeze(0)
.reshape([seq_length, -1])
)
attn_output = attn_output.astype(paddle.float32)
attn_output = self.proj(attn_output)
return attn_output
@@ -306,6 +339,9 @@ class VisionMlp(nn.Layer):
input_is_parallel=True,
has_bias=True,
)
set_weight_attrs(self.fc1.weight, {"output_dim": True})
set_weight_attrs(self.fc1.bias, {"output_dim": True})
set_weight_attrs(self.fc2.weight, {"output_dim": False})
else:
self.fc1 = nn.Linear(dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, dim)
@@ -365,6 +401,7 @@ class DFNRopeVisionBlock(nn.Layer):
self,
config,
tensor_parallel_degree: int,
tensor_parallel_rank: int,
attn_implementation: str = "sdpa",
) -> None:
"""_summary_
@@ -382,6 +419,7 @@ class DFNRopeVisionBlock(nn.Layer):
config.embed_dim,
num_heads=config.num_heads,
tensor_parallel_degree=tensor_parallel_degree,
tensor_parallel_rank=tensor_parallel_rank,
)
self.mlp = VisionMlp(
dim=config.embed_dim,
@@ -407,7 +445,9 @@ class DFNRopeVisionBlock(nn.Layer):
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
@@ -478,6 +518,7 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
DFNRopeVisionBlock(
config.vision_config,
config.pretrained_config.tensor_parallel_degree,
config.pretrained_config.tensor_parallel_rank,
)
for _ in range(config.vision_config.depth)
]

View File

@@ -16,6 +16,7 @@
from __future__ import annotations
import inspect
from dataclasses import dataclass
from functools import partial
from typing import Dict, Optional, Union
@@ -562,6 +563,93 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
def name(self):
return "Ernie4_5_VLMoeForConditionalGeneration"
def gate_correction_bias_loader(self, params_dict, loaded_weight_name, loaded_weight):
text_param_name = loaded_weight_name.replace(
"moe_statics.e_score_correction_bias", "text_fused_moe.experts.gate_correction_bias"
)
image_param_name = loaded_weight_name.replace(
"moe_statics.e_score_correction_bias", "image_fused_moe.experts.gate_correction_bias"
)
text_param = params_dict[text_param_name]
image_param = params_dict[image_param_name]
loaded_weight = get_tensor(loaded_weight)
text_param.copy_(loaded_weight[0].unsqueeze(0), False)
image_param.copy_(loaded_weight[1].unsqueeze(0), False)
@paddle.no_grad()
def load_weights(self, weights_iterator) -> None:
"""
Load model parameters from a given weights_iterator object.
Args:
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
"""
from fastdeploy.model_executor.models.utils import default_weight_loader
general_params_mapping = [
# (param_name, weight_name, expert_id, shard_id)
("embed_tokens.embeddings", "embed_tokens", None, None),
("lm_head.linear", "lm_head", None, None),
("mlp.image_fused_moe.gate.weight", "mlp.gate.weight_1", None, "gate"),
("mlp.text_fused_moe.gate.weight", "mlp.gate.weight", None, "gate"),
("resampler_model", "ernie.resampler_model", None, None),
]
text_expert_params_mapping = []
if getattr(self.fd_config.model_config, "moe_num_experts", None) is not None:
text_expert_params_mapping = FusedMoE.make_expert_params_mapping(
num_experts=self.fd_config.model_config.moe_num_experts[0],
ckpt_down_proj_name="down_proj",
ckpt_gate_up_proj_name="up_gate_proj",
param_gate_up_proj_name="text_fused_moe.experts.up_gate_proj_",
param_down_proj_name="text_fused_moe.experts.down_proj_",
)
image_expert_params_mapping = FusedMoE.make_expert_params_mapping(
num_experts=self.fd_config.model_config.moe_num_experts[1],
ckpt_down_proj_name="down_proj",
ckpt_gate_up_proj_name="up_gate_proj",
param_gate_up_proj_name="image_fused_moe.experts.up_gate_proj_",
param_down_proj_name="image_fused_moe.experts.down_proj_",
experts_offset=self.fd_config.model_config.moe_num_experts[0],
)
all_param_mapping = general_params_mapping + text_expert_params_mapping + image_expert_params_mapping
params_dict = dict(self.named_parameters())
expert_id = None
shard_id = None
for loaded_weight_name, loaded_weight in weights_iterator:
for param_name, weight_name, exp_id, shard_id in all_param_mapping:
if weight_name not in loaded_weight_name:
continue
model_param_name = loaded_weight_name.replace(weight_name, param_name)
param = params_dict[model_param_name]
expert_id = exp_id
shard_id = shard_id
break
else:
# text and image gate_correction_bias is fused in ckpt and need load independently
if "moe_statics.e_score_correction_bias" in loaded_weight_name:
self.gate_correction_bias_loader(params_dict, loaded_weight_name, loaded_weight)
continue
if loaded_weight_name not in params_dict.keys():
continue
model_param_name = loaded_weight_name
param = params_dict[model_param_name]
# Get weight loader from parameter and set weight
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
sig = inspect.signature(weight_loader)
if "expert_id" in sig.parameters:
weight_loader(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
else:
weight_loader(param, loaded_weight)
if self.tie_word_embeddings:
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
@paddle.no_grad()
def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]):
"""
@@ -715,7 +803,6 @@ class Ernie4_5_VLPretrainedModel(PretrainedModel):
"""
get_tensor_parallel_mappings
"""
logger.info("erine inference model _get_tensor_parallel_mappings")
from fastdeploy.model_executor.models.tp_utils import (
build_expanded_keys,
has_prefix,

View File

@@ -30,6 +30,7 @@ from fastdeploy.model_executor.models.ernie4_5_vl.dist_utils import (
reduce_scatter_group,
scatter_axis,
)
from fastdeploy.model_executor.models.utils import set_weight_attrs
class ScatterOp(PyLayer):
@@ -201,7 +202,6 @@ class VariableResolutionResamplerModel(nn.Layer):
mark_as_sequence_parallel_parameter(self.spatial_linear[idx].bias)
_set_var_distributed(self.spatial_linear[idx].weight, split_axis=0)
_set_var_distributed(self.spatial_linear[idx].bias, split_axis=0)
if self.use_temporal_conv:
for idx in [0, 2, 3]:
mark_as_sequence_parallel_parameter(self.temporal_linear[idx].weight)
@@ -210,6 +210,7 @@ class VariableResolutionResamplerModel(nn.Layer):
mark_as_sequence_parallel_parameter(self.mlp.weight)
mark_as_sequence_parallel_parameter(self.mlp.bias)
mark_as_sequence_parallel_parameter(self.after_norm.weight)
set_weight_attrs(self.spatial_linear[0].weight, {"output_dim": False})
def spatial_conv_reshape(self, x, spatial_conv_size):
"""