// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "fastdeploy/function/slice.h" #include "fastdeploy/function/eigen.h" #include namespace fastdeploy { namespace function { std::vector GetSliceDims(const std::vector& in_dims, const std::vector& axes, const std::vector& starts, const std::vector& ends, std::vector* steps = nullptr) { std::vector slice_dims(in_dims); for (size_t i = 0; i < axes.size(); ++i) { int64_t axis = axes[i]; if (in_dims[axis] == -1) { continue; } int64_t start = starts[i]; int64_t end = ends[i]; int64_t step = steps == nullptr ? 1 : (*steps)[i]; if (step > 0) { slice_dims[axis] = (end - start + step - 1) / step; } else { slice_dims[axis] = (end - start + step + 1) / step; } } return slice_dims; } void CheckAndUpdateSliceAttrs(const std::vector& in_dims, const std::vector& axes, std::vector* starts, std::vector* ends, std::vector* steps = nullptr) { for (size_t i = 0; i < axes.size(); ++i) { int64_t axis = axes[i]; FDASSERT(axis < in_dims.size(), "The axis value should be less than the rank of input, " "but received axes[%d] = %d, rank of input is %d.", i, axis, in_dims.size()); int64_t dim_value = in_dims[axis]; if (dim_value > 0) { int64_t step = steps == nullptr ? 1 : (*steps)[i]; FDASSERT(step != 0, "Step should not be 0, but received step = %d.", step); int64_t start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i]; start = (std::max)(start, static_cast(0)); int64_t end = 0 < step && (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i]; end = (std::min)(end, dim_value); if (step > 0) { start = (std::min)(start, dim_value); end = (std::max)(end, static_cast(0)); FDASSERT(end > start, "When step > 0, end should be greater than start, but " "received end = %d, start = %d.", end, start) } else { start = (std::min)(start, dim_value - 1); if (end < -1) { end += dim_value; } end = (std::max)(end, static_cast(-1)); FDASSERT(start >= end, "When step < 0, start should be greater than end, but " "received start = %d, end = %d.", start, end); } (*starts)[i] = start; (*ends)[i] = end; } else if (dim_value == 0) { (*starts)[i] = 0; (*ends)[i] = 0; } } } template void SliceKernel(const FDTensor& x, const std::vector& axes, const std::vector& starts, const std::vector& ends, FDTensor* out) { FDASSERT(starts.size() == axes.size(), "The size of starts must be equal to the size of axes."); FDASSERT(ends.size() == axes.size(), "The size of ends must be equal to the size of axes."); auto starts_idx = starts; auto end_idx = ends; auto in_dims = x.Shape(); CheckAndUpdateSliceAttrs(in_dims, axes, &starts_idx, &end_idx); auto slice_dims = GetSliceDims(in_dims, axes, starts, ends); auto offsets = Eigen::DSizes(); auto extents = Eigen::DSizes(); for (size_t i = 0; i < D; ++i) { offsets[i] = 0; extents[i] = slice_dims[i]; } for (size_t i = 0; i < axes.size(); ++i) { offsets[axes[i]] = starts[i]; } out->Allocate(slice_dims, x.Dtype()); auto in_t = EigenTensor::From(x, in_dims); auto out_t = EigenTensor::From(*out, slice_dims); const auto& dev = *EigenDeviceWrapper::GetInstance()->GetDevice(); out_t.device(dev) = in_t.slice(offsets, extents); } void Slice(const FDTensor& x, const std::vector& axes, const std::vector& starts, const std::vector& ends, FDTensor* out) { FD_VISIT_ALL_TYPES( x.dtype, "SliceKernel", ([&] { int rank = x.Shape().size(); switch (rank) { case 1: SliceKernel(x, axes, starts, ends, out); break; case 2: SliceKernel(x, axes, starts, ends, out); break; case 3: SliceKernel(x, axes, starts, ends, out); break; case 4: SliceKernel(x, axes, starts, ends, out); break; case 5: SliceKernel(x, axes, starts, ends, out); break; case 6: SliceKernel(x, axes, starts, ends, out); break; default: FDASSERT(false, "The rank of input should be less than 7, but received %d.", rank); } })); } } // namespace function } // namespace fastdeploy