mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-14 04:44:00 +08:00
[SOT] Remove breakgraph in post processing && fix datatype (#2780)
This commit is contained in:
@@ -18,6 +18,11 @@ import paddle
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import \
|
||||
get_block_shape_and_split_kv_block as \
|
||||
get_block_shape_and_split_kv_block_cuda
|
||||
|
||||
|
||||
def get_block_shape_and_split_kv_block(
|
||||
seq_lens_encoder: paddle.Tensor,
|
||||
@@ -34,7 +39,6 @@ def get_block_shape_and_split_kv_block(
|
||||
get_block_shape_and_split_kv_block
|
||||
"""
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import get_block_shape_and_split_kv_block
|
||||
(
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
@@ -47,7 +51,7 @@ def get_block_shape_and_split_kv_block(
|
||||
decoder_num_blocks,
|
||||
max_len_kv,
|
||||
set_max_lengths,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
) = get_block_shape_and_split_kv_block_cuda(
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
|
@@ -395,8 +395,8 @@ class Ernie4_5_VLModel(nn.Layer):
|
||||
image_mask = ids_remove_padding == self.im_patch_id
|
||||
token_type_ids = image_mask.cast("int32")
|
||||
token_num = hidden_states.shape[0]
|
||||
image_token_num = paddle.count_nonzero(token_type_ids).cast("int32")
|
||||
text_token_num = paddle.maximum(token_num - image_token_num, paddle.ones([], dtype="int32"))
|
||||
image_token_num = paddle.count_nonzero(token_type_ids)
|
||||
text_token_num = paddle.maximum((token_num - image_token_num), paddle.ones([], dtype="int64"))
|
||||
if image_mask.any():
|
||||
hidden_states[image_mask] = image_features.cast(self._dtype)
|
||||
text_input = paddle.full(
|
||||
@@ -444,7 +444,7 @@ class Ernie4_5_VLModel(nn.Layer):
|
||||
hidden_states = extract_text_token_output(
|
||||
max_seq_len,
|
||||
max_seq_len_index.cast("int32"),
|
||||
image_token_num,
|
||||
image_token_num.cast("int32"),
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.cu_seqlens_q,
|
||||
score_text,
|
||||
|
@@ -929,7 +929,6 @@ class GPUVLModelRunner(VLModelRunnerBase):
|
||||
False,
|
||||
) # multi ends
|
||||
# update inputs
|
||||
with paddle.framework._no_check_dy2st_diff():
|
||||
update_inputs(
|
||||
self.share_inputs["stop_flags"],
|
||||
self.share_inputs["not_need_stop"],
|
||||
|
Reference in New Issue
Block a user