mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Fix] fix mm ep empty run (#2999)
This commit is contained in:
@@ -69,6 +69,7 @@ class VLMoEMeta:
|
|||||||
text_index: Optional[paddle.Tensor] = None
|
text_index: Optional[paddle.Tensor] = None
|
||||||
image_index: Optional[paddle.Tensor] = None
|
image_index: Optional[paddle.Tensor] = None
|
||||||
token_type_ids: Optional[paddle.Tensor] = None
|
token_type_ids: Optional[paddle.Tensor] = None
|
||||||
|
fake_hidden_states: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
class Ernie4_5_VLMoE(nn.Layer):
|
class Ernie4_5_VLMoE(nn.Layer):
|
||||||
@@ -241,6 +242,8 @@ class Ernie4_5_VLMoE(nn.Layer):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hidden_states = self.text_fused_moe(hidden_states)
|
hidden_states = self.text_fused_moe(hidden_states)
|
||||||
|
if vl_moe_meta.fake_hidden_states is not None:
|
||||||
|
self.image_fused_moe(vl_moe_meta.fake_hidden_states)
|
||||||
if self.num_shared_experts > 0:
|
if self.num_shared_experts > 0:
|
||||||
hidden_states += shared_experts_out
|
hidden_states += shared_experts_out
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
@@ -362,6 +365,7 @@ class Ernie4_5_VLModel(nn.Layer):
|
|||||||
self.im_patch_id = fd_config.model_config.im_patch_id
|
self.im_patch_id = fd_config.model_config.im_patch_id
|
||||||
self._dtype = fd_config.model_config.dtype
|
self._dtype = fd_config.model_config.dtype
|
||||||
fd_config.model_config.pretrained_config.prefix_name = "ernie"
|
fd_config.model_config.pretrained_config.prefix_name = "ernie"
|
||||||
|
self.fd_config = fd_config
|
||||||
|
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
fd_config=fd_config,
|
fd_config=fd_config,
|
||||||
@@ -413,6 +417,7 @@ class Ernie4_5_VLModel(nn.Layer):
|
|||||||
image_input = None
|
image_input = None
|
||||||
text_index = None
|
text_index = None
|
||||||
image_index = None
|
image_index = None
|
||||||
|
fake_hidden_states = None
|
||||||
image_token_num = 0
|
image_token_num = 0
|
||||||
|
|
||||||
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
|
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
|
||||||
@@ -423,6 +428,14 @@ class Ernie4_5_VLModel(nn.Layer):
|
|||||||
token_num = hidden_states.shape[0]
|
token_num = hidden_states.shape[0]
|
||||||
image_token_num = paddle.count_nonzero(token_type_ids)
|
image_token_num = paddle.count_nonzero(token_type_ids)
|
||||||
text_token_num = paddle.maximum((token_num - image_token_num), paddle.ones([], dtype="int64"))
|
text_token_num = paddle.maximum((token_num - image_token_num), paddle.ones([], dtype="int64"))
|
||||||
|
|
||||||
|
if self.fd_config.parallel_config.use_ep is True:
|
||||||
|
fake_hidden_states = paddle.empty(
|
||||||
|
shape=[0, self.fd_config.model_config.hidden_size],
|
||||||
|
dtype=paddle.get_default_dtype(),
|
||||||
|
)
|
||||||
|
text_input = fake_hidden_states
|
||||||
|
|
||||||
if image_mask.any():
|
if image_mask.any():
|
||||||
hidden_states[image_mask] = image_features.cast(self._dtype)
|
hidden_states[image_mask] = image_features.cast(self._dtype)
|
||||||
text_input = paddle.full(
|
text_input = paddle.full(
|
||||||
@@ -445,6 +458,7 @@ class Ernie4_5_VLModel(nn.Layer):
|
|||||||
text_index=text_index,
|
text_index=text_index,
|
||||||
image_index=image_index,
|
image_index=image_index,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
|
fake_hidden_states=fake_hidden_states,
|
||||||
)
|
)
|
||||||
# -----------------------
|
# -----------------------
|
||||||
|
|
||||||
@@ -580,6 +594,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
|||||||
self.fd_config.model_config.num_hidden_layers,
|
self.fd_config.model_config.num_hidden_layers,
|
||||||
):
|
):
|
||||||
self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states)
|
self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states)
|
||||||
|
self.ernie.layers[i].mlp.image_fused_moe(fake_hidden_states)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
Reference in New Issue
Block a user