Files
FastDeploy/custom_ops/gpu_ops/extract_text_token_output.cu
2025-06-09 19:20:15 +08:00

105 lines
4.8 KiB
Plaintext

// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
template <int THREADBLOCK_SIZE>
__global__ void extract_text_token_output_kernel(int *max_seq_len,
int *max_seq_len_index,
int *mm_token_num_len,
int *seq_lens_this_time,
int *cu_seqlens_q,
float *score_text,
float *output,
const int bsz,
const int hidden_size) {
int bsz_index = threadIdx.x;
int block_idx = blockIdx.x;
if (bsz_index >= bsz) return;
int max_seq_len_data = max_seq_len[0];
int max_seq_len_index_data = max_seq_len_index[0];
int mm_token_num_len_data = mm_token_num_len[0];
int true_bsz = cu_seqlens_q[bsz_index + 1] - 1;
if (bsz_index >= max_seq_len_index_data) {
true_bsz = true_bsz - mm_token_num_len_data;
}
if (max_seq_len_data == mm_token_num_len_data && bsz_index == max_seq_len_index_data) {
output[bsz_index * hidden_size + block_idx] = 0.0;
} else {
if (seq_lens_this_time[bsz_index] != 0) {
output[bsz_index * hidden_size + block_idx] = score_text[true_bsz * hidden_size + block_idx];
}
}
__syncthreads();
}
std::vector<paddle::Tensor> ExtractTextTokenOutput(
const paddle::Tensor& max_seq_len,
const paddle::Tensor& max_seq_len_index,
const paddle::Tensor& mm_token_num_len,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& score_text) {
const int bsz = seq_lens_this_time.shape()[0];
const int hidden_size = score_text.shape()[1];
paddle::Tensor output = paddle::full({bsz, hidden_size}, 1, paddle::DataType::FLOAT32, score_text.place());
extract_text_token_output_kernel<1024><<<hidden_size, 1024, 0, score_text.stream()>>>(
const_cast<int*>(max_seq_len.data<int>()),
const_cast<int*>(max_seq_len_index.data<int>()),
const_cast<int*>(mm_token_num_len.data<int>()),
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(cu_seqlens_q.data<int>()),
const_cast<float*>(score_text.data<float>()),
output.data<float>(),
bsz,
hidden_size
);
return {output};
}
std::vector<std::vector<int64_t>> ExtractTextTokenOutputInferShape(const std::vector<int64_t>& max_seq_len_shape,
const std::vector<int64_t>& max_seq_len_index_shape,
const std::vector<int64_t>& mm_token_num_len_shape,
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& cu_seqlens_q_shape,
const std::vector<int64_t>& score_text_shape) {
const int bsz = seq_lens_this_time_shape[0];
const int hidden_size = score_text_shape[1];
return {{bsz, hidden_size}};
}
std::vector<paddle::DataType> ExtractTextTokenOutputInferDtype(const paddle::DataType& max_seq_len_dtype,
const paddle::DataType& max_seq_len_index_dtype,
const paddle::DataType& mm_token_num_len_dtype,
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& cu_seqlens_q_dtype,
const paddle::DataType& score_text_dtype) {
return {score_text_dtype};
}
PD_BUILD_STATIC_OP(extract_text_token_output)
.Inputs({"max_seq_len",
"max_seq_len_index",
"mm_token_num_len",
"seq_lens_this_time",
"cu_seqlens_q",
"score_text"})
.Outputs({"output"})
.SetKernelFn(PD_KERNEL(ExtractTextTokenOutput))
.SetInferShapeFn(PD_INFER_SHAPE(ExtractTextTokenOutputInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(ExtractTextTokenOutputInferDtype));