[SOT] Remove breakgraph in post processing && fix datatype (#2780)

This commit is contained in:
Ryan
2025-07-10 11:26:00 +08:00
committed by GitHub
parent 2ea267f624
commit b0f525955c
3 changed files with 20 additions and 17 deletions

View File

@@ -18,6 +18,11 @@ import paddle
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import \
get_block_shape_and_split_kv_block as \
get_block_shape_and_split_kv_block_cuda
def get_block_shape_and_split_kv_block( def get_block_shape_and_split_kv_block(
seq_lens_encoder: paddle.Tensor, seq_lens_encoder: paddle.Tensor,
@@ -34,7 +39,6 @@ def get_block_shape_and_split_kv_block(
get_block_shape_and_split_kv_block get_block_shape_and_split_kv_block
""" """
if current_platform.is_cuda(): if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import get_block_shape_and_split_kv_block
( (
encoder_batch_ids, encoder_batch_ids,
encoder_tile_ids_per_batch, encoder_tile_ids_per_batch,
@@ -47,7 +51,7 @@ def get_block_shape_and_split_kv_block(
decoder_num_blocks, decoder_num_blocks,
max_len_kv, max_len_kv,
set_max_lengths, set_max_lengths,
) = get_block_shape_and_split_kv_block( ) = get_block_shape_and_split_kv_block_cuda(
seq_lens_encoder, seq_lens_encoder,
seq_lens_decoder, seq_lens_decoder,
seq_lens_this_time, seq_lens_this_time,

View File

@@ -395,8 +395,8 @@ class Ernie4_5_VLModel(nn.Layer):
image_mask = ids_remove_padding == self.im_patch_id image_mask = ids_remove_padding == self.im_patch_id
token_type_ids = image_mask.cast("int32") token_type_ids = image_mask.cast("int32")
token_num = hidden_states.shape[0] token_num = hidden_states.shape[0]
image_token_num = paddle.count_nonzero(token_type_ids).cast("int32") image_token_num = paddle.count_nonzero(token_type_ids)
text_token_num = paddle.maximum(token_num - image_token_num, paddle.ones([], dtype="int32")) text_token_num = paddle.maximum((token_num - image_token_num), paddle.ones([], dtype="int64"))
if image_mask.any(): if image_mask.any():
hidden_states[image_mask] = image_features.cast(self._dtype) hidden_states[image_mask] = image_features.cast(self._dtype)
text_input = paddle.full( text_input = paddle.full(
@@ -444,7 +444,7 @@ class Ernie4_5_VLModel(nn.Layer):
hidden_states = extract_text_token_output( hidden_states = extract_text_token_output(
max_seq_len, max_seq_len,
max_seq_len_index.cast("int32"), max_seq_len_index.cast("int32"),
image_token_num, image_token_num.cast("int32"),
forward_meta.seq_lens_this_time, forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q, forward_meta.cu_seqlens_q,
score_text, score_text,

View File

@@ -929,18 +929,17 @@ class GPUVLModelRunner(VLModelRunnerBase):
False, False,
) # multi ends ) # multi ends
# update inputs # update inputs
with paddle.framework._no_check_dy2st_diff(): update_inputs(
update_inputs( self.share_inputs["stop_flags"],
self.share_inputs["stop_flags"], self.share_inputs["not_need_stop"],
self.share_inputs["not_need_stop"], self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_this_time"], self.share_inputs["seq_lens_encoder"],
self.share_inputs["seq_lens_encoder"], self.share_inputs["seq_lens_decoder"],
self.share_inputs["seq_lens_decoder"], self.share_inputs["input_ids"],
self.share_inputs["input_ids"], self.share_inputs["stop_nums"],
self.share_inputs["stop_nums"], next_tokens,
next_tokens, self.share_inputs["is_block_step"],
self.share_inputs["is_block_step"], )
)
save_output( save_output(
next_tokens, next_tokens,
self.share_inputs["not_need_stop"], self.share_inputs["not_need_stop"],