fix and refine vl (#2866)

* refine vl config

* delete attn_sep

* fix vl accuracy
This commit is contained in:
Yuanle Liu
2025-07-16 20:59:28 +08:00
committed by GitHub
parent aa76085d1f
commit 63d6e7ce06
11 changed files with 63 additions and 117 deletions

View File

@@ -71,6 +71,7 @@ class Ernie4_5_MLP(nn.Layer):
input_size=intermediate_size,
output_size=fd_config.model_config.hidden_size,
with_bias=False,
reduce_results=reduce_results,
)
self.act_fn = SiluAndMul(

View File

@@ -46,7 +46,6 @@ class DFNRopeVisionTransformerConfig(PretrainedConfig):
attn_implementation="eager", # new added
pp_data_balance=False,
recompute=False,
attn_sep=False,
vit_first_fwd_bsz=128,
vit_num_recompute_layers=10000,
**kwargs,
@@ -65,6 +64,5 @@ class DFNRopeVisionTransformerConfig(PretrainedConfig):
self.attn_implementation = attn_implementation
self.pp_data_balance = pp_data_balance
self.recompute = recompute
self.attn_sep = attn_sep
self.vit_first_fwd_bsz = vit_first_fwd_bsz
self.vit_num_recompute_layers = vit_num_recompute_layers

View File

@@ -143,30 +143,6 @@ def apply_rotary_pos_emb_vision(tensor: paddle.Tensor,
return output
def qkv_reshard_head(tensor, group):
"""
将qkv在seq维度拼接后一起做切分维度的转换
"""
parallelism = group.nranks
qkv_seqlen, head_num, head_dim = tensor.shape
tensor = tensor.transpose(perm=[1, 0, 2]).contiguous()
out = _AllToAll.apply(tensor, group)
out = paddle.split(out, parallelism, axis=0)
output_q = []
output_k = []
output_v = []
for output_i in out:
outout = output_i.transpose(perm=[1, 0, 2]).contiguous()
output = paddle.split(outout, 3, axis=0)
output_q.append(output[0])
output_k.append(output[1])
output_v.append(output[2])
q = paddle.concat(output_q, axis=0)
k = paddle.concat(output_k, axis=0)
v = paddle.concat(output_v, axis=0)
return q, k, v
class VisionFlashAttention2(nn.Layer):
"""_summary_
@@ -211,7 +187,6 @@ class VisionFlashAttention2(nn.Layer):
hidden_states: paddle.Tensor,
cu_seqlens: paddle.Tensor,
rotary_pos_emb: paddle.Tensor = None,
attn_sep=False,
) -> paddle.Tensor:
"""_summary_
@@ -229,13 +204,6 @@ class VisionFlashAttention2(nn.Layer):
-1]).transpose(perm=[1, 0, 2, 3])
q, k, v = qkv.unbind(axis=0)
if attn_sep:
hcg = get_hcg()
mp_group = hcg.get_model_parallel_group()
qkv = paddle.concat([q, k, v], axis=0)
q, k, v = qkv_reshard_head(qkv, mp_group)
seq_length = q.shape[0]
q = apply_rotary_pos_emb_vision(q.unsqueeze(axis=0),
rotary_pos_emb).squeeze(axis=0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(axis=0),
@@ -256,10 +224,7 @@ class VisionFlashAttention2(nn.Layer):
max_seqlen,
scale=softmax_scale, # TODO: 需要手动加上
)[0].squeeze(0).reshape([seq_length, -1]))
if attn_sep:
out = _AllToAll.apply(attn_output, mp_group)
out = paddle.split(out, mp_group.nranks, axis=0)
attn_output = paddle.concat(out, axis=1)
attn_output = attn_output.astype(paddle.float32)
attn_output = self.proj(attn_output)
return attn_output
@@ -389,7 +354,7 @@ class DFNRopeVisionBlock(nn.Layer):
nn (_type_): _description_
"""
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
def __init__(self, config, tensor_parallel_degree: int, attn_implementation: str = "sdpa") -> None:
"""_summary_
Args:
@@ -404,19 +369,18 @@ class DFNRopeVisionBlock(nn.Layer):
self.attn = VisionFlashAttention2(
config.embed_dim,
num_heads=config.num_heads,
tensor_parallel_degree=config.tensor_parallel_degree)
tensor_parallel_degree=tensor_parallel_degree)
self.mlp = VisionMlp(
dim=config.embed_dim,
hidden_dim=mlp_hidden_dim,
hidden_act=config.hidden_act,
tensor_parallel_degree=config.tensor_parallel_degree)
tensor_parallel_degree=tensor_parallel_degree)
self.config = config
def forward(self,
hidden_states,
cu_seqlens,
rotary_pos_emb,
attn_sep=False) -> paddle.Tensor:
rotary_pos_emb) -> paddle.Tensor:
"""_summary_
Args:
@@ -431,7 +395,6 @@ class DFNRopeVisionBlock(nn.Layer):
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
attn_sep=attn_sep,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
@@ -490,26 +453,26 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
config_class = DFNRopeVisionTransformerConfig
def __init__(self, config, prefix_name: str = "") -> None:
super().__init__(config)
self.spatial_merge_size = config.spatial_merge_size
super().__init__(config.vision_config)
self.spatial_merge_size = config.vision_config.spatial_merge_size
self.prefix_name = prefix_name
self.patch_embed = PatchEmbed(
patch_size=config.patch_size,
in_channels=config.in_channels,
embed_dim=config.embed_dim,
patch_size=config.vision_config.patch_size,
in_channels=config.vision_config.in_channels,
embed_dim=config.vision_config.embed_dim,
)
head_dim = config.embed_dim // config.num_heads
head_dim = config.vision_config.embed_dim // config.vision_config.num_heads
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.LayerList(
[DFNRopeVisionBlock(config) for _ in range(config.depth)])
[DFNRopeVisionBlock(config.vision_config, config.pretrained_config.tensor_parallel_degree) for _ in range(config.vision_config.depth)])
assert (
config.hidden_size == config.embed_dim
config.vision_config.hidden_size == config.vision_config.embed_dim
), "in DFNRope, vit's config.hidden must be equal to config.embed_dim"
# self.merger = PatchMerger(dim=config.hidden_size, context_dim=config.embed_dim)
self.ln = nn.LayerNorm(config.hidden_size, epsilon=1e-6)
self.ln = nn.LayerNorm(config.vision_config.hidden_size, epsilon=1e-6)
def get_dtype(self) -> paddle.dtype:
"""_summary_
@@ -593,7 +556,6 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
else:
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
attn_sep = getattr(self.config, "attn_sep", False)
vit_num_recompute_layers = getattr(self.config,
"vit_num_recompute_layers",
self.config.depth)
@@ -601,13 +563,12 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
for idx, blk in enumerate(self.blocks):
if self.config.recompute and self.training and idx < vit_num_recompute_layers:
hidden_states = recompute(blk, hidden_states, cu_seqlens,
rotary_pos_emb, attn_sep)
rotary_pos_emb)
else:
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
attn_sep=attn_sep,
)
# ret = self.merger(hidden_states)

View File

@@ -123,6 +123,7 @@ class Ernie4_5_VLMoE(nn.Layer):
fd_config=fd_config,
intermediate_size=fd_config.model_config.intermediate_size,
prefix=f"{prefix}",
reduce_results=False,
)
assert image_moe_layer_start_index <= image_moe_layer_end_index
@@ -155,6 +156,7 @@ class Ernie4_5_VLMoE(nn.Layer):
fd_config=fd_config,
intermediate_size=fd_config.model_config.intermediate_size,
prefix=f"{prefix}",
reduce_results=False,
)
self.num_shared_experts = fd_config.model_config.moe_num_shared_experts
@@ -471,8 +473,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
"""
super(Ernie4_5_VLMoeForConditionalGeneration, self).__init__(fd_config)
# ----------- vision model ------------
vision_config = fd_config.model_config.vision_config
self.vision_model = self._init_vision_model(vision_config)
self.vision_model = self._init_vision_model(fd_config.model_config)
# ----------- resampler_model ------------
self.resampler_model = self._init_resampler_model_model(
fd_config.model_config
@@ -490,12 +491,12 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
)
self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings
def _init_vision_model(self, vision_config) -> nn.Layer:
def _init_vision_model(self, model_config) -> nn.Layer:
from fastdeploy.model_executor.models.ernie4_5_vl.dfnrope.modeling import \
DFNRopeVisionTransformerPretrainedModel
vision_model = DFNRopeVisionTransformerPretrainedModel(
vision_config, prefix_name="vision_model"
model_config, prefix_name="vision_model"
)
vision_model = paddle.amp.decorate(
models=vision_model, level="O2", dtype="bfloat16"
@@ -508,7 +509,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
VariableResolutionResamplerModel
resampler_model = VariableResolutionResamplerModel(
model_config.pixel_hidden_size,
model_config.vision_config.hidden_size,
model_config.hidden_size,
model_config.spatial_conv_size,
model_config.temporal_conv_size,

View File

@@ -104,7 +104,7 @@ class RMSNorm(nn.Layer):
self.variance_epsilon = config.rms_norm_eps
self.config = config
if config.sequence_parallel:
if getattr(config, "sequence_parallel", False):
mark_as_sequence_parallel_parameter(self.weight)
def forward(self, hidden_states):
@@ -118,7 +118,6 @@ class RMSNorm(nn.Layer):
Tensor: Normalized output tensor of same shape as input
Note:
- Uses fused kernel if config.fuse_rms_norm is True for better performance
- Otherwise computes RMSNorm manually:
1. Compute variance of features
2. Apply reciprocal square root normalization
@@ -146,9 +145,9 @@ class VariableResolutionResamplerModel(nn.Layer):
self.config = config
self.spatial_conv_size = spatial_conv_size
self.temporal_conv_size = temporal_conv_size
self.use_recompute_resampler = config.use_recompute_resampler
self.use_temporal_conv = config.use_temporal_conv
self.tensor_parallel_degree = config.tensor_parallel_degree
self.use_recompute_resampler = False
self.use_temporal_conv = True
self.tensor_parallel_degree = config.pretrained_config.tensor_parallel_degree
self.prefix_name = prefix_name
# for 空间四合一
@@ -165,7 +164,7 @@ class VariableResolutionResamplerModel(nn.Layer):
input_is_parallel=True,
has_bias=True,
fuse_matmul_bias=True,
) if config.tensor_parallel_degree > 1 else nn.Linear(
) if self.tensor_parallel_degree > 1 else nn.Linear(
self.spatial_dim, self.spatial_dim)),
nn.GELU(),
nn.Linear(self.spatial_dim, self.spatial_dim),
@@ -184,11 +183,9 @@ class VariableResolutionResamplerModel(nn.Layer):
out_config = deepcopy(config)
out_config.hidden_size = out_dim
# Note(GuoxiaWang): fuse can reduce gpu peak memory
out_config.fuse_rms_norm = out_config.resampler_fuse_rms_norm
self.after_norm = RMSNorm(out_config)
if config.tensor_parallel_degree > 1:
if self.tensor_parallel_degree > 1:
for idx in [2, 3]:
mark_as_sequence_parallel_parameter(
self.spatial_linear[idx].weight)