[SOT] Remove breakgraph in post processing && fix datatype (#2780)

This commit is contained in:
Ryan
2025-07-10 11:26:00 +08:00
committed by GitHub
parent 2ea267f624
commit b0f525955c
3 changed files with 20 additions and 17 deletions

View File

@@ -18,6 +18,11 @@ import paddle
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import \
get_block_shape_and_split_kv_block as \
get_block_shape_and_split_kv_block_cuda
def get_block_shape_and_split_kv_block(
seq_lens_encoder: paddle.Tensor,
@@ -34,7 +39,6 @@ def get_block_shape_and_split_kv_block(
get_block_shape_and_split_kv_block
"""
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import get_block_shape_and_split_kv_block
(
encoder_batch_ids,
encoder_tile_ids_per_batch,
@@ -47,7 +51,7 @@ def get_block_shape_and_split_kv_block(
decoder_num_blocks,
max_len_kv,
set_max_lengths,
) = get_block_shape_and_split_kv_block(
) = get_block_shape_and_split_kv_block_cuda(
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,

View File

@@ -395,8 +395,8 @@ class Ernie4_5_VLModel(nn.Layer):
image_mask = ids_remove_padding == self.im_patch_id
token_type_ids = image_mask.cast("int32")
token_num = hidden_states.shape[0]
image_token_num = paddle.count_nonzero(token_type_ids).cast("int32")
text_token_num = paddle.maximum(token_num - image_token_num, paddle.ones([], dtype="int32"))
image_token_num = paddle.count_nonzero(token_type_ids)
text_token_num = paddle.maximum((token_num - image_token_num), paddle.ones([], dtype="int64"))
if image_mask.any():
hidden_states[image_mask] = image_features.cast(self._dtype)
text_input = paddle.full(
@@ -444,7 +444,7 @@ class Ernie4_5_VLModel(nn.Layer):
hidden_states = extract_text_token_output(
max_seq_len,
max_seq_len_index.cast("int32"),
image_token_num,
image_token_num.cast("int32"),
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
score_text,

View File

@@ -929,7 +929,6 @@ class GPUVLModelRunner(VLModelRunnerBase):
False,
) # multi ends
# update inputs
with paddle.framework._no_check_dy2st_diff():
update_inputs(
self.share_inputs["stop_flags"],
self.share_inputs["not_need_stop"],