diff --git a/fastdeploy/function/functions.h b/fastdeploy/function/functions.h index f789a17ce..d2ffe6a0c 100644 --- a/fastdeploy/function/functions.h +++ b/fastdeploy/function/functions.h @@ -17,7 +17,6 @@ #include "fastdeploy/function/cast.h" #include "fastdeploy/function/clip.h" #include "fastdeploy/function/concat.h" -#include "fastdeploy/function/cuda_cast.h" #include "fastdeploy/function/cumprod.h" #include "fastdeploy/function/elementwise.h" #include "fastdeploy/function/full.h" @@ -28,6 +27,7 @@ #include "fastdeploy/function/pad.h" #include "fastdeploy/function/quantile.h" #include "fastdeploy/function/reduce.h" +#include "fastdeploy/function/slice.h" #include "fastdeploy/function/softmax.h" #include "fastdeploy/function/sort.h" #include "fastdeploy/function/split.h" diff --git a/fastdeploy/function/math.cc b/fastdeploy/function/math.cc index 0b3446947..751b79083 100644 --- a/fastdeploy/function/math.cc +++ b/fastdeploy/function/math.cc @@ -28,11 +28,13 @@ namespace function { template void ActivationImpl(const FDTensor& X, FDTensor* Out, const Functor& functor) { FDASSERT(Out != nullptr, "Output Out should not be nullptr"); + FDTensor out_tmp; auto x = EigenVector::Flatten(X); - Out->Allocate(X.Shape(), X.Dtype()); - auto out = EigenVector::Flatten(*Out); + out_tmp.Allocate(X.Shape(), X.Dtype()); + auto out = EigenVector::Flatten(out_tmp); const auto& dev = *EigenDeviceWrapper::GetInstance()->GetDevice(); functor(dev, x, out); + *Out = std::move(out_tmp); } DEFINE_ACTIVATION_KERNEL(Sqrt, SqrtFunctor) diff --git a/fastdeploy/function/slice.cc b/fastdeploy/function/slice.cc new file mode 100644 index 000000000..8b50bf422 --- /dev/null +++ b/fastdeploy/function/slice.cc @@ -0,0 +1,167 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/function/slice.h" +#include "fastdeploy/function/eigen.h" + +#include + +namespace fastdeploy { +namespace function { + +std::vector GetSliceDims(const std::vector& in_dims, + const std::vector& axes, + const std::vector& starts, + const std::vector& ends, + std::vector* steps = nullptr) { + std::vector slice_dims(in_dims); + + for (size_t i = 0; i < axes.size(); ++i) { + int64_t axis = axes[i]; + if (in_dims[axis] == -1) { + continue; + } + + int64_t start = starts[i]; + int64_t end = ends[i]; + int64_t step = steps == nullptr ? 1 : (*steps)[i]; + + if (step > 0) { + slice_dims[axis] = (end - start + step - 1) / step; + } else { + slice_dims[axis] = (end - start + step + 1) / step; + } + } + return slice_dims; +} + +void CheckAndUpdateSliceAttrs(const std::vector& in_dims, + const std::vector& axes, + std::vector* starts, + std::vector* ends, + std::vector* steps = nullptr) { + for (size_t i = 0; i < axes.size(); ++i) { + int64_t axis = axes[i]; + FDASSERT(axis < in_dims.size(), + "The axis value should be less than the rank of input, " + "but received axes[%d] = %d, rank of input is %d.", + i, axis, in_dims.size()); + int64_t dim_value = in_dims[axis]; + + if (dim_value > 0) { + int64_t step = steps == nullptr ? 1 : (*steps)[i]; + FDASSERT(step != 0, "Step should not be 0, but received step = %d.", + step); + int64_t start = + (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i]; + start = std::max(start, static_cast(0)); + + int64_t end = + 0 < step && (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i]; + end = std::min(end, dim_value); + + if (step > 0) { + start = std::min(start, dim_value); + end = std::max(end, static_cast(0)); + FDASSERT(end > start, + "When step > 0, end should be greater than start, but " + "received end = %d, start = %d.", + end, start) + } else { + start = std::min(start, dim_value - 1); + if (end < -1) { + end += dim_value; + } + end = std::max(end, static_cast(-1)); + FDASSERT(start >= end, + "When step < 0, start should be greater than end, but " + "received start = %d, end = %d.", + start, end); + } + + (*starts)[i] = start; + (*ends)[i] = end; + } else if (dim_value == 0) { + (*starts)[i] = 0; + (*ends)[i] = 0; + } + } +} + +template +void SliceKernel(const FDTensor& x, const std::vector& axes, + const std::vector& starts, + const std::vector& ends, FDTensor* out) { + FDASSERT(starts.size() == axes.size(), + "The size of starts must be equal to the size of axes."); + FDASSERT(ends.size() == axes.size(), + "The size of ends must be equal to the size of axes."); + auto starts_idx = starts; + auto end_idx = ends; + auto in_dims = x.Shape(); + CheckAndUpdateSliceAttrs(in_dims, axes, &starts_idx, &end_idx); + auto slice_dims = GetSliceDims(in_dims, axes, starts, ends); + + auto offsets = Eigen::DSizes(); + auto extents = Eigen::DSizes(); + for (size_t i = 0; i < D; ++i) { + offsets[i] = 0; + extents[i] = slice_dims[i]; + } + for (size_t i = 0; i < axes.size(); ++i) { + offsets[axes[i]] = starts[i]; + } + + out->Allocate(slice_dims, x.Dtype()); + auto in_t = EigenTensor::From(x, in_dims); + auto out_t = EigenTensor::From(*out, slice_dims); + const auto& dev = *EigenDeviceWrapper::GetInstance()->GetDevice(); + out_t.device(dev) = in_t.slice(offsets, extents); +} + +void Slice(const FDTensor& x, const std::vector& axes, + const std::vector& starts, const std::vector& ends, + FDTensor* out) { + FD_VISIT_ALL_TYPES( + x.dtype, "SliceKernel", ([&] { + int rank = x.Shape().size(); + switch (rank) { + case 1: + SliceKernel(x, axes, starts, ends, out); + break; + case 2: + SliceKernel(x, axes, starts, ends, out); + break; + case 3: + SliceKernel(x, axes, starts, ends, out); + break; + case 4: + SliceKernel(x, axes, starts, ends, out); + break; + case 5: + SliceKernel(x, axes, starts, ends, out); + break; + case 6: + SliceKernel(x, axes, starts, ends, out); + break; + default: + FDASSERT(false, + "The rank of input should be less than 7, but received %d.", + rank); + } + })); +} + +} // namespace function +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/function/slice.h b/fastdeploy/function/slice.h new file mode 100644 index 000000000..d676a232e --- /dev/null +++ b/fastdeploy/function/slice.h @@ -0,0 +1,41 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fastdeploy/core/fd_tensor.h" + +namespace fastdeploy { +namespace function { + +/** This operator produces a slice of input along multiple axes. + @param x The input tensor. + @param axes Axes that starts and ends apply to. + @param starts If starts is a list or tuple, the elements of it should be + integers or Tensors with shape [1]. If starts is an Tensor, it should + be an 1-D Tensor. It represents starting indices of corresponding axis + in axes + @param ends If ends is a list or tuple, the elements of it should be + integers or Tensors with shape [1]. If ends is an Tensor, it should + be an 1-D Tensor . It represents ending indices of corresponding axis + in axes. + @param out The output tensor which stores the result. +*/ + +FASTDEPLOY_DECL void Slice(const FDTensor& x, const std::vector& axes, + const std::vector& starts, + const std::vector& ends, FDTensor* out); + +} // namespace function +} // namespace fastdeploy diff --git a/tests/function/test_slice.cc b/tests/function/test_slice.cc new file mode 100644 index 000000000..03a7fdfbd --- /dev/null +++ b/tests/function/test_slice.cc @@ -0,0 +1,69 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/core/fd_tensor.h" +#include "fastdeploy/function/slice.h" +#include "glog/logging.h" +#include "gtest_utils.h" +#include "gtest/gtest.h" +#include +#include + +namespace fastdeploy { +namespace function { + +std::vector CreateTestData() { + // Shape: [2, 3, 4] + std::vector x_data = { + 1.8428625, 0.6461913, 0.13740455, 0.11430702, 0.659926, 0.535816, + 0.7429162, 0.8456049, -1.21228176, 0.29970083, 0.8621713, 0.40894133, + 0.12684688, 2.1566195, -9.42884097, 20.8476526, 0.2458633, 0.669046, + 0.87888306, 0.6762589, 0.666453, 0.32523027, 0.4139388, 0.8341406}; + return x_data; +} + +TEST(fastdeploy, slice) { + CheckShape check_shape; + CheckData check_data; + FDTensor x, y; + auto test_data = CreateTestData(); + x.SetExternalData({2, 3, 4}, FDDataType::FP32, test_data.data()); + + // x[0:1] + Slice(x, {0}, {0}, {1}, &y); + std::vector result = {1.842862, 0.646191, 0.137405, 0.114307, + 0.659926, 0.535816, 0.742916, 0.845605, + -1.212282, 0.299701, 0.862171, 0.408941}; + check_shape(y.shape, {1, 3, 4}); + check_data(reinterpret_cast(y.Data()), result.data(), + result.size()); + + // x[:, 1:2] + Slice(x, {1}, {1}, {2}, &y); + result = {0.659926, 0.535816, 0.742916, 0.845605, + 0.245863, 0.669046, 0.878883, 0.676259}; + check_shape(y.shape, {2, 1, 4}); + check_data(reinterpret_cast(y.Data()), result.data(), + result.size()); + + // x[:, 0:1, 2:4] + Slice(x, {1, 2}, {0, 2}, {1, 4}, &y); + result = {0.137405, 0.114307, -9.428841, 20.847652}; + check_shape(y.shape, {2, 1, 2}); + check_data(reinterpret_cast(y.Data()), result.data(), + result.size()); +} + +} // namespace function +} // namespace fastdeploy