[Functions] Add quantile function (#700)

* Add sort function

* Add isfinite function

* upgrade isinf isnan

* Add Scalar to FDTensor

* Add floor, ceil function

* add cast functions

* Update out_tmp

* Update quantile

* add gather scatter along axis

* finish quantile function

* Add quantile unittest

* refresh code style for test source code

* Add comments

* Add full function

* Add scalar to fd tensor

* Add full unittest

* Add functions headers

* move fdtensor operators to fastdeploy namespace
This commit is contained in:
Jack Zhou
2022-11-28 09:51:40 +08:00
committed by GitHub
parent 4e74ac06fb
commit 129dda7809
37 changed files with 1567 additions and 75 deletions

View File

@@ -0,0 +1,46 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/function/cast.h"
#include <algorithm>
namespace fastdeploy {
namespace function {
template <typename InT, typename OutT> struct CastOpTransformFunctor {
OutT operator()(InT in) const { return static_cast<OutT>(in); }
};
template <typename InT>
void CastKernel(const FDTensor& x, FDTensor* out, FDDataType output_dtype) {
FD_VISIT_ALL_TYPES(output_dtype, "CastOpTransformFunctor", ([&] {
auto* in_begin = reinterpret_cast<const InT*>(x.Data());
auto* in_end = in_begin + x.Numel();
FDTensor out_tmp;
out_tmp.Allocate(x.Shape(), output_dtype);
auto* out_begin = reinterpret_cast<data_t*>(out_tmp.Data());
std::transform(in_begin, in_end, out_begin,
CastOpTransformFunctor<InT, data_t>());
*out = std::move(out_tmp);
}));
}
void Cast(const FDTensor& x, FDTensor* out, FDDataType output_dtype) {
FD_VISIT_ALL_TYPES(x.dtype, "CastKernel",
([&] { CastKernel<data_t>(x, out, output_dtype); }));
}
} // namespace function
} // namespace fastdeploy

View File

@@ -0,0 +1,31 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/core/fd_tensor.h"
namespace fastdeploy {
namespace function {
/** Cast x to output data type element-wise. Only for float type FDTensor
@param x The input tensor.
@param out The output tensor which stores the result.
@param output_dtype The type of output tensor.
*/
FASTDEPLOY_DECL void Cast(const FDTensor& x, FDTensor* out,
FDDataType output_dtype);
} // namespace function
} // namespace fastdeploy

View File

@@ -22,7 +22,7 @@ namespace function {
/** Excute the concatenate operation for input FDTensor along given axis.
@param x The input tensor.
@param out The output tensor which stores the result.
@param axisi Axis which will be concatenated.
@param axis Axis which will be concatenated.
*/
FASTDEPLOY_DECL void Concat(const std::vector<FDTensor>& x, FDTensor* out,

View File

@@ -32,45 +32,21 @@ void Add(const FDTensor& x, const FDTensor& y, FDTensor* out) {
([&] { AddRawKernel<data_t>()(x, y, -1, out); }));
}
FDTensor operator+(const FDTensor& x, const FDTensor& y) {
FDTensor out;
Add(x, y, &out);
return out;
}
void Subtract(const FDTensor& x, const FDTensor& y, FDTensor* out) {
FD_VISIT_ALL_TYPES(x.dtype, "SubtractRawKernel",
([&] { SubtractRawKernel<data_t>()(x, y, -1, out); }));
}
FDTensor operator-(const FDTensor& x, const FDTensor& y) {
FDTensor out;
Subtract(x, y, &out);
return out;
}
void Multiply(const FDTensor& x, const FDTensor& y, FDTensor* out) {
FD_VISIT_ALL_TYPES(x.dtype, "MultiplyRawKernel",
([&] { MultiplyRawKernel<data_t>()(x, y, -1, out); }));
}
FDTensor operator*(const FDTensor& x, const FDTensor& y) {
FDTensor out;
Multiply(x, y, &out);
return out;
}
void Divide(const FDTensor& x, const FDTensor& y, FDTensor* out) {
FD_VISIT_ALL_TYPES(x.dtype, "DivideRawKernel",
([&] { DivideRawKernel<data_t>()(x, y, -1, out); }));
}
FDTensor operator/(const FDTensor& x, const FDTensor& y) {
FDTensor out;
Divide(x, y, &out);
return out;
}
template <typename T> struct MaximumRawKernel {
void operator()(const FDTensor& x, const FDTensor& y, int axis,
FDTensor* out) {
@@ -85,4 +61,29 @@ void Maximum(const FDTensor& x, const FDTensor& y, FDTensor* out) {
}
} // namespace function
FDTensor operator+(const FDTensor& x, const FDTensor& y) {
FDTensor out;
function::Add(x, y, &out);
return out;
}
FDTensor operator-(const FDTensor& x, const FDTensor& y) {
FDTensor out;
function::Subtract(x, y, &out);
return out;
}
FDTensor operator*(const FDTensor& x, const FDTensor& y) {
FDTensor out;
function::Multiply(x, y, &out);
return out;
}
FDTensor operator/(const FDTensor& x, const FDTensor& y) {
FDTensor out;
function::Divide(x, y, &out);
return out;
}
} // namespace fastdeploy

View File

@@ -26,8 +26,6 @@ namespace function {
*/
FASTDEPLOY_DECL void Add(const FDTensor& x, const FDTensor& y, FDTensor* out);
FASTDEPLOY_DECL FDTensor operator+(const FDTensor& x, const FDTensor& y);
/** Excute the subtract operation for input FDTensors. *out = x - y.
@param x The input tensor.
@param y The input tensor.
@@ -36,8 +34,6 @@ FASTDEPLOY_DECL FDTensor operator+(const FDTensor& x, const FDTensor& y);
FASTDEPLOY_DECL void Subtract(const FDTensor& x, const FDTensor& y,
FDTensor* out);
FASTDEPLOY_DECL FDTensor operator-(const FDTensor& x, const FDTensor& y);
/** Excute the multiply operation for input FDTensors. *out = x * y.
@param x The input tensor.
@param y The input tensor.
@@ -46,7 +42,6 @@ FASTDEPLOY_DECL FDTensor operator-(const FDTensor& x, const FDTensor& y);
FASTDEPLOY_DECL void Multiply(const FDTensor& x, const FDTensor& y,
FDTensor* out);
FASTDEPLOY_DECL FDTensor operator*(const FDTensor& x, const FDTensor& y);
/** Excute the divide operation for input FDTensors. *out = x / y.
@param x The input tensor.
@param y The input tensor.
@@ -54,7 +49,6 @@ FASTDEPLOY_DECL FDTensor operator*(const FDTensor& x, const FDTensor& y);
*/
FASTDEPLOY_DECL void Divide(const FDTensor& x, const FDTensor& y,
FDTensor* out);
FASTDEPLOY_DECL FDTensor operator/(const FDTensor& x, const FDTensor& y);
/** Excute the maximum operation for input FDTensors. *out = max(x, y).
@param x The input tensor.
@@ -65,4 +59,13 @@ FASTDEPLOY_DECL void Maximum(const FDTensor& x, const FDTensor& y,
FDTensor* out);
} // namespace function
FASTDEPLOY_DECL FDTensor operator+(const FDTensor& x, const FDTensor& y);
FASTDEPLOY_DECL FDTensor operator-(const FDTensor& x, const FDTensor& y);
FASTDEPLOY_DECL FDTensor operator*(const FDTensor& x, const FDTensor& y);
FASTDEPLOY_DECL FDTensor operator/(const FDTensor& x, const FDTensor& y);
} // namespace fastdeploy

View File

@@ -0,0 +1,42 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/function/full.h"
#include "fastdeploy/function/eigen.h"
#include <algorithm>
namespace fastdeploy {
namespace function {
template <typename T> void FullValue(FDTensor* tensor, const Scalar& val) {
auto t = EigenVector<T>::Flatten(*tensor);
auto& place = *EigenDeviceWrapper::GetInstance()->GetDevice();
t.device(place) = t.constant(val.to<T>());
}
void Full(const Scalar& value, const std::vector<int64_t>& shape, FDTensor* out,
FDDataType dtype) {
FD_VISIT_ALL_TYPES(dtype, "Full", ([&] {
out->Allocate(shape, dtype);
FullValue<data_t>(out, value);
}));
}
void FullLike(const FDTensor& x, const Scalar& value, FDTensor* out,
FDDataType dtype) {
Full(value, x.Shape(), out, dtype);
}
} // namespace function
} // namespace fastdeploy

View File

@@ -0,0 +1,44 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/core/fd_scalar.h"
#include "fastdeploy/core/fd_tensor.h"
namespace fastdeploy {
namespace function {
/** Fill the value to tensor
@param value The value to be filled in tensor
@param shape The shape of output tensor.
@param out The output tensor which stores the result.
@param dtype The data type of output tensor. Default to float32
*/
FASTDEPLOY_DECL void Full(const Scalar& value,
const std::vector<int64_t>& shape, FDTensor* out,
FDDataType dtype = FDDataType::FP32);
/** Fill the value to tensor
@param x The input tensor.
@param value The value to be filled in tensor
@param out The output tensor which stores the result.
@param dtype The data type of output tensor. Default to float32
*/
FASTDEPLOY_DECL void FullLike(const FDTensor& x, const Scalar& value,
FDTensor* out,
FDDataType dtype = FDDataType::FP32);
} // namespace function
} // namespace fastdeploy

View File

@@ -0,0 +1,35 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/function/cast.h"
#include "fastdeploy/function/clip.h"
#include "fastdeploy/function/concat.h"
#include "fastdeploy/function/cuda_cast.h"
#include "fastdeploy/function/cumprod.h"
#include "fastdeploy/function/elementwise.h"
#include "fastdeploy/function/full.h"
#include "fastdeploy/function/gather_scatter_along_axis.h"
#include "fastdeploy/function/isfinite.h"
#include "fastdeploy/function/linspace.h"
#include "fastdeploy/function/math.h"
#include "fastdeploy/function/pad.h"
#include "fastdeploy/function/quantile.h"
#include "fastdeploy/function/reduce.h"
#include "fastdeploy/function/softmax.h"
#include "fastdeploy/function/sort.h"
#include "fastdeploy/function/split.h"
#include "fastdeploy/function/tile.h"
#include "fastdeploy/function/transpose.h"

View File

@@ -0,0 +1,125 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/function/gather_scatter_along_axis.h"
#include "fastdeploy/function/tile.h"
namespace fastdeploy {
namespace function {
class TensorAssign {
public:
template <typename tensor_t>
void operator()(tensor_t* self_data, tensor_t* src_data) const {
*self_data = *src_data;
}
};
static TensorAssign tensor_assign;
template <typename T, typename index_t = int64_t, bool is_scatter_like = true>
struct GatherScatterFunctor {
template <typename func_t>
void operator()(const FDTensor& x, int axis, const FDTensor& index,
FDTensor* result, const func_t& reduce_op) {
if (index.Numel() == 0) {
return;
}
result->Allocate(index.Shape(), x.Dtype());
const T* x_data = reinterpret_cast<const T*>(x.Data());
const index_t* index_data = reinterpret_cast<const index_t*>(index.Data());
T* result_data = reinterpret_cast<T*>(result->Data());
int64_t x_size = x.Numel();
int64_t index_size = index.Numel();
int64_t result_size = result->Numel();
auto x_dims = x.Shape();
auto index_dims = index.Shape();
auto result_dims = result->Shape();
if (x_size == 0 || result_size == 0 || index_size == 0) {
FDASSERT(false, "zero size input found, self_size, result_size, "
"index_size cannot be 0");
return;
}
int select_dim_size = index_dims[axis];
// index matrix has different shape with self matrix or src matrix.
int replaced_select_dim_size =
is_scatter_like ? result_dims[axis] : x_dims[axis];
int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
for (int64_t i = 0; i < axis; ++i) {
inner_dim_size *= index_dims[i];
}
for (int i = axis + 1; i < index_dims.size(); i++) {
outer_dim_size *= index_dims[i];
}
int64_t index_idx = 0;
int64_t self_idx, src_idx;
// N layer loop squeezed into 3 layers loop
for (int64_t i = 0; i < inner_dim_size; i++) {
for (int64_t j = 0; j < select_dim_size; j++) {
for (int64_t k = 0; k < outer_dim_size; k++) {
int64_t index = index_data[index_idx];
// This index might out of bound of index matrix's index, so here
// multiply the replaced_select_dim_size.
int64_t replace_index = k + index * outer_dim_size +
i * outer_dim_size * replaced_select_dim_size;
self_idx = is_scatter_like ? replace_index : index_idx;
src_idx = is_scatter_like ? index_idx : replace_index;
reduce_op((T*)(result_data + self_idx), // NOLINT
(T*)(x_data + src_idx)); // NOLINT
index_idx++;
}
}
}
}
};
template <typename T> struct GatherFunctor {
void operator()(const FDTensor& x, int axis, const FDTensor& index,
FDTensor* result) {
FD_VISIT_INT_TYPES(index.Dtype(), "GatherFunctor", [&]() {
auto x_shape = x.Shape();
auto index_shape = index.Shape();
std::vector<int64_t> repeat_times(x_shape.size(), 1);
for (int i = 0; i < x_shape.size(); ++i) {
repeat_times[i] = x_shape[i] / index_shape[i];
}
repeat_times[axis] = 1;
FDTensor gs_index;
Tile(index, repeat_times, &gs_index);
GatherScatterFunctor<T, data_t, /*is_scatter_like=*/false>()(
x, axis, gs_index, result, tensor_assign);
});
}
};
void GatherAlongAxis(const FDTensor& x, const FDTensor& index, FDTensor* result,
int axis) {
int rank = x.Shape().size();
FDASSERT(axis >= -rank && axis < rank,
"axis should be in range [-%d, %d - 1].", rank, rank - 1);
if (axis < 0) {
axis += rank;
}
FD_VISIT_ALL_TYPES(x.Dtype(), "GatherAlongAxis", [&]() {
GatherFunctor<data_t>()(x, axis, index, result);
});
}
} // namespace function
} // namespace fastdeploy

View File

@@ -0,0 +1,33 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/core/fd_tensor.h"
namespace fastdeploy {
namespace function {
/** Output is obtained by gathering entries of axis of x indexed by index and
* concatenate them together.
@param x The input tensor.
@param index The index of a tensor to gather.
@param out The output tensor which stores the result.
@param axis Axis which will be gathered.
*/
void GatherAlongAxis(const FDTensor& x, const FDTensor& index, FDTensor* result,
int axis);
} // namespace function
} // namespace fastdeploy

View File

@@ -0,0 +1,111 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/function/isfinite.h"
#include "fastdeploy/core/float16.h"
#include <algorithm>
#include <type_traits>
namespace fastdeploy {
namespace function {
template <typename T, typename OutT, class Enable = void> struct IsNanFunctor {
OutT operator()(const T& a) const { return static_cast<OutT>(std::isnan(a)); }
};
template <typename T, typename OutT>
struct IsNanFunctor<T, OutT,
typename std::enable_if<std::is_integral<T>::value>::type> {
OutT operator()(const T& a) const { return static_cast<OutT>(false); }
};
template <typename OutT> struct IsNanFunctor<fastdeploy::float16, OutT, void> {
OutT operator()(const fastdeploy::float16& a) const {
return static_cast<OutT>(fastdeploy::isnan(a));
}
};
template <typename T, typename OutT, class Enable = void> struct IsInfFunctor {
OutT operator()(const T& a) const { return static_cast<OutT>(std::isinf(a)); }
};
template <typename T, typename OutT>
struct IsInfFunctor<T, OutT,
typename std::enable_if<std::is_integral<T>::value>::type> {
OutT operator()(const T& a) const { return static_cast<OutT>(false); }
};
template <typename OutT> struct IsInfFunctor<fastdeploy::float16, OutT, void> {
OutT operator()(const fastdeploy::float16& a) const {
return static_cast<OutT>(fastdeploy::isinf(a));
}
};
template <typename T, typename OutT, class Enable = void>
struct IsFiniteFunctor {
OutT operator()(const T& a) const {
return static_cast<OutT>(std::isfinite(a));
}
};
template <typename T, typename OutT>
struct IsFiniteFunctor<
T, OutT, typename std::enable_if<std::is_integral<T>::value>::type> {
OutT operator()(const T& a) const { return static_cast<OutT>(true); }
};
template <typename OutT>
struct IsFiniteFunctor<fastdeploy::float16, OutT, void> {
OutT operator()(const fastdeploy::float16& a) const {
return static_cast<OutT>(fastdeploy::isfinite(a));
}
};
#define DEFINE_ISFINITE_KERNEL(isfinite_kernel, functor) \
template <typename T> \
void isfinite_kernel(const FDTensor& x, FDTensor* out, FDDataType dtype) { \
FD_VISIT_ALL_TYPES(dtype, #isfinite_kernel, ([&] { \
out->Allocate(x.Shape(), dtype); \
functor<T, data_t> unary_func; \
data_t* out_ptr = \
reinterpret_cast<data_t*>(out->Data()); \
const T* input_ptr = \
reinterpret_cast<const T*>(x.Data()); \
std::transform(input_ptr, input_ptr + x.Numel(), \
out_ptr, unary_func); \
})); \
}
DEFINE_ISFINITE_KERNEL(IsNanKernel, IsNanFunctor)
DEFINE_ISFINITE_KERNEL(IsInfKernel, IsInfFunctor)
DEFINE_ISFINITE_KERNEL(IsFiniteKernel, IsFiniteFunctor)
#undef DEFINE_ISFINITE_KERNEL
void IsNan(const FDTensor& x, FDTensor* out, FDDataType dtype) {
FD_VISIT_FLOAT_TYPES(x.dtype, "IsNanKernel",
([&] { IsNanKernel<data_t>(x, out, dtype); }));
}
void IsInf(const FDTensor& x, FDTensor* out, FDDataType dtype) {
FD_VISIT_FLOAT_TYPES(x.dtype, "IsInfKernel",
([&] { IsInfKernel<data_t>(x, out, dtype); }));
}
void IsFinite(const FDTensor& x, FDTensor* out, FDDataType dtype) {
FD_VISIT_FLOAT_TYPES(x.dtype, "IsFiniteKernel",
([&] { IsFiniteKernel<data_t>(x, out, dtype); }));
}
} // namespace function
} // namespace fastdeploy

View File

@@ -0,0 +1,47 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/core/fd_tensor.h"
namespace fastdeploy {
namespace function {
/** Return whether every element of input tensor is NaN or not.
@param x The input tensor.
@param out The output tensor which stores the result.
@param dtype The output data type
*/
FASTDEPLOY_DECL void IsNan(const FDTensor& x, FDTensor* out,
FDDataType dtype = FDDataType::BOOL);
/** Return whether every element of input tensor is Inf or not.
@param x The input tensor.
@param out The output tensor which stores the result.
@param dtype The output data type
*/
FASTDEPLOY_DECL void IsInf(const FDTensor& x, FDTensor* out,
FDDataType dtype = FDDataType::BOOL);
/** Return whether every element of input tensor is finite or not.
@param x The input tensor.
@param out The output tensor which stores the result.
@param dtype The output data type
*/
FASTDEPLOY_DECL void IsFinite(const FDTensor& x, FDTensor* out,
FDDataType dtype = FDDataType::BOOL);
} // namespace function
} // namespace fastdeploy

View File

@@ -40,6 +40,8 @@ DEFINE_ACTIVATION_KERNEL(Log, LogFunctor)
DEFINE_ACTIVATION_KERNEL(Round, RoundFunctor)
DEFINE_ACTIVATION_KERNEL(Exp, ExpFunctor)
DEFINE_ACTIVATION_KERNEL(Abs, AbsFunctor)
DEFINE_ACTIVATION_KERNEL(Ceil, CeilFunctor)
DEFINE_ACTIVATION_KERNEL(Floor, FloorFunctor)
void Sqrt(const FDTensor& x, FDTensor* out) {
FD_VISIT_FLOAT_TYPES(x.dtype, "SqrtKernel",
@@ -66,5 +68,15 @@ void Abs(const FDTensor& x, FDTensor* out) {
([&] { AbsKernel<data_t>(x, out); }));
}
void Ceil(const FDTensor& x, FDTensor* out) {
FD_VISIT_FLOAT_TYPES(x.dtype, "CeilKernel",
([&] { CeilKernel<data_t>(x, out); }));
}
void Floor(const FDTensor& x, FDTensor* out) {
FD_VISIT_FLOAT_TYPES(x.dtype, "FloorKernel",
([&] { FloorKernel<data_t>(x, out); }));
}
} // namespace function
} // namespace fastdeploy

View File

@@ -49,5 +49,17 @@ FASTDEPLOY_DECL void Exp(const FDTensor& x, FDTensor* out);
*/
FASTDEPLOY_DECL void Abs(const FDTensor& x, FDTensor* out);
/** Computes ceil of x element-wise. Only for float type FDTensor
@param x The input tensor.
@param out The output tensor which stores the result.
*/
FASTDEPLOY_DECL void Ceil(const FDTensor& x, FDTensor* out);
/** Computes floor of x element-wise. Only for float type FDTensor
@param x The input tensor.
@param out The output tensor which stores the result.
*/
FASTDEPLOY_DECL void Floor(const FDTensor& x, FDTensor* out);
} // namespace function
} // namespace fastdeploy

View File

@@ -61,5 +61,21 @@ template <typename T> struct AbsFunctor {
}
};
// ceil(x) = ceiling(x)
template <typename T> struct CeilFunctor {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.ceil();
}
};
// floor(x) = flooring(x)
template <typename T> struct FloorFunctor {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.floor();
}
};
} // namespace function
} // namespace fastdeploy

View File

@@ -0,0 +1,130 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/function/quantile.h"
#include "fastdeploy/core/fd_scalar.h"
#include "fastdeploy/function/cast.h"
#include "fastdeploy/function/concat.h"
#include "fastdeploy/function/elementwise.h"
#include "fastdeploy/function/gather_scatter_along_axis.h"
#include "fastdeploy/function/isfinite.h"
#include "fastdeploy/function/math.h"
#include "fastdeploy/function/reduce.h"
#include "fastdeploy/function/sort.h"
#include "fastdeploy/function/transpose.h"
#include <algorithm>
#include <cmath>
#include <numeric>
namespace fastdeploy {
namespace function {
template <typename T>
void QuantileKernel(const FDTensor& x, const std::vector<double>& q,
const std::vector<int>& axis, FDTensor* out) {
FDASSERT(q.size() > 0, "q should not be empty.");
FDASSERT(axis.size() > 0, "axis should not be empty.");
std::vector<int64_t> axis_src;
std::vector<int64_t> out_shape = x.Shape();
int64_t rank = x.Shape().size();
for (auto axis_single : axis) {
FDASSERT(axis_single >= -rank && axis_single < rank,
"The axis is expected to be in range of [%d, %d), but got %d",
-rank, rank, axis_single);
if (axis_single < 0) {
axis_single += rank;
}
axis_src.push_back(axis_single);
out_shape[axis_single] = 1;
}
std::vector<int64_t> axis_dst;
for (int64_t i = 0; i < rank; ++i) {
if (std::find(axis_src.begin(), axis_src.end(), i) == axis_src.end()) {
axis_dst.push_back(i);
}
}
axis_dst.insert(axis_dst.end(), axis_src.begin(), axis_src.end());
FDTensor y;
Transpose(x, &y, axis_dst);
std::vector<int64_t> y_shape(rank - axis_src.size(), 0);
y_shape.push_back(-1);
y.Reshape({y_shape});
int64_t target_axis = rank - 1;
FDTensor mask, valid_counts, mask_any;
IsNan(y, &mask);
Any(mask, &mask_any, {target_axis}, true);
bool* mask_data = reinterpret_cast<bool*>(mask.Data());
std::transform(mask_data, mask_data + mask.Numel(), mask_data,
[](const bool& val) { return !val; });
Cast(mask_any, &mask_any, FDDataType::FP64);
Cast(mask, &mask, FDDataType::FP64);
Sum(mask, &valid_counts, {target_axis}, true);
FDTensor one_tensor(Scalar(static_cast<double>(1.0)));
std::vector<FDTensor> indices;
FDTensor last_index(Scalar(static_cast<double>(x.Shape()[target_axis])));
for (auto q_num : q) {
FDASSERT(q_num >= 0 && q_num <= 1, "q should be in range [0, 1]");
FDTensor q_tensor(static_cast<double>(q_num));
FDTensor index = q_tensor * (valid_counts - one_tensor);
index = mask_any * last_index + (one_tensor - mask_any) * index;
indices.push_back(index);
}
std::vector<FDTensor> outputs;
FDTensor sorted_tensor, sorted_indices_tensor;
Sort(y, &sorted_tensor, &sorted_indices_tensor, target_axis);
Cast(sorted_tensor, &sorted_tensor, FDDataType::FP64);
FDTensor indices_below, indices_upper;
for (auto&& index : indices) {
Floor(index, &indices_below);
Ceil(index, &indices_upper);
Cast(indices_below, &indices_below, FDDataType::INT32);
Cast(indices_upper, &indices_upper, FDDataType::INT32);
FDTensor tensor_below, tensor_upper;
GatherAlongAxis(sorted_tensor, indices_below, &tensor_below, target_axis);
GatherAlongAxis(sorted_tensor, indices_upper, &tensor_upper, target_axis);
// Need to cast to FP64 to compute with index and tensor_upper
Cast(indices_below, &indices_below, FDDataType::FP64);
FDTensor weight = index - indices_below;
FDTensor out = tensor_below + weight * (tensor_upper - tensor_below);
out.Squeeze(target_axis);
if (out.Dtype() != x.Dtype()) {
Cast(out, &out, x.Dtype());
}
outputs.push_back(std::move(out));
}
if (outputs.size() > 1) {
// Execute stack operation
for (auto& output : outputs) {
output.ExpandDim(0);
}
Concat(outputs, out, 0);
} else {
*out = std::move(outputs[0]);
}
}
void Quantile(const FDTensor& x, const std::vector<double>& q,
const std::vector<int>& axis, FDTensor* out) {
FD_VISIT_FLOAT_TYPES(x.dtype, "QuantileKernel",
([&] { QuantileKernel<data_t>(x, q, axis, out); }));
}
} // namespace function
} // namespace fastdeploy

View File

@@ -0,0 +1,34 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/core/fd_tensor.h"
namespace fastdeploy {
namespace function {
/** Compute the quantile of the input along the specified axis. If any values
** in a reduced row are NaN, then the quantiles for that reduction will be NaN.
@param x The input tensor.
@param q The q for calculate quantile, which should be in range [0, 1].
@param axis The axis along which to calculate quantile. axis should be int
or list of int.
@param out The output tensor which stores the result.
*/
FASTDEPLOY_DECL void Quantile(const FDTensor& x, const std::vector<double>& q,
const std::vector<int>& axis, FDTensor* out);
} // namespace function
} // namespace fastdeploy

118
fastdeploy/function/sort.cc Normal file
View File

@@ -0,0 +1,118 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/function/sort.h"
#include "fastdeploy/function/eigen.h"
#include "fastdeploy/function/transpose.h"
#include <algorithm>
#include <cmath>
#include <numeric>
namespace fastdeploy {
namespace function {
template <typename T, typename Type>
static void FullSort(Type input_height, Type input_width, int input_dim,
const FDTensor* input, FDTensor* out, FDTensor* indices,
bool descending) {
out->Allocate(input->Shape(), input->Dtype());
indices->Allocate(input->Shape(), TypeToDataType<Type>::dtype);
T* t_out = reinterpret_cast<T*>(out->Data());
Type* t_indices = reinterpret_cast<Type*>(indices->Data());
for (Type i = 0; i < input_height; ++i) {
std::vector<std::pair<T, Type>> col_vec;
col_vec.reserve(input_width);
if (input_dim == 1) {
auto e_input = EigenVector<T>::Flatten(*input);
for (Type j = 0; j < input_width; ++j) {
col_vec.push_back(std::pair<T, Type>(e_input(j), j));
}
} else {
auto e_input = EigenMatrix<T>::Reshape(*input, input_dim - 1);
for (Type j = 0; j < input_width; ++j) {
col_vec.push_back(std::pair<T, Type>(e_input(i, j), j));
}
}
std::sort(col_vec.begin(), col_vec.end(),
[&](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
if (descending)
return (std::isnan(static_cast<double>(l.first)) &&
!std::isnan(static_cast<double>(r.first))) ||
(l.first > r.first);
else
return (!std::isnan(static_cast<double>(l.first)) &&
std::isnan(static_cast<double>(r.first))) ||
(l.first < r.first);
});
for (Type j = 0; j < input_width; ++j) {
t_out[i * input_width + j] = col_vec[j].first;
t_indices[i * input_width + j] = col_vec[j].second;
}
}
}
template <typename T>
void SortKernel(const FDTensor& x, FDTensor* out, FDTensor* indices,
FDDataType indices_type, bool descending, int axis) {
auto input_shape = x.Shape();
int rank = input_shape.size();
axis = (axis < 0) ? (rank + axis) : axis;
// Do full sort
if (axis == -1 || axis + 1 == rank) {
const int64_t input_width = input_shape[rank - 1];
const int64_t input_height = x.Numel() / input_width;
FD_VISIT_INT_TYPES(indices_type, "FullSort", ([&] {
FullSort<T, data_t>(input_height, input_width, rank,
&x, out, indices, descending);
}));
} else {
// If not full sort do transpose
std::vector<int64_t> trans;
for (int i = 0; i < axis; i++) {
trans.push_back(i);
}
trans.push_back(rank - 1);
for (int i = axis + 1; i < rank - 1; i++) {
trans.push_back(i);
}
trans.push_back(axis);
FDTensor trans_inp;
Transpose(x, &trans_inp, trans);
const int64_t input_width = input_shape[axis];
const int64_t input_height = x.Numel() / input_width;
FD_VISIT_INT_TYPES(indices_type, "FullSort", ([&] {
FullSort<T, data_t>(input_height, input_width, rank,
&trans_inp, out, indices,
descending);
}));
// transpose back
Transpose(*out, out, trans);
Transpose(*indices, indices, trans);
}
}
void Sort(const FDTensor& x, FDTensor* out, FDTensor* indices, int axis,
bool descending, FDDataType indices_type) {
FD_VISIT_INT_FLOAT_TYPES(x.dtype, "SortKernel", ([&] {
SortKernel<data_t>(x, out, indices, indices_type,
descending, axis);
}));
}
} // namespace function
} // namespace fastdeploy

View File

@@ -0,0 +1,47 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/core/fd_tensor.h"
namespace fastdeploy {
namespace function {
/**
* @brief Performs sorting on the input tensor along the given axis and outputs
* two tensors, Output(Out) and Output(Indices). They reserve the same
* shape with Input(X), and Output(Out) represents the sorted tensor
* while Output(Indices) gives the sorted order along the given axis
* Attr(axis).
* @param x The input of sort
* @param out The sorted tensor of sort op, with the same shape as
* x
* @param indices The indices of a tensor giving the sorted order, with
* the same shape as x
* @param axis The axis along which to sort the tensor.
* When axis < 0, the actual axis will be the |axis|'th
* counting backwards
* @param descending The descending attribute is a flag to tell
* algorithm how to sort the input data.
* If descending is true, will sort by descending order,
* else if false, sort by ascending order
* @param indices_type The data type of indices, default to int64
*/
FASTDEPLOY_DECL void Sort(const FDTensor& x, FDTensor* out, FDTensor* indices,
int axis = 0, bool descending = false,
FDDataType indices_type = FDDataType::INT64);
} // namespace function
} // namespace fastdeploy