mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Code Simplification] Refactor Post-processing in VL Model Forward Method (#2937)
* rm sth useless * refactor model forward * mv bool index to kernel
This commit is contained in:
@@ -418,17 +418,16 @@ class Ernie4_5_VLModel(nn.Layer):
|
||||
text_index = None
|
||||
image_index = None
|
||||
fake_hidden_states = None
|
||||
image_token_num = 0
|
||||
|
||||
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
|
||||
token_num, hidden_dim = hidden_states.shape
|
||||
|
||||
# -----------------------
|
||||
image_mask = ids_remove_padding == self.im_patch_id
|
||||
token_type_ids = image_mask.cast("int32")
|
||||
token_num = hidden_states.shape[0]
|
||||
image_token_num = paddle.count_nonzero(token_type_ids)
|
||||
image_token_num = image_mask.sum()
|
||||
text_token_num = paddle.maximum((token_num - image_token_num), paddle.ones([], dtype="int64"))
|
||||
|
||||
token_type_ids = image_mask.cast("int32")
|
||||
if self.fd_config.parallel_config.use_ep is True:
|
||||
fake_hidden_states = paddle.empty(
|
||||
shape=[0, self.fd_config.model_config.hidden_size],
|
||||
@@ -436,20 +435,18 @@ class Ernie4_5_VLModel(nn.Layer):
|
||||
)
|
||||
text_input = fake_hidden_states
|
||||
|
||||
if image_mask.any():
|
||||
if image_token_num > 0:
|
||||
hidden_states[image_mask] = image_features.cast(self._dtype)
|
||||
text_input = paddle.full(
|
||||
shape=[text_token_num, hidden_states.shape[1]],
|
||||
fill_value=1,
|
||||
text_input = paddle.ones(
|
||||
shape=[text_token_num, hidden_dim],
|
||||
dtype=self._dtype,
|
||||
)
|
||||
image_input = paddle.full(
|
||||
shape=[image_token_num, hidden_states.shape[1]],
|
||||
fill_value=1,
|
||||
image_input = paddle.ones(
|
||||
shape=[image_token_num, hidden_dim],
|
||||
dtype=self._dtype,
|
||||
)
|
||||
text_index = paddle.zeros_like(token_type_ids)
|
||||
image_index = paddle.zeros_like(token_type_ids)
|
||||
text_index = paddle.zeros_like(image_mask, dtype="int32")
|
||||
image_index = paddle.zeros_like(image_mask, dtype="int32")
|
||||
text_image_index_out(token_type_ids, text_index, image_index)
|
||||
|
||||
vl_moe_meta = VLMoEMeta(
|
||||
@@ -474,21 +471,14 @@ class Ernie4_5_VLModel(nn.Layer):
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
# -----------------------
|
||||
hidden_states = hidden_states.cast("float32")
|
||||
score_text = hidden_states
|
||||
|
||||
if image_input is not None:
|
||||
token_type_ids = token_type_ids.reshape([-1])
|
||||
text_pos_shifted = token_type_ids[:token_num] == 0
|
||||
score_text = hidden_states[text_pos_shifted.reshape([-1])]
|
||||
max_seq_len, max_seq_len_index = paddle.topk(forward_meta.seq_lens_this_time.squeeze(-1), k=1)
|
||||
max_seq_len, max_seq_len_index = paddle.topk(forward_meta.seq_lens_this_time, k=1)
|
||||
hidden_states = extract_text_token_output(
|
||||
max_seq_len,
|
||||
max_seq_len_index.cast("int32"),
|
||||
image_token_num.cast("int32"),
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.cu_seqlens_q,
|
||||
score_text,
|
||||
hidden_states.cast("float32"),
|
||||
).cast(self._dtype)
|
||||
# -----------------------
|
||||
|
||||
|
Reference in New Issue
Block a user