mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Code Simplification] Refactor Post-processing in VL Model Forward Method (#2937)
* rm sth useless * refactor model forward * mv bool index to kernel
This commit is contained in:
@@ -323,7 +323,7 @@ std::vector<paddle::Tensor> ExtractTextTokenOutput(
|
|||||||
const paddle::Tensor &max_seq_len, const paddle::Tensor &max_seq_len_index,
|
const paddle::Tensor &max_seq_len, const paddle::Tensor &max_seq_len_index,
|
||||||
const paddle::Tensor &mm_token_num_len,
|
const paddle::Tensor &mm_token_num_len,
|
||||||
const paddle::Tensor &seq_lens_this_time,
|
const paddle::Tensor &seq_lens_this_time,
|
||||||
const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &score_text);
|
const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &hidden_states);
|
||||||
|
|
||||||
std::vector<paddle::Tensor> MoEDeepGEMMPermute(const paddle::Tensor &x,
|
std::vector<paddle::Tensor> MoEDeepGEMMPermute(const paddle::Tensor &x,
|
||||||
const paddle::Tensor &topk_idx,
|
const paddle::Tensor &topk_idx,
|
||||||
|
@@ -20,7 +20,7 @@ __global__ void extract_text_token_output_kernel(int *max_seq_len,
|
|||||||
int *mm_token_num_len,
|
int *mm_token_num_len,
|
||||||
int *seq_lens_this_time,
|
int *seq_lens_this_time,
|
||||||
int *cu_seqlens_q,
|
int *cu_seqlens_q,
|
||||||
float *score_text,
|
float *hidden_states,
|
||||||
float *output,
|
float *output,
|
||||||
const int bsz,
|
const int bsz,
|
||||||
const int hidden_size) {
|
const int hidden_size) {
|
||||||
@@ -32,14 +32,11 @@ __global__ void extract_text_token_output_kernel(int *max_seq_len,
|
|||||||
int max_seq_len_index_data = max_seq_len_index[0];
|
int max_seq_len_index_data = max_seq_len_index[0];
|
||||||
int mm_token_num_len_data = mm_token_num_len[0];
|
int mm_token_num_len_data = mm_token_num_len[0];
|
||||||
int true_bsz = cu_seqlens_q[bsz_index + 1] - 1;
|
int true_bsz = cu_seqlens_q[bsz_index + 1] - 1;
|
||||||
if (bsz_index >= max_seq_len_index_data) {
|
|
||||||
true_bsz = true_bsz - mm_token_num_len_data;
|
|
||||||
}
|
|
||||||
if (max_seq_len_data == mm_token_num_len_data && bsz_index == max_seq_len_index_data) {
|
if (max_seq_len_data == mm_token_num_len_data && bsz_index == max_seq_len_index_data) {
|
||||||
output[bsz_index * hidden_size + block_idx] = 0.0;
|
output[bsz_index * hidden_size + block_idx] = 0.0;
|
||||||
} else {
|
} else {
|
||||||
if (seq_lens_this_time[bsz_index] != 0) {
|
if (seq_lens_this_time[bsz_index] != 0) {
|
||||||
output[bsz_index * hidden_size + block_idx] = score_text[true_bsz * hidden_size + block_idx];
|
output[bsz_index * hidden_size + block_idx] = hidden_states[true_bsz * hidden_size + block_idx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
@@ -51,19 +48,19 @@ std::vector<paddle::Tensor> ExtractTextTokenOutput(
|
|||||||
const paddle::Tensor& mm_token_num_len,
|
const paddle::Tensor& mm_token_num_len,
|
||||||
const paddle::Tensor& seq_lens_this_time,
|
const paddle::Tensor& seq_lens_this_time,
|
||||||
const paddle::Tensor& cu_seqlens_q,
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
const paddle::Tensor& score_text) {
|
const paddle::Tensor& hidden_states) {
|
||||||
|
|
||||||
const int bsz = seq_lens_this_time.shape()[0];
|
const int bsz = seq_lens_this_time.shape()[0];
|
||||||
const int hidden_size = score_text.shape()[1];
|
const int hidden_size = hidden_states.shape()[1];
|
||||||
paddle::Tensor output = paddle::full({bsz, hidden_size}, 1, paddle::DataType::FLOAT32, score_text.place());
|
paddle::Tensor output = paddle::full({bsz, hidden_size}, 1, paddle::DataType::FLOAT32, hidden_states.place());
|
||||||
|
|
||||||
extract_text_token_output_kernel<1024><<<hidden_size, 1024, 0, score_text.stream()>>>(
|
extract_text_token_output_kernel<1024><<<hidden_size, 1024, 0, hidden_states.stream()>>>(
|
||||||
const_cast<int*>(max_seq_len.data<int>()),
|
const_cast<int*>(max_seq_len.data<int>()),
|
||||||
const_cast<int*>(max_seq_len_index.data<int>()),
|
const_cast<int*>(max_seq_len_index.data<int>()),
|
||||||
const_cast<int*>(mm_token_num_len.data<int>()),
|
const_cast<int*>(mm_token_num_len.data<int>()),
|
||||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||||
const_cast<int*>(cu_seqlens_q.data<int>()),
|
const_cast<int*>(cu_seqlens_q.data<int>()),
|
||||||
const_cast<float*>(score_text.data<float>()),
|
const_cast<float*>(hidden_states.data<float>()),
|
||||||
output.data<float>(),
|
output.data<float>(),
|
||||||
bsz,
|
bsz,
|
||||||
hidden_size
|
hidden_size
|
||||||
@@ -76,9 +73,9 @@ std::vector<std::vector<int64_t>> ExtractTextTokenOutputInferShape(const std::ve
|
|||||||
const std::vector<int64_t>& mm_token_num_len_shape,
|
const std::vector<int64_t>& mm_token_num_len_shape,
|
||||||
const std::vector<int64_t>& seq_lens_this_time_shape,
|
const std::vector<int64_t>& seq_lens_this_time_shape,
|
||||||
const std::vector<int64_t>& cu_seqlens_q_shape,
|
const std::vector<int64_t>& cu_seqlens_q_shape,
|
||||||
const std::vector<int64_t>& score_text_shape) {
|
const std::vector<int64_t>& hidden_states_shape) {
|
||||||
const int bsz = seq_lens_this_time_shape[0];
|
const int bsz = seq_lens_this_time_shape[0];
|
||||||
const int hidden_size = score_text_shape[1];
|
const int hidden_size = hidden_states_shape[1];
|
||||||
return {{bsz, hidden_size}};
|
return {{bsz, hidden_size}};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,8 +84,8 @@ std::vector<paddle::DataType> ExtractTextTokenOutputInferDtype(const paddle::Dat
|
|||||||
const paddle::DataType& mm_token_num_len_dtype,
|
const paddle::DataType& mm_token_num_len_dtype,
|
||||||
const paddle::DataType& seq_lens_this_time_dtype,
|
const paddle::DataType& seq_lens_this_time_dtype,
|
||||||
const paddle::DataType& cu_seqlens_q_dtype,
|
const paddle::DataType& cu_seqlens_q_dtype,
|
||||||
const paddle::DataType& score_text_dtype) {
|
const paddle::DataType& hidden_states_dtype) {
|
||||||
return {score_text_dtype};
|
return {hidden_states_dtype};
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(extract_text_token_output)
|
PD_BUILD_STATIC_OP(extract_text_token_output)
|
||||||
@@ -97,7 +94,7 @@ PD_BUILD_STATIC_OP(extract_text_token_output)
|
|||||||
"mm_token_num_len",
|
"mm_token_num_len",
|
||||||
"seq_lens_this_time",
|
"seq_lens_this_time",
|
||||||
"cu_seqlens_q",
|
"cu_seqlens_q",
|
||||||
"score_text"})
|
"hidden_states"})
|
||||||
.Outputs({"output"})
|
.Outputs({"output"})
|
||||||
.SetKernelFn(PD_KERNEL(ExtractTextTokenOutput))
|
.SetKernelFn(PD_KERNEL(ExtractTextTokenOutput))
|
||||||
.SetInferShapeFn(PD_INFER_SHAPE(ExtractTextTokenOutputInferShape))
|
.SetInferShapeFn(PD_INFER_SHAPE(ExtractTextTokenOutputInferShape))
|
||||||
|
@@ -418,17 +418,16 @@ class Ernie4_5_VLModel(nn.Layer):
|
|||||||
text_index = None
|
text_index = None
|
||||||
image_index = None
|
image_index = None
|
||||||
fake_hidden_states = None
|
fake_hidden_states = None
|
||||||
image_token_num = 0
|
|
||||||
|
|
||||||
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
|
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
|
||||||
|
token_num, hidden_dim = hidden_states.shape
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
image_mask = ids_remove_padding == self.im_patch_id
|
image_mask = ids_remove_padding == self.im_patch_id
|
||||||
token_type_ids = image_mask.cast("int32")
|
image_token_num = image_mask.sum()
|
||||||
token_num = hidden_states.shape[0]
|
|
||||||
image_token_num = paddle.count_nonzero(token_type_ids)
|
|
||||||
text_token_num = paddle.maximum((token_num - image_token_num), paddle.ones([], dtype="int64"))
|
text_token_num = paddle.maximum((token_num - image_token_num), paddle.ones([], dtype="int64"))
|
||||||
|
|
||||||
|
token_type_ids = image_mask.cast("int32")
|
||||||
if self.fd_config.parallel_config.use_ep is True:
|
if self.fd_config.parallel_config.use_ep is True:
|
||||||
fake_hidden_states = paddle.empty(
|
fake_hidden_states = paddle.empty(
|
||||||
shape=[0, self.fd_config.model_config.hidden_size],
|
shape=[0, self.fd_config.model_config.hidden_size],
|
||||||
@@ -436,20 +435,18 @@ class Ernie4_5_VLModel(nn.Layer):
|
|||||||
)
|
)
|
||||||
text_input = fake_hidden_states
|
text_input = fake_hidden_states
|
||||||
|
|
||||||
if image_mask.any():
|
if image_token_num > 0:
|
||||||
hidden_states[image_mask] = image_features.cast(self._dtype)
|
hidden_states[image_mask] = image_features.cast(self._dtype)
|
||||||
text_input = paddle.full(
|
text_input = paddle.ones(
|
||||||
shape=[text_token_num, hidden_states.shape[1]],
|
shape=[text_token_num, hidden_dim],
|
||||||
fill_value=1,
|
|
||||||
dtype=self._dtype,
|
dtype=self._dtype,
|
||||||
)
|
)
|
||||||
image_input = paddle.full(
|
image_input = paddle.ones(
|
||||||
shape=[image_token_num, hidden_states.shape[1]],
|
shape=[image_token_num, hidden_dim],
|
||||||
fill_value=1,
|
|
||||||
dtype=self._dtype,
|
dtype=self._dtype,
|
||||||
)
|
)
|
||||||
text_index = paddle.zeros_like(token_type_ids)
|
text_index = paddle.zeros_like(image_mask, dtype="int32")
|
||||||
image_index = paddle.zeros_like(token_type_ids)
|
image_index = paddle.zeros_like(image_mask, dtype="int32")
|
||||||
text_image_index_out(token_type_ids, text_index, image_index)
|
text_image_index_out(token_type_ids, text_index, image_index)
|
||||||
|
|
||||||
vl_moe_meta = VLMoEMeta(
|
vl_moe_meta = VLMoEMeta(
|
||||||
@@ -474,21 +471,14 @@ class Ernie4_5_VLModel(nn.Layer):
|
|||||||
hidden_states = hidden_states + residual
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
hidden_states = hidden_states.cast("float32")
|
max_seq_len, max_seq_len_index = paddle.topk(forward_meta.seq_lens_this_time, k=1)
|
||||||
score_text = hidden_states
|
|
||||||
|
|
||||||
if image_input is not None:
|
|
||||||
token_type_ids = token_type_ids.reshape([-1])
|
|
||||||
text_pos_shifted = token_type_ids[:token_num] == 0
|
|
||||||
score_text = hidden_states[text_pos_shifted.reshape([-1])]
|
|
||||||
max_seq_len, max_seq_len_index = paddle.topk(forward_meta.seq_lens_this_time.squeeze(-1), k=1)
|
|
||||||
hidden_states = extract_text_token_output(
|
hidden_states = extract_text_token_output(
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
max_seq_len_index.cast("int32"),
|
max_seq_len_index.cast("int32"),
|
||||||
image_token_num.cast("int32"),
|
image_token_num.cast("int32"),
|
||||||
forward_meta.seq_lens_this_time,
|
forward_meta.seq_lens_this_time,
|
||||||
forward_meta.cu_seqlens_q,
|
forward_meta.cu_seqlens_q,
|
||||||
score_text,
|
hidden_states.cast("float32"),
|
||||||
).cast(self._dtype)
|
).cast(self._dtype)
|
||||||
# -----------------------
|
# -----------------------
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user