From 63d6e7ce060bb50a07abf591b8c39829f8087e20 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Wed, 16 Jul 2025 20:59:28 +0800 Subject: [PATCH] fix and refine vl (#2866) * refine vl config * delete attn_sep * fix vl accuracy --- fastdeploy/config.py | 47 +++++-------- .../model_executor/layers/embeddings.py | 1 - .../model_executor/layers/rotary_embedding.py | 16 ++--- .../model_executor/models/ernie4_5_moe.py | 1 + .../ernie4_5_vl/dfnrope/configuration.py | 2 - .../models/ernie4_5_vl/dfnrope/modeling.py | 69 ++++--------------- .../models/ernie4_5_vl/ernie4_5_vl_moe.py | 11 +-- .../models/ernie4_5_vl/modeling_resampler.py | 15 ++-- fastdeploy/rl/rollout_config.py | 2 +- fastdeploy/worker/gpu_model_runner.py | 13 ++-- fastdeploy/worker/worker_process.py | 3 + 11 files changed, 63 insertions(+), 117 deletions(-) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 55bf59b8e..34c295ac3 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -21,7 +21,6 @@ from enum import Enum from typing import Literal, Optional from paddleformers.transformers.configuration_utils import PretrainedConfig -from paddleformers.trl import llm_utils from fastdeploy import envs from fastdeploy.model_executor.layers.quantization.quant_base import \ @@ -39,27 +38,23 @@ class MoEPhase(Enum): DECODER = 2 PRETRAINED_INIT_CONFIGURATION = { - "rope_theta": 10000.0, - "num_key_value_heads":-1, - "start_layer_index": 0, - "moe_num_shared_experts":0, - "moe_layer_start_index": 0, - "num_max_dispatch_tokens_per_rank":256, - "moe_use_aux_free":False, - "vocab_size": -1, + "rope_theta" : 10000.0, + "num_key_value_heads" : -1, + "start_layer_index" : 0, + "moe_num_shared_experts" : 0, + "moe_layer_start_index" : 0, + "num_max_dispatch_tokens_per_rank" : 256, + "moe_use_aux_free" : False, + "vocab_size" : -1, "use_rope": True, - "hidden_dropout_prob":0.0, - "initializer_range":0.02, - "max_position_embeddings":512, - "quantization_config":None, - "use_recompute_resampler":False, - "use_temporal_conv":True, - "resampler_fuse_rms_norm":False, - "freq_allocation":20, - "tie_word_embeddings":False, - "rms_norm_eps":1e-5, - "moe_num_experts": None, - "moe_layer_end_index":None, + "hidden_dropout_prob" : 0.0, + "initializer_range" : 0.02, + "max_position_embeddings" : 512, + "quantization_config" : None, + "tie_word_embeddings" : False, + "rms_norm_eps" : 1e-5, + "moe_num_experts" : None, + "moe_layer_end_index" : None, } @@ -84,9 +79,6 @@ class ModelConfig: self.min_length = 1 self.model_name_or_path = "" - self.im_patch_id = ( - 100295 # multimodality, TODO(liuyuanle): read from config.json - ) self.is_quantized = False self.max_model_len = 0 self.dtype = "" @@ -130,10 +122,9 @@ class ParallelConfig: self.moe_phase = MoEPhase.PREFILL # Generation phase self.msg_queue_id = 1 # mesage queue id - tensor_parallel_rank, tensor_parallel_size = llm_utils.init_dist_env() - self.tensor_parallel_rank = tensor_parallel_rank # TP rank ID - self.tensor_parallel_size = tensor_parallel_size # TP degree - self.expert_parallel_rank = int(tensor_parallel_rank / tensor_parallel_size) # EP rank ID + self.tensor_parallel_rank = 0 # TP rank ID + self.tensor_parallel_size = 1 # TP degree + self.expert_parallel_rank = 0 # EP rank ID self.expert_parallel_size = 1 # EP degree # The embedding weight distributed on your gpu cards is divided by row or column. # Defaults to False means divide by row. When vocab_size can not be divided by world_size diff --git a/fastdeploy/model_executor/layers/embeddings.py b/fastdeploy/model_executor/layers/embeddings.py index cc446f4bf..c7ad68ec1 100644 --- a/fastdeploy/model_executor/layers/embeddings.py +++ b/fastdeploy/model_executor/layers/embeddings.py @@ -62,7 +62,6 @@ class VocabParallelEmbedding(nn.Layer): self.use_ep: bool = fd_config.parallel_config.use_ep self.hidden_dropout_prob: float = fd_config.model_config.hidden_dropout_prob self.initializer_range: float = fd_config.model_config.initializer_range - self.sequence_parallel: bool = fd_config.parallel_config.sequence_parallel self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings self.params_dtype: str = params_dtype diff --git a/fastdeploy/model_executor/layers/rotary_embedding.py b/fastdeploy/model_executor/layers/rotary_embedding.py index 3266d1097..18bb1be33 100644 --- a/fastdeploy/model_executor/layers/rotary_embedding.py +++ b/fastdeploy/model_executor/layers/rotary_embedding.py @@ -423,11 +423,11 @@ class ErnieVlRotaryEmbedding3D: def get_rope_3d( rotary_dim: int, - base: 10000, - position_ids, - paritial_rotary_factor: 1, - max_position: 131072, - freq_allocation: 2, + base: float, + position_ids: paddle.Tensor, + partial_rotary_factor: float, + max_position: int, + freq_allocation: int, ) -> paddle.Tensor: """ Pre-calculate rotary position embedding for position_ids. @@ -435,19 +435,19 @@ def get_rope_3d( Args: rotary_dim (int): Dimension of rotary embeddings (head dimension) - base (float, optional): + base (float): Base value used to compute the inverse frequencies. Default: 10000.0. position_ids (paddle.Tensor): Tensor containing position indices of input tokens. - partial_rotary_factor (int, optional): + partial_rotary_factor (float): Factor controlling partial rotary application. Default: 1 (apply to all dimensions). max_position: Maximum position index to precompute. freq_allocation: Number of rotary dimensions allocated to temporal axis """ rotary_emb3d_layer = ErnieVlRotaryEmbedding3D(rotary_dim, base, - paritial_rotary_factor, + partial_rotary_factor, max_position, freq_allocation) rotary_emb_3d = rotary_emb3d_layer(position_ids) diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 3c8e0d8e5..98ec7090b 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -71,6 +71,7 @@ class Ernie4_5_MLP(nn.Layer): input_size=intermediate_size, output_size=fd_config.model_config.hidden_size, with_bias=False, + reduce_results=reduce_results, ) self.act_fn = SiluAndMul( diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/configuration.py b/fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/configuration.py index 243b857f4..74c8fbc9f 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/configuration.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/configuration.py @@ -46,7 +46,6 @@ class DFNRopeVisionTransformerConfig(PretrainedConfig): attn_implementation="eager", # new added pp_data_balance=False, recompute=False, - attn_sep=False, vit_first_fwd_bsz=128, vit_num_recompute_layers=10000, **kwargs, @@ -65,6 +64,5 @@ class DFNRopeVisionTransformerConfig(PretrainedConfig): self.attn_implementation = attn_implementation self.pp_data_balance = pp_data_balance self.recompute = recompute - self.attn_sep = attn_sep self.vit_first_fwd_bsz = vit_first_fwd_bsz self.vit_num_recompute_layers = vit_num_recompute_layers diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/modeling.py b/fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/modeling.py index 8c70c146c..ff532dd4c 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/modeling.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/modeling.py @@ -143,30 +143,6 @@ def apply_rotary_pos_emb_vision(tensor: paddle.Tensor, return output -def qkv_reshard_head(tensor, group): - """ - 将qkv在seq维度拼接后一起做切分维度的转换 - """ - parallelism = group.nranks - qkv_seqlen, head_num, head_dim = tensor.shape - tensor = tensor.transpose(perm=[1, 0, 2]).contiguous() - out = _AllToAll.apply(tensor, group) - out = paddle.split(out, parallelism, axis=0) - output_q = [] - output_k = [] - output_v = [] - for output_i in out: - outout = output_i.transpose(perm=[1, 0, 2]).contiguous() - output = paddle.split(outout, 3, axis=0) - output_q.append(output[0]) - output_k.append(output[1]) - output_v.append(output[2]) - q = paddle.concat(output_q, axis=0) - k = paddle.concat(output_k, axis=0) - v = paddle.concat(output_v, axis=0) - return q, k, v - - class VisionFlashAttention2(nn.Layer): """_summary_ @@ -211,7 +187,6 @@ class VisionFlashAttention2(nn.Layer): hidden_states: paddle.Tensor, cu_seqlens: paddle.Tensor, rotary_pos_emb: paddle.Tensor = None, - attn_sep=False, ) -> paddle.Tensor: """_summary_ @@ -229,13 +204,6 @@ class VisionFlashAttention2(nn.Layer): -1]).transpose(perm=[1, 0, 2, 3]) q, k, v = qkv.unbind(axis=0) - if attn_sep: - hcg = get_hcg() - mp_group = hcg.get_model_parallel_group() - qkv = paddle.concat([q, k, v], axis=0) - q, k, v = qkv_reshard_head(qkv, mp_group) - seq_length = q.shape[0] - q = apply_rotary_pos_emb_vision(q.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0) k = apply_rotary_pos_emb_vision(k.unsqueeze(axis=0), @@ -256,10 +224,7 @@ class VisionFlashAttention2(nn.Layer): max_seqlen, scale=softmax_scale, # TODO: 需要手动加上 )[0].squeeze(0).reshape([seq_length, -1])) - if attn_sep: - out = _AllToAll.apply(attn_output, mp_group) - out = paddle.split(out, mp_group.nranks, axis=0) - attn_output = paddle.concat(out, axis=1) + attn_output = attn_output.astype(paddle.float32) attn_output = self.proj(attn_output) return attn_output @@ -389,7 +354,7 @@ class DFNRopeVisionBlock(nn.Layer): nn (_type_): _description_ """ - def __init__(self, config, attn_implementation: str = "sdpa") -> None: + def __init__(self, config, tensor_parallel_degree: int, attn_implementation: str = "sdpa") -> None: """_summary_ Args: @@ -404,19 +369,18 @@ class DFNRopeVisionBlock(nn.Layer): self.attn = VisionFlashAttention2( config.embed_dim, num_heads=config.num_heads, - tensor_parallel_degree=config.tensor_parallel_degree) + tensor_parallel_degree=tensor_parallel_degree) self.mlp = VisionMlp( dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act, - tensor_parallel_degree=config.tensor_parallel_degree) + tensor_parallel_degree=tensor_parallel_degree) self.config = config def forward(self, hidden_states, cu_seqlens, - rotary_pos_emb, - attn_sep=False) -> paddle.Tensor: + rotary_pos_emb) -> paddle.Tensor: """_summary_ Args: @@ -431,7 +395,6 @@ class DFNRopeVisionBlock(nn.Layer): self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, - attn_sep=attn_sep, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states @@ -490,26 +453,26 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel): config_class = DFNRopeVisionTransformerConfig def __init__(self, config, prefix_name: str = "") -> None: - super().__init__(config) - self.spatial_merge_size = config.spatial_merge_size + super().__init__(config.vision_config) + self.spatial_merge_size = config.vision_config.spatial_merge_size self.prefix_name = prefix_name self.patch_embed = PatchEmbed( - patch_size=config.patch_size, - in_channels=config.in_channels, - embed_dim=config.embed_dim, + patch_size=config.vision_config.patch_size, + in_channels=config.vision_config.in_channels, + embed_dim=config.vision_config.embed_dim, ) - head_dim = config.embed_dim // config.num_heads + head_dim = config.vision_config.embed_dim // config.vision_config.num_heads self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) self.blocks = nn.LayerList( - [DFNRopeVisionBlock(config) for _ in range(config.depth)]) + [DFNRopeVisionBlock(config.vision_config, config.pretrained_config.tensor_parallel_degree) for _ in range(config.vision_config.depth)]) assert ( - config.hidden_size == config.embed_dim + config.vision_config.hidden_size == config.vision_config.embed_dim ), "in DFNRope, vit's config.hidden must be equal to config.embed_dim" # self.merger = PatchMerger(dim=config.hidden_size, context_dim=config.embed_dim) - self.ln = nn.LayerNorm(config.hidden_size, epsilon=1e-6) + self.ln = nn.LayerNorm(config.vision_config.hidden_size, epsilon=1e-6) def get_dtype(self) -> paddle.dtype: """_summary_ @@ -593,7 +556,6 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel): else: cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - attn_sep = getattr(self.config, "attn_sep", False) vit_num_recompute_layers = getattr(self.config, "vit_num_recompute_layers", self.config.depth) @@ -601,13 +563,12 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel): for idx, blk in enumerate(self.blocks): if self.config.recompute and self.training and idx < vit_num_recompute_layers: hidden_states = recompute(blk, hidden_states, cu_seqlens, - rotary_pos_emb, attn_sep) + rotary_pos_emb) else: hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, - attn_sep=attn_sep, ) # ret = self.merger(hidden_states) diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index d9503bbe2..92e36220c 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -123,6 +123,7 @@ class Ernie4_5_VLMoE(nn.Layer): fd_config=fd_config, intermediate_size=fd_config.model_config.intermediate_size, prefix=f"{prefix}", + reduce_results=False, ) assert image_moe_layer_start_index <= image_moe_layer_end_index @@ -155,6 +156,7 @@ class Ernie4_5_VLMoE(nn.Layer): fd_config=fd_config, intermediate_size=fd_config.model_config.intermediate_size, prefix=f"{prefix}", + reduce_results=False, ) self.num_shared_experts = fd_config.model_config.moe_num_shared_experts @@ -471,8 +473,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): """ super(Ernie4_5_VLMoeForConditionalGeneration, self).__init__(fd_config) # ----------- vision model ------------ - vision_config = fd_config.model_config.vision_config - self.vision_model = self._init_vision_model(vision_config) + self.vision_model = self._init_vision_model(fd_config.model_config) # ----------- resampler_model ------------ self.resampler_model = self._init_resampler_model_model( fd_config.model_config @@ -490,12 +491,12 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): ) self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings - def _init_vision_model(self, vision_config) -> nn.Layer: + def _init_vision_model(self, model_config) -> nn.Layer: from fastdeploy.model_executor.models.ernie4_5_vl.dfnrope.modeling import \ DFNRopeVisionTransformerPretrainedModel vision_model = DFNRopeVisionTransformerPretrainedModel( - vision_config, prefix_name="vision_model" + model_config, prefix_name="vision_model" ) vision_model = paddle.amp.decorate( models=vision_model, level="O2", dtype="bfloat16" @@ -508,7 +509,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): VariableResolutionResamplerModel resampler_model = VariableResolutionResamplerModel( - model_config.pixel_hidden_size, + model_config.vision_config.hidden_size, model_config.hidden_size, model_config.spatial_conv_size, model_config.temporal_conv_size, diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py b/fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py index c87e6db78..f85ac235c 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py @@ -104,7 +104,7 @@ class RMSNorm(nn.Layer): self.variance_epsilon = config.rms_norm_eps self.config = config - if config.sequence_parallel: + if getattr(config, "sequence_parallel", False): mark_as_sequence_parallel_parameter(self.weight) def forward(self, hidden_states): @@ -118,7 +118,6 @@ class RMSNorm(nn.Layer): Tensor: Normalized output tensor of same shape as input Note: - - Uses fused kernel if config.fuse_rms_norm is True for better performance - Otherwise computes RMSNorm manually: 1. Compute variance of features 2. Apply reciprocal square root normalization @@ -146,9 +145,9 @@ class VariableResolutionResamplerModel(nn.Layer): self.config = config self.spatial_conv_size = spatial_conv_size self.temporal_conv_size = temporal_conv_size - self.use_recompute_resampler = config.use_recompute_resampler - self.use_temporal_conv = config.use_temporal_conv - self.tensor_parallel_degree = config.tensor_parallel_degree + self.use_recompute_resampler = False + self.use_temporal_conv = True + self.tensor_parallel_degree = config.pretrained_config.tensor_parallel_degree self.prefix_name = prefix_name # for 空间四合一 @@ -165,7 +164,7 @@ class VariableResolutionResamplerModel(nn.Layer): input_is_parallel=True, has_bias=True, fuse_matmul_bias=True, - ) if config.tensor_parallel_degree > 1 else nn.Linear( + ) if self.tensor_parallel_degree > 1 else nn.Linear( self.spatial_dim, self.spatial_dim)), nn.GELU(), nn.Linear(self.spatial_dim, self.spatial_dim), @@ -184,11 +183,9 @@ class VariableResolutionResamplerModel(nn.Layer): out_config = deepcopy(config) out_config.hidden_size = out_dim - # Note(GuoxiaWang): fuse can reduce gpu peak memory - out_config.fuse_rms_norm = out_config.resampler_fuse_rms_norm self.after_norm = RMSNorm(out_config) - if config.tensor_parallel_degree > 1: + if self.tensor_parallel_degree > 1: for idx in [2, 3]: mark_as_sequence_parallel_parameter( self.spatial_linear[idx].weight) diff --git a/fastdeploy/rl/rollout_config.py b/fastdeploy/rl/rollout_config.py index ac67d02f0..8aa11897d 100644 --- a/fastdeploy/rl/rollout_config.py +++ b/fastdeploy/rl/rollout_config.py @@ -102,4 +102,4 @@ class RolloutModelConfig: def initialize(self): """Initialize the final fd config""" - return initialize_fd_config(self, self.tensor_parallel_size, 0) + return initialize_fd_config(self, ranks=self.tensor_parallel_size, local_rank=0) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 5bf4ce9c9..80eb3df4e 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1447,20 +1447,15 @@ class GPUModelRunner(ModelRunnerBase): self.fd_config.model_config.tensor_parallel_degree = self.parallel_config.tensor_parallel_size self.fd_config.model_config.tensor_parallel_rank = self.parallel_config.tensor_parallel_rank - self.fd_config.model_config.moe_group="dummy" - self.fd_config.parallel_config.column_cut = False vision_config = self.fd_config.model_config.vision_config - vision_config.attn_sep = False - vision_config.dtype = "bfloat16" + vision_config.dtype = self.fd_config.model_config.dtype vision_config.tensor_parallel_degree = self.parallel_config.tensor_parallel_size vision_config.tensor_parallel_rank = self.parallel_config.tensor_parallel_rank - self.fd_config.model_config.pixel_hidden_size = vision_config.hidden_size self.fd_config.model_config.im_patch_id = tokenizer.get_vocab()[ "<|IMAGE_PLACEHOLDER|>" ] self.fd_config.model_config.think_end_id = tokenizer.get_vocab()[""] - self.fd_config.model_config.max_text_id = self.fd_config.model_config.im_patch_id - self.fd_config.model_config.sequence_parallel = False + self.fd_config.model_config.sequence_parallel = self.parallel_config.sequence_parallel self.model_config = self.fd_config.model_config self._init_image_preprocess() @@ -1558,9 +1553,9 @@ class GPUModelRunner(ModelRunnerBase): rope_emb = get_rope_3d( position_ids=position_ids_3d_real, rotary_dim=self.model_config.head_dim, - paritial_rotary_factor=1.0, + partial_rotary_factor=1.0, base=self.model_config.rope_theta, max_position=self.parallel_config.max_model_len, - freq_allocation=self.model_config.freq_allocation, + freq_allocation=getattr(self.model_config, "freq_allocation", 20), ) return rope_emb diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 4aa420a5e..8775c5de2 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -567,6 +567,9 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: decoding_config = DecodingConfig(vars(args)) speculative_config = SpeculativeConfig(vars(args)) parallel_config = ParallelConfig(vars(args)) + parallel_config.tensor_parallel_rank = local_rank + parallel_config.tensor_parallel_size = ranks + parallel_config.expert_parallel_rank = int(local_rank / ranks) load_config = LoadConfig(vars(args)) graph_opt_config = GraphOptimizationConfig(