// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "helper.h" template __global__ void extract_text_token_output_kernel(int *max_seq_len, int *max_seq_len_index, int *mm_token_num_len, int *seq_lens_this_time, int *cu_seqlens_q, float *hidden_states, float *output, const int bsz, const int hidden_size) { int bsz_index = threadIdx.x; int block_idx = blockIdx.x; if (bsz_index >= bsz) return; int max_seq_len_data = max_seq_len[0]; int max_seq_len_index_data = max_seq_len_index[0]; int mm_token_num_len_data = mm_token_num_len[0]; int true_bsz = cu_seqlens_q[bsz_index + 1] - 1; if (max_seq_len_data == mm_token_num_len_data && bsz_index == max_seq_len_index_data) { output[bsz_index * hidden_size + block_idx] = 0.0; } else { if (seq_lens_this_time[bsz_index] != 0) { output[bsz_index * hidden_size + block_idx] = hidden_states[true_bsz * hidden_size + block_idx]; } } __syncthreads(); } std::vector ExtractTextTokenOutput( const paddle::Tensor& max_seq_len, const paddle::Tensor& max_seq_len_index, const paddle::Tensor& mm_token_num_len, const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& hidden_states) { const int bsz = seq_lens_this_time.shape()[0]; const int hidden_size = hidden_states.shape()[1]; paddle::Tensor output = paddle::full({bsz, hidden_size}, 1, paddle::DataType::FLOAT32, hidden_states.place()); extract_text_token_output_kernel<1024><<>>( const_cast(max_seq_len.data()), const_cast(max_seq_len_index.data()), const_cast(mm_token_num_len.data()), const_cast(seq_lens_this_time.data()), const_cast(cu_seqlens_q.data()), const_cast(hidden_states.data()), output.data(), bsz, hidden_size ); return {output}; } std::vector> ExtractTextTokenOutputInferShape(const std::vector& max_seq_len_shape, const std::vector& max_seq_len_index_shape, const std::vector& mm_token_num_len_shape, const std::vector& seq_lens_this_time_shape, const std::vector& cu_seqlens_q_shape, const std::vector& hidden_states_shape) { const int bsz = seq_lens_this_time_shape[0]; const int hidden_size = hidden_states_shape[1]; return {{bsz, hidden_size}}; } std::vector ExtractTextTokenOutputInferDtype(const paddle::DataType& max_seq_len_dtype, const paddle::DataType& max_seq_len_index_dtype, const paddle::DataType& mm_token_num_len_dtype, const paddle::DataType& seq_lens_this_time_dtype, const paddle::DataType& cu_seqlens_q_dtype, const paddle::DataType& hidden_states_dtype) { return {hidden_states_dtype}; } PD_BUILD_STATIC_OP(extract_text_token_output) .Inputs({"max_seq_len", "max_seq_len_index", "mm_token_num_len", "seq_lens_this_time", "cu_seqlens_q", "hidden_states"}) .Outputs({"output"}) .SetKernelFn(PD_KERNEL(ExtractTextTokenOutput)) .SetInferShapeFn(PD_INFER_SHAPE(ExtractTextTokenOutputInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(ExtractTextTokenOutputInferDtype));