// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/extension.h" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) #endif #ifdef PADDLE_WITH_HIP #include #include #include #include #include #include namespace cub = hipcub; #endif template class PDTraits; template <> class PDTraits { public: typedef float DataType; typedef float data_t; }; template <> class PDTraits { public: typedef half DataType; typedef paddle::float16 data_t; }; template <> class PDTraits { public: #ifdef PADDLE_WITH_HIP typedef hip_bfloat16 DataType; #else typedef __nv_bfloat16 DataType; #endif typedef paddle::bfloat16 data_t; }; template __global__ void get_value_by_id(const T *logits, const int *ids, T *logits_out, int bs, int seq_len, int length) { int bid = blockIdx.x; int tid = threadIdx.x; int idx = bid * blockDim.x + tid; for (int i = idx; i < bs * length; i += gridDim.x * blockDim.x) { int bi = i / length; int lane = i % length; int si = ids[bi]; if (si == -1) { si = 0; } const T *logits_now = logits + bi * seq_len * length + si * length; T *logits_out_now = logits_out + bi * length; logits_out_now[lane] = logits_now[lane]; } } template std::vector gather_idx(const paddle::Tensor &logits, const paddle::Tensor &gather_id) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; PD_CHECK(gather_id.dtype() == paddle::DataType::INT32); auto cu_stream = logits.stream(); std::vector logits_shape = logits.shape(); std::vector id_shape = gather_id.shape(); int logits_bs = logits_shape[0]; int seq_len = logits_shape[1]; int logits_len = logits_shape[2]; auto logits_out = paddle::empty({logits_bs, logits_len}, logits.type(), logits.place()); int id_bs = id_shape[0]; int64_t numels = logits_bs * logits_len; int block_size = 128; int grid_size = (numels + block_size - 1) / block_size; get_value_by_id<<>>( reinterpret_cast( const_cast(logits.data())), gather_id.data(), reinterpret_cast( const_cast(logits_out.data())), logits_bs, seq_len, logits_len); return {logits_out}; } std::vector GatherIdx(const paddle::Tensor &logits, const paddle::Tensor &gather_id) { switch (logits.type()) { case paddle::DataType::BFLOAT16: { return gather_idx(logits, gather_id); } case paddle::DataType::FLOAT16: { return gather_idx(logits, gather_id); } case paddle::DataType::FLOAT32: { return gather_idx(logits, gather_id); } default: { PD_THROW( "NOT supported data type. " "Only bfloat16, float16 and float32 are supported. "); break; } } } std::vector> GatherIdxInferShape( const std::vector &logits_shape, const std::vector &gather_id_shape) { std::vector out_shape = {logits_shape[0], logits_shape[2]}; return {out_shape}; } std::vector GatherIdxInferDtype( const paddle::DataType &logits_dtype, const paddle::DataType &gather_id_dtype) { return {logits_dtype}; } PD_BUILD_STATIC_OP(gather_idx) .Inputs({"logits", "gather_id"}) .Outputs({"logits_out"}) .SetKernelFn(PD_KERNEL(GatherIdx)) .SetInferShapeFn(PD_INFER_SHAPE(GatherIdxInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(GatherIdxInferDtype));