// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "paddle/extension.h" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) #endif template void RebuildPaddingCPUImpl(T *output_data, const T *input_data, const int *cu_seqlens_q_data, const int *seq_len_this_time_data, const int *seq_lens_decoder_data, const int *seq_lens_encoder_data, int max_input_length, int dim_embed, const int elem_nums) { for (int i = 0; i < elem_nums; ++i) { const int bi = i / dim_embed; const int bias_idx = i % dim_embed; int seq_id = 0; if (seq_len_this_time_data[bi] == 0) { continue; } if (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0) { continue; } if (seq_lens_encoder_data[bi] > 0) { seq_id = seq_lens_encoder_data[bi] - 1; } const int ori_token_idx = cu_seqlens_q_data[bi] + seq_id; const int src_offset = ori_token_idx * dim_embed + bias_idx; output_data[i] = input_data[src_offset]; } } template void RebuildAppendPaddingCPUImpl(T *output_data, const T *input_data, const int *cu_seqlens_q_data, const int *seq_len_this_time_data, const int *seq_lens_decoder_data, const int *seq_lens_encoder_data, const int *output_padding_offset_data, const int max_input_length, const int dim_embed, const int64_t output_elem_nums) { for (int i = 0; i < output_elem_nums; ++i) { int out_token_id = i / dim_embed; int ori_token_id = out_token_id + output_padding_offset_data[out_token_id]; int bi = ori_token_id / max_input_length; if (seq_len_this_time_data[bi] == 0 || (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0)) { continue; } int seq_id = 0; if (seq_lens_encoder_data[bi] > 0) { seq_id = seq_lens_encoder_data[bi] - 1; } int input_token_id = cu_seqlens_q_data[bi] + seq_id; int bias_idx = i % dim_embed; int src_offset = input_token_id * dim_embed + bias_idx; output_data[i] = input_data[src_offset]; } } std::vector RebuildPaddingCPU( const paddle::Tensor &tmp_out, const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &seq_len_this_time, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, const paddle::optional &output_padding_offset, int max_input_length) { auto tmp_out_cpu = tmp_out.copy_to(paddle::CPUPlace(), true); auto cu_seqlens_q_cpu = cu_seqlens_q.copy_to(paddle::CPUPlace(), true); auto seq_len_this_time_cpu = seq_len_this_time.copy_to(paddle::CPUPlace(), true); auto seq_lens_decoder_cpu = seq_lens_decoder.copy_to(paddle::CPUPlace(), true); auto seq_lens_encoder_cpu = seq_lens_encoder.copy_to(paddle::CPUPlace(), true); paddle::optional output_padding_offset_cpu; if (output_padding_offset) { output_padding_offset_cpu = output_padding_offset->copy_to(paddle::CPUPlace(), true); } int token_num = tmp_out_cpu.shape()[0]; int dim_embed = tmp_out_cpu.shape()[1]; int bsz = cu_seqlens_q_cpu.shape()[0] - 1; paddle::Tensor out; if (output_padding_offset_cpu) { int need_delete_token_num = 0; for (int i = 0; i < bsz; ++i) { if (seq_lens_encoder_cpu.data()[i] > 0) { need_delete_token_num += seq_lens_encoder_cpu.data()[i] - 1; } } int output_token_num = token_num - need_delete_token_num; out = paddle::full({output_token_num, dim_embed}, 0, tmp_out_cpu.dtype(), paddle::CPUPlace()); } else { out = paddle::full( {bsz, dim_embed}, 0, tmp_out_cpu.dtype(), paddle::CPUPlace()); } const int *cu_seqlens_q_data = cu_seqlens_q_cpu.data(); const int *seq_len_this_time_data = seq_len_this_time_cpu.data(); const int *seq_lens_decoder_data = seq_lens_decoder_cpu.data(); const int *seq_lens_encoder_data = seq_lens_encoder_cpu.data(); int elem_nums = out.numel(); if (output_padding_offset_cpu) { const int *output_padding_offset_data = output_padding_offset_cpu->data(); switch (tmp_out_cpu.dtype()) { case paddle::DataType::FLOAT32: RebuildAppendPaddingCPUImpl(out.data(), tmp_out_cpu.data(), cu_seqlens_q_data, seq_len_this_time_data, seq_lens_decoder_data, seq_lens_encoder_data, output_padding_offset_data, max_input_length, dim_embed, elem_nums); break; case paddle::DataType::FLOAT16: RebuildAppendPaddingCPUImpl( out.data(), tmp_out_cpu.data(), cu_seqlens_q_data, seq_len_this_time_data, seq_lens_decoder_data, seq_lens_encoder_data, output_padding_offset_data, max_input_length, dim_embed, elem_nums); break; case paddle::DataType::BFLOAT16: RebuildAppendPaddingCPUImpl( out.data(), tmp_out_cpu.data(), cu_seqlens_q_data, seq_len_this_time_data, seq_lens_decoder_data, seq_lens_encoder_data, output_padding_offset_data, max_input_length, dim_embed, elem_nums); break; default: PD_THROW( "Unsupported data type for rebuild_padding_cpu. " "Only float32, float16, and bfloat16 are supported."); } } else { switch (tmp_out_cpu.dtype()) { case paddle::DataType::FLOAT32: RebuildPaddingCPUImpl(out.data(), tmp_out_cpu.data(), cu_seqlens_q_data, seq_len_this_time_data, seq_lens_decoder_data, seq_lens_encoder_data, max_input_length, dim_embed, elem_nums); break; case paddle::DataType::FLOAT16: RebuildPaddingCPUImpl( out.data(), tmp_out_cpu.data(), cu_seqlens_q_data, seq_len_this_time_data, seq_lens_decoder_data, seq_lens_encoder_data, max_input_length, dim_embed, elem_nums); break; case paddle::DataType::BFLOAT16: RebuildPaddingCPUImpl( out.data(), tmp_out_cpu.data(), cu_seqlens_q_data, seq_len_this_time_data, seq_lens_decoder_data, seq_lens_encoder_data, max_input_length, dim_embed, elem_nums); break; default: PD_THROW( "Unsupported data type for rebuild_padding_cpu. " "Only float32, float16, and bfloat16 are supported."); } } return {out}; } std::vector> RebuildPaddingInferShape( const std::vector &tmp_out_shape, const std::vector &cu_seqlens_q_shape, const std::vector &seq_len_this_time_shape, const std::vector &seq_lens_decoder_shape, const std::vector &seq_lens_encoder_shape, const paddle::optional> &output_padding_offset_shape) { int64_t dim_embed = tmp_out_shape[1]; if (output_padding_offset_shape) { return {{-1, dim_embed}}; } else { int64_t bsz = cu_seqlens_q_shape[0] - 1; return {{bsz, dim_embed}}; } } std::vector RebuildPaddingInferDtype( const paddle::DataType &tmp_out_dtype, const paddle::DataType &cu_seqlens_q_dtype, const paddle::DataType &seq_len_this_time_dtype, const paddle::DataType &seq_lens_decoder_dtype, const paddle::DataType &seq_lens_encoder_dtype, const paddle::optional &output_padding_offset_dtype) { return {tmp_out_dtype}; } PD_BUILD_STATIC_OP(rebuild_padding_cpu) .Inputs({"tmp_out", "cu_seqlens_q", "seq_len_this_time", "seq_lens_decoder", "seq_lens_encoder", paddle::Optional("output_padding_offset")}) .Outputs({"out"}) .Attrs({"max_input_length: int"}) .SetKernelFn(PD_KERNEL(RebuildPaddingCPU)) .SetInferShapeFn(PD_INFER_SHAPE(RebuildPaddingInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(RebuildPaddingInferDtype));