diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index f5a0f0421..2dd562135 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -69,6 +69,7 @@ class VLMoEMeta: text_index: Optional[paddle.Tensor] = None image_index: Optional[paddle.Tensor] = None token_type_ids: Optional[paddle.Tensor] = None + fake_hidden_states: Optional[paddle.Tensor] = None class Ernie4_5_VLMoE(nn.Layer): @@ -241,6 +242,8 @@ class Ernie4_5_VLMoE(nn.Layer): ) else: hidden_states = self.text_fused_moe(hidden_states) + if vl_moe_meta.fake_hidden_states is not None: + self.image_fused_moe(vl_moe_meta.fake_hidden_states) if self.num_shared_experts > 0: hidden_states += shared_experts_out if self.tp_size > 1: @@ -362,6 +365,7 @@ class Ernie4_5_VLModel(nn.Layer): self.im_patch_id = fd_config.model_config.im_patch_id self._dtype = fd_config.model_config.dtype fd_config.model_config.pretrained_config.prefix_name = "ernie" + self.fd_config = fd_config self.embed_tokens = VocabParallelEmbedding( fd_config=fd_config, @@ -413,6 +417,7 @@ class Ernie4_5_VLModel(nn.Layer): image_input = None text_index = None image_index = None + fake_hidden_states = None image_token_num = 0 hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding) @@ -423,6 +428,14 @@ class Ernie4_5_VLModel(nn.Layer): token_num = hidden_states.shape[0] image_token_num = paddle.count_nonzero(token_type_ids) text_token_num = paddle.maximum((token_num - image_token_num), paddle.ones([], dtype="int64")) + + if self.fd_config.parallel_config.use_ep is True: + fake_hidden_states = paddle.empty( + shape=[0, self.fd_config.model_config.hidden_size], + dtype=paddle.get_default_dtype(), + ) + text_input = fake_hidden_states + if image_mask.any(): hidden_states[image_mask] = image_features.cast(self._dtype) text_input = paddle.full( @@ -445,6 +458,7 @@ class Ernie4_5_VLModel(nn.Layer): text_index=text_index, image_index=image_index, token_type_ids=token_type_ids, + fake_hidden_states=fake_hidden_states, ) # ----------------------- @@ -580,6 +594,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): self.fd_config.model_config.num_hidden_layers, ): self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states) + self.ernie.layers[i].mlp.image_fused_moe(fake_hidden_states) def forward( self,