// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "fastdeploy/function/sort.h" #include "fastdeploy/function/eigen.h" #include "fastdeploy/function/transpose.h" #include #include #include namespace fastdeploy { namespace function { template static void FullSort(Type input_height, Type input_width, int input_dim, const FDTensor* input, FDTensor* out, FDTensor* indices, bool descending) { out->Allocate(input->Shape(), input->Dtype()); indices->Allocate(input->Shape(), TypeToDataType::dtype); T* t_out = reinterpret_cast(out->Data()); Type* t_indices = reinterpret_cast(indices->Data()); for (Type i = 0; i < input_height; ++i) { std::vector> col_vec; col_vec.reserve(input_width); if (input_dim == 1) { auto e_input = EigenVector::Flatten(*input); for (Type j = 0; j < input_width; ++j) { col_vec.push_back(std::pair(e_input(j), j)); } } else { auto e_input = EigenMatrix::Reshape(*input, input_dim - 1); for (Type j = 0; j < input_width; ++j) { col_vec.push_back(std::pair(e_input(i, j), j)); } } std::sort(col_vec.begin(), col_vec.end(), [&](const std::pair& l, const std::pair& r) { if (descending) return (std::isnan(static_cast(l.first)) && !std::isnan(static_cast(r.first))) || (l.first > r.first); else return (!std::isnan(static_cast(l.first)) && std::isnan(static_cast(r.first))) || (l.first < r.first); }); for (Type j = 0; j < input_width; ++j) { t_out[i * input_width + j] = col_vec[j].first; t_indices[i * input_width + j] = col_vec[j].second; } } } template void SortKernel(const FDTensor& x, FDTensor* out, FDTensor* indices, FDDataType indices_type, bool descending, int axis) { auto input_shape = x.Shape(); int rank = input_shape.size(); axis = (axis < 0) ? (rank + axis) : axis; // Do full sort if (axis == -1 || axis + 1 == rank) { const int64_t input_width = input_shape[rank - 1]; const int64_t input_height = x.Numel() / input_width; FD_VISIT_INT_TYPES(indices_type, "FullSort", ([&] { FullSort(input_height, input_width, rank, &x, out, indices, descending); })); } else { // If not full sort do transpose std::vector trans; for (int i = 0; i < axis; i++) { trans.push_back(i); } trans.push_back(rank - 1); for (int i = axis + 1; i < rank - 1; i++) { trans.push_back(i); } trans.push_back(axis); FDTensor trans_inp; Transpose(x, &trans_inp, trans); const int64_t input_width = input_shape[axis]; const int64_t input_height = x.Numel() / input_width; FD_VISIT_INT_TYPES(indices_type, "FullSort", ([&] { FullSort(input_height, input_width, rank, &trans_inp, out, indices, descending); })); // transpose back Transpose(*out, out, trans); Transpose(*indices, indices, trans); } } void Sort(const FDTensor& x, FDTensor* out, FDTensor* indices, int axis, bool descending, FDDataType indices_type) { FD_VISIT_INT_FLOAT_TYPES(x.dtype, "SortKernel", ([&] { SortKernel(x, out, indices, indices_type, descending, axis); })); } } // namespace function } // namespace fastdeploy