[SOT] Remove BreakGraph with paddle.maximum (#2731)

* rm if with clip

* clip -> maximum

* int64 -> int32
This commit is contained in:
Ryan
2025-07-08 11:44:25 +08:00
committed by GitHub
parent 1eb8ea7328
commit fefbd65cf8

View File

@@ -393,8 +393,7 @@ class Ernie4_5_VLModel(nn.Layer):
token_type_ids = image_mask.cast("int32")
token_num = hidden_states.shape[0]
image_token_num = paddle.count_nonzero(token_type_ids).cast("int32")
text_token_num = ((token_num - image_token_num) if
(token_num - image_token_num) > 0 else 1)
text_token_num = paddle.maximum(token_num - image_token_num, paddle.ones([], dtype="int32"))
if image_mask.any():
hidden_states[image_mask] = image_features.cast(self._dtype)
text_input = paddle.full(