mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[V1 Loader]Ernie VL support loader v1 (#3494)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* ernie vl support new loader * add unittest * fix test
This commit is contained in:
@@ -191,7 +191,7 @@ class FusedMoE(nn.Layer):
|
||||
loaded_weight_shard = loaded_weight[..., shard_offset : shard_offset + shard_size]
|
||||
self.weight_loader(param, loaded_weight_shard, expert_id, shard_id)
|
||||
else:
|
||||
expert_param = param[expert_id]
|
||||
expert_param = param[expert_id - self.expert_id_offset]
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
expert_param.copy_(loaded_weight, False)
|
||||
else:
|
||||
@@ -262,7 +262,7 @@ class FusedMoE(nn.Layer):
|
||||
loaded_weight,
|
||||
shard_id,
|
||||
):
|
||||
expert_param = param[expert_id]
|
||||
expert_param = param[expert_id - self.expert_id_offset]
|
||||
if shard_id == "down":
|
||||
self._load_down_weight(expert_param, shard_dim, loaded_weight, shard_id)
|
||||
elif shard_id in ["gate", "up"]:
|
||||
@@ -279,6 +279,7 @@ class FusedMoE(nn.Layer):
|
||||
param_gate_up_proj_name: Optional[str] = None,
|
||||
param_down_proj_name: Optional[str] = None,
|
||||
ckpt_expert_key_name: str = "experts",
|
||||
experts_offset: int = 0,
|
||||
) -> list[tuple[str, str, int, str]]:
|
||||
param_name_maping = []
|
||||
|
||||
@@ -303,7 +304,7 @@ class FusedMoE(nn.Layer):
|
||||
expert_id,
|
||||
shard_id,
|
||||
)
|
||||
for expert_id in range(num_experts)
|
||||
for expert_id in range(experts_offset, experts_offset + num_experts)
|
||||
for shard_id, weight_name in param_name_maping
|
||||
]
|
||||
|
||||
|
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
@@ -32,7 +33,8 @@ from paddle.nn.functional.flash_attention import (
|
||||
)
|
||||
from paddleformers.transformers.model_utils import PretrainedModel
|
||||
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.layers.utils import divide, get_tensor
|
||||
from fastdeploy.model_executor.models.utils import set_weight_attrs
|
||||
|
||||
from .activation import ACT2FN
|
||||
from .configuration import DFNRopeVisionTransformerConfig
|
||||
@@ -153,11 +155,13 @@ class VisionFlashAttention2(nn.Layer):
|
||||
nn (_type_): _description_
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, num_heads: int = 16, tensor_parallel_degree: int = 1) -> None:
|
||||
def __init__(
|
||||
self, dim: int, num_heads: int = 16, tensor_parallel_degree: int = 1, tensor_parallel_rank: int = 0
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.tensor_parallel_degree = tensor_parallel_degree
|
||||
|
||||
self.tensor_parallel_rank = tensor_parallel_rank
|
||||
if tensor_parallel_degree > 1:
|
||||
self.qkv = ColumnParallelLinear(
|
||||
dim,
|
||||
@@ -175,11 +179,42 @@ class VisionFlashAttention2(nn.Layer):
|
||||
input_is_parallel=True,
|
||||
has_bias=True,
|
||||
)
|
||||
set_weight_attrs(self.qkv.weight, {"weight_loader": self.weight_loader})
|
||||
set_weight_attrs(self.qkv.bias, {"weight_loader": self.weight_loader, "load_bias": True})
|
||||
set_weight_attrs(self.qkv.bias, {"output_dim": True})
|
||||
set_weight_attrs(self.proj.weight, {"output_dim": False})
|
||||
else:
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
self.head_dim = dim // num_heads # must added
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = dim
|
||||
self.num_heads_per_rank = divide(self.num_heads, self.tensor_parallel_degree)
|
||||
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
load_bias = getattr(param, "load_bias", None)
|
||||
if load_bias:
|
||||
head_dim = self.hidden_size // self.num_heads
|
||||
shard_weight = loaded_weight[...].reshape([3, self.num_heads, head_dim])
|
||||
shard_weight = np.split(shard_weight, self.tensor_parallel_degree, axis=-2)[self.tensor_parallel_rank]
|
||||
shard_weight = shard_weight.reshape([-1])
|
||||
else:
|
||||
shard_weight = loaded_weight[...].reshape(
|
||||
[
|
||||
self.hidden_size,
|
||||
3,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
]
|
||||
)
|
||||
shard_weight = np.split(shard_weight, self.tensor_parallel_degree, axis=-2)[self.tensor_parallel_rank]
|
||||
shard_weight = shard_weight.reshape([self.hidden_size, -1])
|
||||
shard_weight = get_tensor(shard_weight)
|
||||
assert param.shape == shard_weight.shape, (
|
||||
f" Attempted to load weight ({shard_weight.shape}) " f"into parameter ({param.shape})"
|
||||
)
|
||||
param.copy_(shard_weight, False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -211,7 +246,6 @@ class VisionFlashAttention2(nn.Layer):
|
||||
.transpose(perm=[1, 0, 2, 3])
|
||||
)
|
||||
q, k, v = qkv.unbind(axis=0)
|
||||
|
||||
q = apply_rotary_pos_emb_vision(q.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0)
|
||||
k = apply_rotary_pos_emb_vision(k.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0)
|
||||
|
||||
@@ -233,7 +267,6 @@ class VisionFlashAttention2(nn.Layer):
|
||||
.squeeze(0)
|
||||
.reshape([seq_length, -1])
|
||||
)
|
||||
|
||||
attn_output = attn_output.astype(paddle.float32)
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
@@ -306,6 +339,9 @@ class VisionMlp(nn.Layer):
|
||||
input_is_parallel=True,
|
||||
has_bias=True,
|
||||
)
|
||||
set_weight_attrs(self.fc1.weight, {"output_dim": True})
|
||||
set_weight_attrs(self.fc1.bias, {"output_dim": True})
|
||||
set_weight_attrs(self.fc2.weight, {"output_dim": False})
|
||||
else:
|
||||
self.fc1 = nn.Linear(dim, hidden_dim)
|
||||
self.fc2 = nn.Linear(hidden_dim, dim)
|
||||
@@ -365,6 +401,7 @@ class DFNRopeVisionBlock(nn.Layer):
|
||||
self,
|
||||
config,
|
||||
tensor_parallel_degree: int,
|
||||
tensor_parallel_rank: int,
|
||||
attn_implementation: str = "sdpa",
|
||||
) -> None:
|
||||
"""_summary_
|
||||
@@ -382,6 +419,7 @@ class DFNRopeVisionBlock(nn.Layer):
|
||||
config.embed_dim,
|
||||
num_heads=config.num_heads,
|
||||
tensor_parallel_degree=tensor_parallel_degree,
|
||||
tensor_parallel_rank=tensor_parallel_rank,
|
||||
)
|
||||
self.mlp = VisionMlp(
|
||||
dim=config.embed_dim,
|
||||
@@ -407,7 +445,9 @@ class DFNRopeVisionBlock(nn.Layer):
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -478,6 +518,7 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
|
||||
DFNRopeVisionBlock(
|
||||
config.vision_config,
|
||||
config.pretrained_config.tensor_parallel_degree,
|
||||
config.pretrained_config.tensor_parallel_rank,
|
||||
)
|
||||
for _ in range(config.vision_config.depth)
|
||||
]
|
||||
|
@@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Dict, Optional, Union
|
||||
@@ -562,6 +563,93 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
||||
def name(self):
|
||||
return "Ernie4_5_VLMoeForConditionalGeneration"
|
||||
|
||||
def gate_correction_bias_loader(self, params_dict, loaded_weight_name, loaded_weight):
|
||||
text_param_name = loaded_weight_name.replace(
|
||||
"moe_statics.e_score_correction_bias", "text_fused_moe.experts.gate_correction_bias"
|
||||
)
|
||||
image_param_name = loaded_weight_name.replace(
|
||||
"moe_statics.e_score_correction_bias", "image_fused_moe.experts.gate_correction_bias"
|
||||
)
|
||||
text_param = params_dict[text_param_name]
|
||||
image_param = params_dict[image_param_name]
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
text_param.copy_(loaded_weight[0].unsqueeze(0), False)
|
||||
image_param.copy_(loaded_weight[1].unsqueeze(0), False)
|
||||
|
||||
@paddle.no_grad()
|
||||
def load_weights(self, weights_iterator) -> None:
|
||||
"""
|
||||
Load model parameters from a given weights_iterator object.
|
||||
|
||||
Args:
|
||||
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
|
||||
"""
|
||||
|
||||
from fastdeploy.model_executor.models.utils import default_weight_loader
|
||||
|
||||
general_params_mapping = [
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
("embed_tokens.embeddings", "embed_tokens", None, None),
|
||||
("lm_head.linear", "lm_head", None, None),
|
||||
("mlp.image_fused_moe.gate.weight", "mlp.gate.weight_1", None, "gate"),
|
||||
("mlp.text_fused_moe.gate.weight", "mlp.gate.weight", None, "gate"),
|
||||
("resampler_model", "ernie.resampler_model", None, None),
|
||||
]
|
||||
|
||||
text_expert_params_mapping = []
|
||||
if getattr(self.fd_config.model_config, "moe_num_experts", None) is not None:
|
||||
text_expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
num_experts=self.fd_config.model_config.moe_num_experts[0],
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_gate_up_proj_name="up_gate_proj",
|
||||
param_gate_up_proj_name="text_fused_moe.experts.up_gate_proj_",
|
||||
param_down_proj_name="text_fused_moe.experts.down_proj_",
|
||||
)
|
||||
image_expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
num_experts=self.fd_config.model_config.moe_num_experts[1],
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_gate_up_proj_name="up_gate_proj",
|
||||
param_gate_up_proj_name="image_fused_moe.experts.up_gate_proj_",
|
||||
param_down_proj_name="image_fused_moe.experts.down_proj_",
|
||||
experts_offset=self.fd_config.model_config.moe_num_experts[0],
|
||||
)
|
||||
|
||||
all_param_mapping = general_params_mapping + text_expert_params_mapping + image_expert_params_mapping
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
expert_id = None
|
||||
shard_id = None
|
||||
for loaded_weight_name, loaded_weight in weights_iterator:
|
||||
for param_name, weight_name, exp_id, shard_id in all_param_mapping:
|
||||
if weight_name not in loaded_weight_name:
|
||||
continue
|
||||
model_param_name = loaded_weight_name.replace(weight_name, param_name)
|
||||
param = params_dict[model_param_name]
|
||||
expert_id = exp_id
|
||||
shard_id = shard_id
|
||||
break
|
||||
else:
|
||||
# text and image gate_correction_bias is fused in ckpt and need load independently
|
||||
if "moe_statics.e_score_correction_bias" in loaded_weight_name:
|
||||
self.gate_correction_bias_loader(params_dict, loaded_weight_name, loaded_weight)
|
||||
continue
|
||||
if loaded_weight_name not in params_dict.keys():
|
||||
continue
|
||||
model_param_name = loaded_weight_name
|
||||
param = params_dict[model_param_name]
|
||||
|
||||
# Get weight loader from parameter and set weight
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
||||
sig = inspect.signature(weight_loader)
|
||||
|
||||
if "expert_id" in sig.parameters:
|
||||
weight_loader(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
|
||||
else:
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if self.tie_word_embeddings:
|
||||
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
|
||||
|
||||
@paddle.no_grad()
|
||||
def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]):
|
||||
"""
|
||||
@@ -715,7 +803,6 @@ class Ernie4_5_VLPretrainedModel(PretrainedModel):
|
||||
"""
|
||||
get_tensor_parallel_mappings
|
||||
"""
|
||||
logger.info("erine inference model _get_tensor_parallel_mappings")
|
||||
from fastdeploy.model_executor.models.tp_utils import (
|
||||
build_expanded_keys,
|
||||
has_prefix,
|
||||
|
@@ -30,6 +30,7 @@ from fastdeploy.model_executor.models.ernie4_5_vl.dist_utils import (
|
||||
reduce_scatter_group,
|
||||
scatter_axis,
|
||||
)
|
||||
from fastdeploy.model_executor.models.utils import set_weight_attrs
|
||||
|
||||
|
||||
class ScatterOp(PyLayer):
|
||||
@@ -201,7 +202,6 @@ class VariableResolutionResamplerModel(nn.Layer):
|
||||
mark_as_sequence_parallel_parameter(self.spatial_linear[idx].bias)
|
||||
_set_var_distributed(self.spatial_linear[idx].weight, split_axis=0)
|
||||
_set_var_distributed(self.spatial_linear[idx].bias, split_axis=0)
|
||||
|
||||
if self.use_temporal_conv:
|
||||
for idx in [0, 2, 3]:
|
||||
mark_as_sequence_parallel_parameter(self.temporal_linear[idx].weight)
|
||||
@@ -210,6 +210,7 @@ class VariableResolutionResamplerModel(nn.Layer):
|
||||
mark_as_sequence_parallel_parameter(self.mlp.weight)
|
||||
mark_as_sequence_parallel_parameter(self.mlp.bias)
|
||||
mark_as_sequence_parallel_parameter(self.after_norm.weight)
|
||||
set_weight_attrs(self.spatial_linear[0].weight, {"output_dim": False})
|
||||
|
||||
def spatial_conv_reshape(self, x, spatial_conv_size):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user