From 94264bbf60fd6357ec44a19c3769fd86517424d8 Mon Sep 17 00:00:00 2001 From: Ryan Date: Fri, 1 Aug 2025 17:28:07 +0800 Subject: [PATCH] [Code Simplification] Refactor Post-processing in VL Model Forward Method (#2937) * rm sth useless * refactor model forward * mv bool index to kernel --- custom_ops/gpu_ops/cpp_extensions.cc | 2 +- .../gpu_ops/extract_text_token_output.cu | 27 +++++++-------- .../models/ernie4_5_vl/ernie4_5_vl_moe.py | 34 +++++++------------ 3 files changed, 25 insertions(+), 38 deletions(-) diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index b4d7b952d..4639d1e93 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -323,7 +323,7 @@ std::vector ExtractTextTokenOutput( const paddle::Tensor &max_seq_len, const paddle::Tensor &max_seq_len_index, const paddle::Tensor &mm_token_num_len, const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &score_text); + const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &hidden_states); std::vector MoEDeepGEMMPermute(const paddle::Tensor &x, const paddle::Tensor &topk_idx, diff --git a/custom_ops/gpu_ops/extract_text_token_output.cu b/custom_ops/gpu_ops/extract_text_token_output.cu index ff04a813e..4459b967e 100644 --- a/custom_ops/gpu_ops/extract_text_token_output.cu +++ b/custom_ops/gpu_ops/extract_text_token_output.cu @@ -20,7 +20,7 @@ __global__ void extract_text_token_output_kernel(int *max_seq_len, int *mm_token_num_len, int *seq_lens_this_time, int *cu_seqlens_q, - float *score_text, + float *hidden_states, float *output, const int bsz, const int hidden_size) { @@ -32,14 +32,11 @@ __global__ void extract_text_token_output_kernel(int *max_seq_len, int max_seq_len_index_data = max_seq_len_index[0]; int mm_token_num_len_data = mm_token_num_len[0]; int true_bsz = cu_seqlens_q[bsz_index + 1] - 1; - if (bsz_index >= max_seq_len_index_data) { - true_bsz = true_bsz - mm_token_num_len_data; - } if (max_seq_len_data == mm_token_num_len_data && bsz_index == max_seq_len_index_data) { output[bsz_index * hidden_size + block_idx] = 0.0; } else { if (seq_lens_this_time[bsz_index] != 0) { - output[bsz_index * hidden_size + block_idx] = score_text[true_bsz * hidden_size + block_idx]; + output[bsz_index * hidden_size + block_idx] = hidden_states[true_bsz * hidden_size + block_idx]; } } __syncthreads(); @@ -51,19 +48,19 @@ std::vector ExtractTextTokenOutput( const paddle::Tensor& mm_token_num_len, const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& score_text) { + const paddle::Tensor& hidden_states) { const int bsz = seq_lens_this_time.shape()[0]; - const int hidden_size = score_text.shape()[1]; - paddle::Tensor output = paddle::full({bsz, hidden_size}, 1, paddle::DataType::FLOAT32, score_text.place()); + const int hidden_size = hidden_states.shape()[1]; + paddle::Tensor output = paddle::full({bsz, hidden_size}, 1, paddle::DataType::FLOAT32, hidden_states.place()); - extract_text_token_output_kernel<1024><<>>( + extract_text_token_output_kernel<1024><<>>( const_cast(max_seq_len.data()), const_cast(max_seq_len_index.data()), const_cast(mm_token_num_len.data()), const_cast(seq_lens_this_time.data()), const_cast(cu_seqlens_q.data()), - const_cast(score_text.data()), + const_cast(hidden_states.data()), output.data(), bsz, hidden_size @@ -76,9 +73,9 @@ std::vector> ExtractTextTokenOutputInferShape(const std::ve const std::vector& mm_token_num_len_shape, const std::vector& seq_lens_this_time_shape, const std::vector& cu_seqlens_q_shape, - const std::vector& score_text_shape) { + const std::vector& hidden_states_shape) { const int bsz = seq_lens_this_time_shape[0]; - const int hidden_size = score_text_shape[1]; + const int hidden_size = hidden_states_shape[1]; return {{bsz, hidden_size}}; } @@ -87,8 +84,8 @@ std::vector ExtractTextTokenOutputInferDtype(const paddle::Dat const paddle::DataType& mm_token_num_len_dtype, const paddle::DataType& seq_lens_this_time_dtype, const paddle::DataType& cu_seqlens_q_dtype, - const paddle::DataType& score_text_dtype) { - return {score_text_dtype}; + const paddle::DataType& hidden_states_dtype) { + return {hidden_states_dtype}; } PD_BUILD_STATIC_OP(extract_text_token_output) @@ -97,7 +94,7 @@ PD_BUILD_STATIC_OP(extract_text_token_output) "mm_token_num_len", "seq_lens_this_time", "cu_seqlens_q", - "score_text"}) + "hidden_states"}) .Outputs({"output"}) .SetKernelFn(PD_KERNEL(ExtractTextTokenOutput)) .SetInferShapeFn(PD_INFER_SHAPE(ExtractTextTokenOutputInferShape)) diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index 2dd562135..6016b06fd 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -418,17 +418,16 @@ class Ernie4_5_VLModel(nn.Layer): text_index = None image_index = None fake_hidden_states = None - image_token_num = 0 hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding) + token_num, hidden_dim = hidden_states.shape # ----------------------- image_mask = ids_remove_padding == self.im_patch_id - token_type_ids = image_mask.cast("int32") - token_num = hidden_states.shape[0] - image_token_num = paddle.count_nonzero(token_type_ids) + image_token_num = image_mask.sum() text_token_num = paddle.maximum((token_num - image_token_num), paddle.ones([], dtype="int64")) + token_type_ids = image_mask.cast("int32") if self.fd_config.parallel_config.use_ep is True: fake_hidden_states = paddle.empty( shape=[0, self.fd_config.model_config.hidden_size], @@ -436,20 +435,18 @@ class Ernie4_5_VLModel(nn.Layer): ) text_input = fake_hidden_states - if image_mask.any(): + if image_token_num > 0: hidden_states[image_mask] = image_features.cast(self._dtype) - text_input = paddle.full( - shape=[text_token_num, hidden_states.shape[1]], - fill_value=1, + text_input = paddle.ones( + shape=[text_token_num, hidden_dim], dtype=self._dtype, ) - image_input = paddle.full( - shape=[image_token_num, hidden_states.shape[1]], - fill_value=1, + image_input = paddle.ones( + shape=[image_token_num, hidden_dim], dtype=self._dtype, ) - text_index = paddle.zeros_like(token_type_ids) - image_index = paddle.zeros_like(token_type_ids) + text_index = paddle.zeros_like(image_mask, dtype="int32") + image_index = paddle.zeros_like(image_mask, dtype="int32") text_image_index_out(token_type_ids, text_index, image_index) vl_moe_meta = VLMoEMeta( @@ -474,21 +471,14 @@ class Ernie4_5_VLModel(nn.Layer): hidden_states = hidden_states + residual # ----------------------- - hidden_states = hidden_states.cast("float32") - score_text = hidden_states - - if image_input is not None: - token_type_ids = token_type_ids.reshape([-1]) - text_pos_shifted = token_type_ids[:token_num] == 0 - score_text = hidden_states[text_pos_shifted.reshape([-1])] - max_seq_len, max_seq_len_index = paddle.topk(forward_meta.seq_lens_this_time.squeeze(-1), k=1) + max_seq_len, max_seq_len_index = paddle.topk(forward_meta.seq_lens_this_time, k=1) hidden_states = extract_text_token_output( max_seq_len, max_seq_len_index.cast("int32"), image_token_num.cast("int32"), forward_meta.seq_lens_this_time, forward_meta.cu_seqlens_q, - score_text, + hidden_states.cast("float32"), ).cast(self._dtype) # -----------------------