[Code Simplification] Refactor Post-processing in VL Model Forward Method (#2937)

* rm sth useless

* refactor model forward

* mv bool index to kernel
This commit is contained in:
Ryan
2025-08-01 17:28:07 +08:00
committed by GitHub
parent 3a4db15765
commit 94264bbf60
3 changed files with 25 additions and 38 deletions

View File

@@ -323,7 +323,7 @@ std::vector<paddle::Tensor> ExtractTextTokenOutput(
const paddle::Tensor &max_seq_len, const paddle::Tensor &max_seq_len_index,
const paddle::Tensor &mm_token_num_len,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &score_text);
const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &hidden_states);
std::vector<paddle::Tensor> MoEDeepGEMMPermute(const paddle::Tensor &x,
const paddle::Tensor &topk_idx,

View File

@@ -20,7 +20,7 @@ __global__ void extract_text_token_output_kernel(int *max_seq_len,
int *mm_token_num_len,
int *seq_lens_this_time,
int *cu_seqlens_q,
float *score_text,
float *hidden_states,
float *output,
const int bsz,
const int hidden_size) {
@@ -32,14 +32,11 @@ __global__ void extract_text_token_output_kernel(int *max_seq_len,
int max_seq_len_index_data = max_seq_len_index[0];
int mm_token_num_len_data = mm_token_num_len[0];
int true_bsz = cu_seqlens_q[bsz_index + 1] - 1;
if (bsz_index >= max_seq_len_index_data) {
true_bsz = true_bsz - mm_token_num_len_data;
}
if (max_seq_len_data == mm_token_num_len_data && bsz_index == max_seq_len_index_data) {
output[bsz_index * hidden_size + block_idx] = 0.0;
} else {
if (seq_lens_this_time[bsz_index] != 0) {
output[bsz_index * hidden_size + block_idx] = score_text[true_bsz * hidden_size + block_idx];
output[bsz_index * hidden_size + block_idx] = hidden_states[true_bsz * hidden_size + block_idx];
}
}
__syncthreads();
@@ -51,19 +48,19 @@ std::vector<paddle::Tensor> ExtractTextTokenOutput(
const paddle::Tensor& mm_token_num_len,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& score_text) {
const paddle::Tensor& hidden_states) {
const int bsz = seq_lens_this_time.shape()[0];
const int hidden_size = score_text.shape()[1];
paddle::Tensor output = paddle::full({bsz, hidden_size}, 1, paddle::DataType::FLOAT32, score_text.place());
const int hidden_size = hidden_states.shape()[1];
paddle::Tensor output = paddle::full({bsz, hidden_size}, 1, paddle::DataType::FLOAT32, hidden_states.place());
extract_text_token_output_kernel<1024><<<hidden_size, 1024, 0, score_text.stream()>>>(
extract_text_token_output_kernel<1024><<<hidden_size, 1024, 0, hidden_states.stream()>>>(
const_cast<int*>(max_seq_len.data<int>()),
const_cast<int*>(max_seq_len_index.data<int>()),
const_cast<int*>(mm_token_num_len.data<int>()),
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(cu_seqlens_q.data<int>()),
const_cast<float*>(score_text.data<float>()),
const_cast<float*>(hidden_states.data<float>()),
output.data<float>(),
bsz,
hidden_size
@@ -76,9 +73,9 @@ std::vector<std::vector<int64_t>> ExtractTextTokenOutputInferShape(const std::ve
const std::vector<int64_t>& mm_token_num_len_shape,
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& cu_seqlens_q_shape,
const std::vector<int64_t>& score_text_shape) {
const std::vector<int64_t>& hidden_states_shape) {
const int bsz = seq_lens_this_time_shape[0];
const int hidden_size = score_text_shape[1];
const int hidden_size = hidden_states_shape[1];
return {{bsz, hidden_size}};
}
@@ -87,8 +84,8 @@ std::vector<paddle::DataType> ExtractTextTokenOutputInferDtype(const paddle::Dat
const paddle::DataType& mm_token_num_len_dtype,
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& cu_seqlens_q_dtype,
const paddle::DataType& score_text_dtype) {
return {score_text_dtype};
const paddle::DataType& hidden_states_dtype) {
return {hidden_states_dtype};
}
PD_BUILD_STATIC_OP(extract_text_token_output)
@@ -97,7 +94,7 @@ PD_BUILD_STATIC_OP(extract_text_token_output)
"mm_token_num_len",
"seq_lens_this_time",
"cu_seqlens_q",
"score_text"})
"hidden_states"})
.Outputs({"output"})
.SetKernelFn(PD_KERNEL(ExtractTextTokenOutput))
.SetInferShapeFn(PD_INFER_SHAPE(ExtractTextTokenOutputInferShape))