[Fix] fix mm ep empty run (#2999)

This commit is contained in:
xiaoxiaohehe001
2025-07-24 14:15:55 +08:00
committed by GitHub
parent e3a843f2c5
commit 2c0ff068e2

View File

@@ -69,6 +69,7 @@ class VLMoEMeta:
text_index: Optional[paddle.Tensor] = None text_index: Optional[paddle.Tensor] = None
image_index: Optional[paddle.Tensor] = None image_index: Optional[paddle.Tensor] = None
token_type_ids: Optional[paddle.Tensor] = None token_type_ids: Optional[paddle.Tensor] = None
fake_hidden_states: Optional[paddle.Tensor] = None
class Ernie4_5_VLMoE(nn.Layer): class Ernie4_5_VLMoE(nn.Layer):
@@ -241,6 +242,8 @@ class Ernie4_5_VLMoE(nn.Layer):
) )
else: else:
hidden_states = self.text_fused_moe(hidden_states) hidden_states = self.text_fused_moe(hidden_states)
if vl_moe_meta.fake_hidden_states is not None:
self.image_fused_moe(vl_moe_meta.fake_hidden_states)
if self.num_shared_experts > 0: if self.num_shared_experts > 0:
hidden_states += shared_experts_out hidden_states += shared_experts_out
if self.tp_size > 1: if self.tp_size > 1:
@@ -362,6 +365,7 @@ class Ernie4_5_VLModel(nn.Layer):
self.im_patch_id = fd_config.model_config.im_patch_id self.im_patch_id = fd_config.model_config.im_patch_id
self._dtype = fd_config.model_config.dtype self._dtype = fd_config.model_config.dtype
fd_config.model_config.pretrained_config.prefix_name = "ernie" fd_config.model_config.pretrained_config.prefix_name = "ernie"
self.fd_config = fd_config
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
fd_config=fd_config, fd_config=fd_config,
@@ -413,6 +417,7 @@ class Ernie4_5_VLModel(nn.Layer):
image_input = None image_input = None
text_index = None text_index = None
image_index = None image_index = None
fake_hidden_states = None
image_token_num = 0 image_token_num = 0
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding) hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
@@ -423,6 +428,14 @@ class Ernie4_5_VLModel(nn.Layer):
token_num = hidden_states.shape[0] token_num = hidden_states.shape[0]
image_token_num = paddle.count_nonzero(token_type_ids) image_token_num = paddle.count_nonzero(token_type_ids)
text_token_num = paddle.maximum((token_num - image_token_num), paddle.ones([], dtype="int64")) text_token_num = paddle.maximum((token_num - image_token_num), paddle.ones([], dtype="int64"))
if self.fd_config.parallel_config.use_ep is True:
fake_hidden_states = paddle.empty(
shape=[0, self.fd_config.model_config.hidden_size],
dtype=paddle.get_default_dtype(),
)
text_input = fake_hidden_states
if image_mask.any(): if image_mask.any():
hidden_states[image_mask] = image_features.cast(self._dtype) hidden_states[image_mask] = image_features.cast(self._dtype)
text_input = paddle.full( text_input = paddle.full(
@@ -445,6 +458,7 @@ class Ernie4_5_VLModel(nn.Layer):
text_index=text_index, text_index=text_index,
image_index=image_index, image_index=image_index,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
fake_hidden_states=fake_hidden_states,
) )
# ----------------------- # -----------------------
@@ -580,6 +594,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
self.fd_config.model_config.num_hidden_layers, self.fd_config.model_config.num_hidden_layers,
): ):
self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states) self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states)
self.ernie.layers[i].mlp.image_fused_moe(fake_hidden_states)
def forward( def forward(
self, self,