[Functions] Add quantile function (#700)

* Add sort function

* Add isfinite function

* upgrade isinf isnan

* Add Scalar to FDTensor

* Add floor, ceil function

* add cast functions

* Update out_tmp

* Update quantile

* add gather scatter along axis

* finish quantile function

* Add quantile unittest

* refresh code style for test source code

* Add comments

* Add full function

* Add scalar to fd tensor

* Add full unittest

* Add functions headers

* move fdtensor operators to fastdeploy namespace
This commit is contained in:
Jack Zhou
2022-11-28 09:51:40 +08:00
committed by GitHub
parent 4e74ac06fb
commit 129dda7809
37 changed files with 1567 additions and 75 deletions

121
fastdeploy/core/fd_scalar.h Normal file
View File

@@ -0,0 +1,121 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include <limits>
#include "fastdeploy/core/fd_type.h"
#include "fastdeploy/core/float16.h"
namespace fastdeploy {
class Scalar {
public:
// Constructor support implicit
Scalar() : Scalar(0) {}
Scalar(double val) : dtype_(FDDataType::FP64) { // NOLINT
data_.f64 = val;
}
Scalar(float val) : dtype_(FDDataType::FP32) { // NOLINT
data_.f32 = val;
}
Scalar(float16 val) : dtype_(FDDataType::FP16) { // NOLINT
data_.f16 = val;
}
Scalar(int64_t val) : dtype_(FDDataType::INT64) { // NOLINT
data_.i64 = val;
}
Scalar(int32_t val) : dtype_(FDDataType::INT32) { // NOLINT
data_.i32 = val;
}
Scalar(int16_t val) : dtype_(FDDataType::INT16) { // NOLINT
data_.i16 = val;
}
Scalar(int8_t val) : dtype_(FDDataType::INT8) { // NOLINT
data_.i8 = val;
}
Scalar(uint8_t val) : dtype_(FDDataType::UINT8) { // NOLINT
data_.ui8 = val;
}
Scalar(bool val) : dtype_(FDDataType::BOOL) { // NOLINT
data_.b = val;
}
// The compatible method for fliud operators,
// and it will be removed in the future.
explicit Scalar(const std::string& str_value) : dtype_(FDDataType::FP64) {
if (str_value == "inf") {
data_.f64 = std::numeric_limits<double>::infinity();
} else if (str_value == "-inf") {
data_.f64 = -std::numeric_limits<double>::infinity();
} else if (str_value == "nan") {
data_.f64 = std::numeric_limits<double>::quiet_NaN();
} else {
data_.f64 = std::stod(str_value);
}
}
template <typename RT> inline RT to() const {
switch (dtype_) {
case FDDataType::FP32:
return static_cast<RT>(data_.f32);
case FDDataType::FP64:
return static_cast<RT>(data_.f64);
case FDDataType::FP16:
return static_cast<RT>(data_.f16);
case FDDataType::INT32:
return static_cast<RT>(data_.i32);
case FDDataType::INT64:
return static_cast<RT>(data_.i64);
case FDDataType::INT16:
return static_cast<RT>(data_.i16);
case FDDataType::INT8:
return static_cast<RT>(data_.i8);
case FDDataType::UINT8:
return static_cast<RT>(data_.ui8);
case FDDataType::BOOL:
return static_cast<RT>(data_.b);
default:
FDASSERT(false, "Invalid enum scalar data type `%s`.",
Str(dtype_).c_str());
}
}
FDDataType dtype() const { return dtype_; }
private:
FDDataType dtype_;
union data {
bool b;
int8_t i8;
int16_t i16;
int32_t i32;
int64_t i64;
uint8_t ui8;
float16 f16;
float f32;
double f64;
} data_;
};
} // namespace fastdeploy

View File

@@ -12,8 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/core/fd_scalar.h"
#include "fastdeploy/core/float16.h"
#include "fastdeploy/utils/utils.h"
#include <algorithm>
#include <cstring>
#ifdef WITH_GPU
@@ -344,6 +346,40 @@ void FDTensor::CopyBuffer(void* dst, const void* src, size_t nbytes,
}
FDTensor::FDTensor(const std::string& tensor_name) { name = tensor_name; }
FDTensor::FDTensor(const char* tensor_name) { name = tensor_name; }
FDTensor::FDTensor(const Scalar& scalar) {
Allocate({1}, scalar.dtype());
switch (scalar.dtype()) {
case FDDataType::BOOL:
(reinterpret_cast<bool*>(Data()))[0] = scalar.to<bool>();
break;
case FDDataType::UINT8:
(reinterpret_cast<uint8_t*>(Data()))[0] = scalar.to<uint8_t>();
break;
case FDDataType::INT8:
(reinterpret_cast<int8_t*>(Data()))[0] = scalar.to<int8_t>();
break;
case FDDataType::INT16:
(reinterpret_cast<int16_t*>(Data()))[0] = scalar.to<int16_t>();
break;
case FDDataType::INT32:
(reinterpret_cast<int*>(Data()))[0] = scalar.to<int>();
break;
case FDDataType::INT64:
(reinterpret_cast<int64_t*>(Data()))[0] = scalar.to<int64_t>();
break;
case FDDataType::FP16:
(reinterpret_cast<float16*>(Data()))[0] = scalar.to<float16>();
break;
case FDDataType::FP32:
(reinterpret_cast<float*>(Data()))[0] = scalar.to<float>();
break;
case FDDataType::FP64:
(reinterpret_cast<double*>(Data()))[0] = scalar.to<double>();
break;
}
}
FDTensor::FDTensor(const FDTensor& other)
: shape(other.shape), name(other.name), dtype(other.dtype),

View File

@@ -23,6 +23,8 @@
namespace fastdeploy {
struct Scalar;
struct FASTDEPLOY_DECL FDTensor {
// std::vector<int8_t> data;
void* buffer_ = nullptr;
@@ -126,6 +128,8 @@ struct FASTDEPLOY_DECL FDTensor {
FDTensor() {}
explicit FDTensor(const std::string& tensor_name);
explicit FDTensor(const char* tensor_name);
// Deep copy
FDTensor(const FDTensor& other);
// Move constructor
@@ -136,6 +140,9 @@ struct FASTDEPLOY_DECL FDTensor {
// Move assignment
FDTensor& operator=(FDTensor&& other);
// Scalar to FDTensor
explicit FDTensor(const Scalar& scalar);
~FDTensor() { FreeFn(); }
static void CopyBuffer(void* dst, const void* src, size_t nbytes,

View File

@@ -0,0 +1,46 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/function/cast.h"
#include <algorithm>
namespace fastdeploy {
namespace function {
template <typename InT, typename OutT> struct CastOpTransformFunctor {
OutT operator()(InT in) const { return static_cast<OutT>(in); }
};
template <typename InT>
void CastKernel(const FDTensor& x, FDTensor* out, FDDataType output_dtype) {
FD_VISIT_ALL_TYPES(output_dtype, "CastOpTransformFunctor", ([&] {
auto* in_begin = reinterpret_cast<const InT*>(x.Data());
auto* in_end = in_begin + x.Numel();
FDTensor out_tmp;
out_tmp.Allocate(x.Shape(), output_dtype);
auto* out_begin = reinterpret_cast<data_t*>(out_tmp.Data());
std::transform(in_begin, in_end, out_begin,
CastOpTransformFunctor<InT, data_t>());
*out = std::move(out_tmp);
}));
}
void Cast(const FDTensor& x, FDTensor* out, FDDataType output_dtype) {
FD_VISIT_ALL_TYPES(x.dtype, "CastKernel",
([&] { CastKernel<data_t>(x, out, output_dtype); }));
}
} // namespace function
} // namespace fastdeploy

View File

@@ -0,0 +1,31 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/core/fd_tensor.h"
namespace fastdeploy {
namespace function {
/** Cast x to output data type element-wise. Only for float type FDTensor
@param x The input tensor.
@param out The output tensor which stores the result.
@param output_dtype The type of output tensor.
*/
FASTDEPLOY_DECL void Cast(const FDTensor& x, FDTensor* out,
FDDataType output_dtype);
} // namespace function
} // namespace fastdeploy

View File

@@ -22,7 +22,7 @@ namespace function {
/** Excute the concatenate operation for input FDTensor along given axis.
@param x The input tensor.
@param out The output tensor which stores the result.
@param axisi Axis which will be concatenated.
@param axis Axis which will be concatenated.
*/
FASTDEPLOY_DECL void Concat(const std::vector<FDTensor>& x, FDTensor* out,

View File

@@ -32,45 +32,21 @@ void Add(const FDTensor& x, const FDTensor& y, FDTensor* out) {
([&] { AddRawKernel<data_t>()(x, y, -1, out); }));
}
FDTensor operator+(const FDTensor& x, const FDTensor& y) {
FDTensor out;
Add(x, y, &out);
return out;
}
void Subtract(const FDTensor& x, const FDTensor& y, FDTensor* out) {
FD_VISIT_ALL_TYPES(x.dtype, "SubtractRawKernel",
([&] { SubtractRawKernel<data_t>()(x, y, -1, out); }));
}
FDTensor operator-(const FDTensor& x, const FDTensor& y) {
FDTensor out;
Subtract(x, y, &out);
return out;
}
void Multiply(const FDTensor& x, const FDTensor& y, FDTensor* out) {
FD_VISIT_ALL_TYPES(x.dtype, "MultiplyRawKernel",
([&] { MultiplyRawKernel<data_t>()(x, y, -1, out); }));
}
FDTensor operator*(const FDTensor& x, const FDTensor& y) {
FDTensor out;
Multiply(x, y, &out);
return out;
}
void Divide(const FDTensor& x, const FDTensor& y, FDTensor* out) {
FD_VISIT_ALL_TYPES(x.dtype, "DivideRawKernel",
([&] { DivideRawKernel<data_t>()(x, y, -1, out); }));
}
FDTensor operator/(const FDTensor& x, const FDTensor& y) {
FDTensor out;
Divide(x, y, &out);
return out;
}
template <typename T> struct MaximumRawKernel {
void operator()(const FDTensor& x, const FDTensor& y, int axis,
FDTensor* out) {
@@ -85,4 +61,29 @@ void Maximum(const FDTensor& x, const FDTensor& y, FDTensor* out) {
}
} // namespace function
FDTensor operator+(const FDTensor& x, const FDTensor& y) {
FDTensor out;
function::Add(x, y, &out);
return out;
}
FDTensor operator-(const FDTensor& x, const FDTensor& y) {
FDTensor out;
function::Subtract(x, y, &out);
return out;
}
FDTensor operator*(const FDTensor& x, const FDTensor& y) {
FDTensor out;
function::Multiply(x, y, &out);
return out;
}
FDTensor operator/(const FDTensor& x, const FDTensor& y) {
FDTensor out;
function::Divide(x, y, &out);
return out;
}
} // namespace fastdeploy

View File

@@ -26,8 +26,6 @@ namespace function {
*/
FASTDEPLOY_DECL void Add(const FDTensor& x, const FDTensor& y, FDTensor* out);
FASTDEPLOY_DECL FDTensor operator+(const FDTensor& x, const FDTensor& y);
/** Excute the subtract operation for input FDTensors. *out = x - y.
@param x The input tensor.
@param y The input tensor.
@@ -36,8 +34,6 @@ FASTDEPLOY_DECL FDTensor operator+(const FDTensor& x, const FDTensor& y);
FASTDEPLOY_DECL void Subtract(const FDTensor& x, const FDTensor& y,
FDTensor* out);
FASTDEPLOY_DECL FDTensor operator-(const FDTensor& x, const FDTensor& y);
/** Excute the multiply operation for input FDTensors. *out = x * y.
@param x The input tensor.
@param y The input tensor.
@@ -46,7 +42,6 @@ FASTDEPLOY_DECL FDTensor operator-(const FDTensor& x, const FDTensor& y);
FASTDEPLOY_DECL void Multiply(const FDTensor& x, const FDTensor& y,
FDTensor* out);
FASTDEPLOY_DECL FDTensor operator*(const FDTensor& x, const FDTensor& y);
/** Excute the divide operation for input FDTensors. *out = x / y.
@param x The input tensor.
@param y The input tensor.
@@ -54,7 +49,6 @@ FASTDEPLOY_DECL FDTensor operator*(const FDTensor& x, const FDTensor& y);
*/
FASTDEPLOY_DECL void Divide(const FDTensor& x, const FDTensor& y,
FDTensor* out);
FASTDEPLOY_DECL FDTensor operator/(const FDTensor& x, const FDTensor& y);
/** Excute the maximum operation for input FDTensors. *out = max(x, y).
@param x The input tensor.
@@ -65,4 +59,13 @@ FASTDEPLOY_DECL void Maximum(const FDTensor& x, const FDTensor& y,
FDTensor* out);
} // namespace function
FASTDEPLOY_DECL FDTensor operator+(const FDTensor& x, const FDTensor& y);
FASTDEPLOY_DECL FDTensor operator-(const FDTensor& x, const FDTensor& y);
FASTDEPLOY_DECL FDTensor operator*(const FDTensor& x, const FDTensor& y);
FASTDEPLOY_DECL FDTensor operator/(const FDTensor& x, const FDTensor& y);
} // namespace fastdeploy

View File

@@ -0,0 +1,42 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/function/full.h"
#include "fastdeploy/function/eigen.h"
#include <algorithm>
namespace fastdeploy {
namespace function {
template <typename T> void FullValue(FDTensor* tensor, const Scalar& val) {
auto t = EigenVector<T>::Flatten(*tensor);
auto& place = *EigenDeviceWrapper::GetInstance()->GetDevice();
t.device(place) = t.constant(val.to<T>());
}
void Full(const Scalar& value, const std::vector<int64_t>& shape, FDTensor* out,
FDDataType dtype) {
FD_VISIT_ALL_TYPES(dtype, "Full", ([&] {
out->Allocate(shape, dtype);
FullValue<data_t>(out, value);
}));
}
void FullLike(const FDTensor& x, const Scalar& value, FDTensor* out,
FDDataType dtype) {
Full(value, x.Shape(), out, dtype);
}
} // namespace function
} // namespace fastdeploy

View File

@@ -0,0 +1,44 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/core/fd_scalar.h"
#include "fastdeploy/core/fd_tensor.h"
namespace fastdeploy {
namespace function {
/** Fill the value to tensor
@param value The value to be filled in tensor
@param shape The shape of output tensor.
@param out The output tensor which stores the result.
@param dtype The data type of output tensor. Default to float32
*/
FASTDEPLOY_DECL void Full(const Scalar& value,
const std::vector<int64_t>& shape, FDTensor* out,
FDDataType dtype = FDDataType::FP32);
/** Fill the value to tensor
@param x The input tensor.
@param value The value to be filled in tensor
@param out The output tensor which stores the result.
@param dtype The data type of output tensor. Default to float32
*/
FASTDEPLOY_DECL void FullLike(const FDTensor& x, const Scalar& value,
FDTensor* out,
FDDataType dtype = FDDataType::FP32);
} // namespace function
} // namespace fastdeploy

View File

@@ -0,0 +1,35 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/function/cast.h"
#include "fastdeploy/function/clip.h"
#include "fastdeploy/function/concat.h"
#include "fastdeploy/function/cuda_cast.h"
#include "fastdeploy/function/cumprod.h"
#include "fastdeploy/function/elementwise.h"
#include "fastdeploy/function/full.h"
#include "fastdeploy/function/gather_scatter_along_axis.h"
#include "fastdeploy/function/isfinite.h"
#include "fastdeploy/function/linspace.h"
#include "fastdeploy/function/math.h"
#include "fastdeploy/function/pad.h"
#include "fastdeploy/function/quantile.h"
#include "fastdeploy/function/reduce.h"
#include "fastdeploy/function/softmax.h"
#include "fastdeploy/function/sort.h"
#include "fastdeploy/function/split.h"
#include "fastdeploy/function/tile.h"
#include "fastdeploy/function/transpose.h"

View File

@@ -0,0 +1,125 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/function/gather_scatter_along_axis.h"
#include "fastdeploy/function/tile.h"
namespace fastdeploy {
namespace function {
class TensorAssign {
public:
template <typename tensor_t>
void operator()(tensor_t* self_data, tensor_t* src_data) const {
*self_data = *src_data;
}
};
static TensorAssign tensor_assign;
template <typename T, typename index_t = int64_t, bool is_scatter_like = true>
struct GatherScatterFunctor {
template <typename func_t>
void operator()(const FDTensor& x, int axis, const FDTensor& index,
FDTensor* result, const func_t& reduce_op) {
if (index.Numel() == 0) {
return;
}
result->Allocate(index.Shape(), x.Dtype());
const T* x_data = reinterpret_cast<const T*>(x.Data());
const index_t* index_data = reinterpret_cast<const index_t*>(index.Data());
T* result_data = reinterpret_cast<T*>(result->Data());
int64_t x_size = x.Numel();
int64_t index_size = index.Numel();
int64_t result_size = result->Numel();
auto x_dims = x.Shape();
auto index_dims = index.Shape();
auto result_dims = result->Shape();
if (x_size == 0 || result_size == 0 || index_size == 0) {
FDASSERT(false, "zero size input found, self_size, result_size, "
"index_size cannot be 0");
return;
}
int select_dim_size = index_dims[axis];
// index matrix has different shape with self matrix or src matrix.
int replaced_select_dim_size =
is_scatter_like ? result_dims[axis] : x_dims[axis];
int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
for (int64_t i = 0; i < axis; ++i) {
inner_dim_size *= index_dims[i];
}
for (int i = axis + 1; i < index_dims.size(); i++) {
outer_dim_size *= index_dims[i];
}
int64_t index_idx = 0;
int64_t self_idx, src_idx;
// N layer loop squeezed into 3 layers loop
for (int64_t i = 0; i < inner_dim_size; i++) {
for (int64_t j = 0; j < select_dim_size; j++) {
for (int64_t k = 0; k < outer_dim_size; k++) {
int64_t index = index_data[index_idx];
// This index might out of bound of index matrix's index, so here
// multiply the replaced_select_dim_size.
int64_t replace_index = k + index * outer_dim_size +
i * outer_dim_size * replaced_select_dim_size;
self_idx = is_scatter_like ? replace_index : index_idx;
src_idx = is_scatter_like ? index_idx : replace_index;
reduce_op((T*)(result_data + self_idx), // NOLINT
(T*)(x_data + src_idx)); // NOLINT
index_idx++;
}
}
}
}
};
template <typename T> struct GatherFunctor {
void operator()(const FDTensor& x, int axis, const FDTensor& index,
FDTensor* result) {
FD_VISIT_INT_TYPES(index.Dtype(), "GatherFunctor", [&]() {
auto x_shape = x.Shape();
auto index_shape = index.Shape();
std::vector<int64_t> repeat_times(x_shape.size(), 1);
for (int i = 0; i < x_shape.size(); ++i) {
repeat_times[i] = x_shape[i] / index_shape[i];
}
repeat_times[axis] = 1;
FDTensor gs_index;
Tile(index, repeat_times, &gs_index);
GatherScatterFunctor<T, data_t, /*is_scatter_like=*/false>()(
x, axis, gs_index, result, tensor_assign);
});
}
};
void GatherAlongAxis(const FDTensor& x, const FDTensor& index, FDTensor* result,
int axis) {
int rank = x.Shape().size();
FDASSERT(axis >= -rank && axis < rank,
"axis should be in range [-%d, %d - 1].", rank, rank - 1);
if (axis < 0) {
axis += rank;
}
FD_VISIT_ALL_TYPES(x.Dtype(), "GatherAlongAxis", [&]() {
GatherFunctor<data_t>()(x, axis, index, result);
});
}
} // namespace function
} // namespace fastdeploy

View File

@@ -0,0 +1,33 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/core/fd_tensor.h"
namespace fastdeploy {
namespace function {
/** Output is obtained by gathering entries of axis of x indexed by index and
* concatenate them together.
@param x The input tensor.
@param index The index of a tensor to gather.
@param out The output tensor which stores the result.
@param axis Axis which will be gathered.
*/
void GatherAlongAxis(const FDTensor& x, const FDTensor& index, FDTensor* result,
int axis);
} // namespace function
} // namespace fastdeploy

View File

@@ -0,0 +1,111 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/function/isfinite.h"
#include "fastdeploy/core/float16.h"
#include <algorithm>
#include <type_traits>
namespace fastdeploy {
namespace function {
template <typename T, typename OutT, class Enable = void> struct IsNanFunctor {
OutT operator()(const T& a) const { return static_cast<OutT>(std::isnan(a)); }
};
template <typename T, typename OutT>
struct IsNanFunctor<T, OutT,
typename std::enable_if<std::is_integral<T>::value>::type> {
OutT operator()(const T& a) const { return static_cast<OutT>(false); }
};
template <typename OutT> struct IsNanFunctor<fastdeploy::float16, OutT, void> {
OutT operator()(const fastdeploy::float16& a) const {
return static_cast<OutT>(fastdeploy::isnan(a));
}
};
template <typename T, typename OutT, class Enable = void> struct IsInfFunctor {
OutT operator()(const T& a) const { return static_cast<OutT>(std::isinf(a)); }
};
template <typename T, typename OutT>
struct IsInfFunctor<T, OutT,
typename std::enable_if<std::is_integral<T>::value>::type> {
OutT operator()(const T& a) const { return static_cast<OutT>(false); }
};
template <typename OutT> struct IsInfFunctor<fastdeploy::float16, OutT, void> {
OutT operator()(const fastdeploy::float16& a) const {
return static_cast<OutT>(fastdeploy::isinf(a));
}
};
template <typename T, typename OutT, class Enable = void>
struct IsFiniteFunctor {
OutT operator()(const T& a) const {
return static_cast<OutT>(std::isfinite(a));
}
};
template <typename T, typename OutT>
struct IsFiniteFunctor<
T, OutT, typename std::enable_if<std::is_integral<T>::value>::type> {
OutT operator()(const T& a) const { return static_cast<OutT>(true); }
};
template <typename OutT>
struct IsFiniteFunctor<fastdeploy::float16, OutT, void> {
OutT operator()(const fastdeploy::float16& a) const {
return static_cast<OutT>(fastdeploy::isfinite(a));
}
};
#define DEFINE_ISFINITE_KERNEL(isfinite_kernel, functor) \
template <typename T> \
void isfinite_kernel(const FDTensor& x, FDTensor* out, FDDataType dtype) { \
FD_VISIT_ALL_TYPES(dtype, #isfinite_kernel, ([&] { \
out->Allocate(x.Shape(), dtype); \
functor<T, data_t> unary_func; \
data_t* out_ptr = \
reinterpret_cast<data_t*>(out->Data()); \
const T* input_ptr = \
reinterpret_cast<const T*>(x.Data()); \
std::transform(input_ptr, input_ptr + x.Numel(), \
out_ptr, unary_func); \
})); \
}
DEFINE_ISFINITE_KERNEL(IsNanKernel, IsNanFunctor)
DEFINE_ISFINITE_KERNEL(IsInfKernel, IsInfFunctor)
DEFINE_ISFINITE_KERNEL(IsFiniteKernel, IsFiniteFunctor)
#undef DEFINE_ISFINITE_KERNEL
void IsNan(const FDTensor& x, FDTensor* out, FDDataType dtype) {
FD_VISIT_FLOAT_TYPES(x.dtype, "IsNanKernel",
([&] { IsNanKernel<data_t>(x, out, dtype); }));
}
void IsInf(const FDTensor& x, FDTensor* out, FDDataType dtype) {
FD_VISIT_FLOAT_TYPES(x.dtype, "IsInfKernel",
([&] { IsInfKernel<data_t>(x, out, dtype); }));
}
void IsFinite(const FDTensor& x, FDTensor* out, FDDataType dtype) {
FD_VISIT_FLOAT_TYPES(x.dtype, "IsFiniteKernel",
([&] { IsFiniteKernel<data_t>(x, out, dtype); }));
}
} // namespace function
} // namespace fastdeploy

View File

@@ -0,0 +1,47 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/core/fd_tensor.h"
namespace fastdeploy {
namespace function {
/** Return whether every element of input tensor is NaN or not.
@param x The input tensor.
@param out The output tensor which stores the result.
@param dtype The output data type
*/
FASTDEPLOY_DECL void IsNan(const FDTensor& x, FDTensor* out,
FDDataType dtype = FDDataType::BOOL);
/** Return whether every element of input tensor is Inf or not.
@param x The input tensor.
@param out The output tensor which stores the result.
@param dtype The output data type
*/
FASTDEPLOY_DECL void IsInf(const FDTensor& x, FDTensor* out,
FDDataType dtype = FDDataType::BOOL);
/** Return whether every element of input tensor is finite or not.
@param x The input tensor.
@param out The output tensor which stores the result.
@param dtype The output data type
*/
FASTDEPLOY_DECL void IsFinite(const FDTensor& x, FDTensor* out,
FDDataType dtype = FDDataType::BOOL);
} // namespace function
} // namespace fastdeploy

View File

@@ -40,6 +40,8 @@ DEFINE_ACTIVATION_KERNEL(Log, LogFunctor)
DEFINE_ACTIVATION_KERNEL(Round, RoundFunctor)
DEFINE_ACTIVATION_KERNEL(Exp, ExpFunctor)
DEFINE_ACTIVATION_KERNEL(Abs, AbsFunctor)
DEFINE_ACTIVATION_KERNEL(Ceil, CeilFunctor)
DEFINE_ACTIVATION_KERNEL(Floor, FloorFunctor)
void Sqrt(const FDTensor& x, FDTensor* out) {
FD_VISIT_FLOAT_TYPES(x.dtype, "SqrtKernel",
@@ -66,5 +68,15 @@ void Abs(const FDTensor& x, FDTensor* out) {
([&] { AbsKernel<data_t>(x, out); }));
}
void Ceil(const FDTensor& x, FDTensor* out) {
FD_VISIT_FLOAT_TYPES(x.dtype, "CeilKernel",
([&] { CeilKernel<data_t>(x, out); }));
}
void Floor(const FDTensor& x, FDTensor* out) {
FD_VISIT_FLOAT_TYPES(x.dtype, "FloorKernel",
([&] { FloorKernel<data_t>(x, out); }));
}
} // namespace function
} // namespace fastdeploy

View File

@@ -49,5 +49,17 @@ FASTDEPLOY_DECL void Exp(const FDTensor& x, FDTensor* out);
*/
FASTDEPLOY_DECL void Abs(const FDTensor& x, FDTensor* out);
/** Computes ceil of x element-wise. Only for float type FDTensor
@param x The input tensor.
@param out The output tensor which stores the result.
*/
FASTDEPLOY_DECL void Ceil(const FDTensor& x, FDTensor* out);
/** Computes floor of x element-wise. Only for float type FDTensor
@param x The input tensor.
@param out The output tensor which stores the result.
*/
FASTDEPLOY_DECL void Floor(const FDTensor& x, FDTensor* out);
} // namespace function
} // namespace fastdeploy

View File

@@ -61,5 +61,21 @@ template <typename T> struct AbsFunctor {
}
};
// ceil(x) = ceiling(x)
template <typename T> struct CeilFunctor {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.ceil();
}
};
// floor(x) = flooring(x)
template <typename T> struct FloorFunctor {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.floor();
}
};
} // namespace function
} // namespace fastdeploy

View File

@@ -0,0 +1,130 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/function/quantile.h"
#include "fastdeploy/core/fd_scalar.h"
#include "fastdeploy/function/cast.h"
#include "fastdeploy/function/concat.h"
#include "fastdeploy/function/elementwise.h"
#include "fastdeploy/function/gather_scatter_along_axis.h"
#include "fastdeploy/function/isfinite.h"
#include "fastdeploy/function/math.h"
#include "fastdeploy/function/reduce.h"
#include "fastdeploy/function/sort.h"
#include "fastdeploy/function/transpose.h"
#include <algorithm>
#include <cmath>
#include <numeric>
namespace fastdeploy {
namespace function {
template <typename T>
void QuantileKernel(const FDTensor& x, const std::vector<double>& q,
const std::vector<int>& axis, FDTensor* out) {
FDASSERT(q.size() > 0, "q should not be empty.");
FDASSERT(axis.size() > 0, "axis should not be empty.");
std::vector<int64_t> axis_src;
std::vector<int64_t> out_shape = x.Shape();
int64_t rank = x.Shape().size();
for (auto axis_single : axis) {
FDASSERT(axis_single >= -rank && axis_single < rank,
"The axis is expected to be in range of [%d, %d), but got %d",
-rank, rank, axis_single);
if (axis_single < 0) {
axis_single += rank;
}
axis_src.push_back(axis_single);
out_shape[axis_single] = 1;
}
std::vector<int64_t> axis_dst;
for (int64_t i = 0; i < rank; ++i) {
if (std::find(axis_src.begin(), axis_src.end(), i) == axis_src.end()) {
axis_dst.push_back(i);
}
}
axis_dst.insert(axis_dst.end(), axis_src.begin(), axis_src.end());
FDTensor y;
Transpose(x, &y, axis_dst);
std::vector<int64_t> y_shape(rank - axis_src.size(), 0);
y_shape.push_back(-1);
y.Reshape({y_shape});
int64_t target_axis = rank - 1;
FDTensor mask, valid_counts, mask_any;
IsNan(y, &mask);
Any(mask, &mask_any, {target_axis}, true);
bool* mask_data = reinterpret_cast<bool*>(mask.Data());
std::transform(mask_data, mask_data + mask.Numel(), mask_data,
[](const bool& val) { return !val; });
Cast(mask_any, &mask_any, FDDataType::FP64);
Cast(mask, &mask, FDDataType::FP64);
Sum(mask, &valid_counts, {target_axis}, true);
FDTensor one_tensor(Scalar(static_cast<double>(1.0)));
std::vector<FDTensor> indices;
FDTensor last_index(Scalar(static_cast<double>(x.Shape()[target_axis])));
for (auto q_num : q) {
FDASSERT(q_num >= 0 && q_num <= 1, "q should be in range [0, 1]");
FDTensor q_tensor(static_cast<double>(q_num));
FDTensor index = q_tensor * (valid_counts - one_tensor);
index = mask_any * last_index + (one_tensor - mask_any) * index;
indices.push_back(index);
}
std::vector<FDTensor> outputs;
FDTensor sorted_tensor, sorted_indices_tensor;
Sort(y, &sorted_tensor, &sorted_indices_tensor, target_axis);
Cast(sorted_tensor, &sorted_tensor, FDDataType::FP64);
FDTensor indices_below, indices_upper;
for (auto&& index : indices) {
Floor(index, &indices_below);
Ceil(index, &indices_upper);
Cast(indices_below, &indices_below, FDDataType::INT32);
Cast(indices_upper, &indices_upper, FDDataType::INT32);
FDTensor tensor_below, tensor_upper;
GatherAlongAxis(sorted_tensor, indices_below, &tensor_below, target_axis);
GatherAlongAxis(sorted_tensor, indices_upper, &tensor_upper, target_axis);
// Need to cast to FP64 to compute with index and tensor_upper
Cast(indices_below, &indices_below, FDDataType::FP64);
FDTensor weight = index - indices_below;
FDTensor out = tensor_below + weight * (tensor_upper - tensor_below);
out.Squeeze(target_axis);
if (out.Dtype() != x.Dtype()) {
Cast(out, &out, x.Dtype());
}
outputs.push_back(std::move(out));
}
if (outputs.size() > 1) {
// Execute stack operation
for (auto& output : outputs) {
output.ExpandDim(0);
}
Concat(outputs, out, 0);
} else {
*out = std::move(outputs[0]);
}
}
void Quantile(const FDTensor& x, const std::vector<double>& q,
const std::vector<int>& axis, FDTensor* out) {
FD_VISIT_FLOAT_TYPES(x.dtype, "QuantileKernel",
([&] { QuantileKernel<data_t>(x, q, axis, out); }));
}
} // namespace function
} // namespace fastdeploy

View File

@@ -0,0 +1,34 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/core/fd_tensor.h"
namespace fastdeploy {
namespace function {
/** Compute the quantile of the input along the specified axis. If any values
** in a reduced row are NaN, then the quantiles for that reduction will be NaN.
@param x The input tensor.
@param q The q for calculate quantile, which should be in range [0, 1].
@param axis The axis along which to calculate quantile. axis should be int
or list of int.
@param out The output tensor which stores the result.
*/
FASTDEPLOY_DECL void Quantile(const FDTensor& x, const std::vector<double>& q,
const std::vector<int>& axis, FDTensor* out);
} // namespace function
} // namespace fastdeploy

118
fastdeploy/function/sort.cc Normal file
View File

@@ -0,0 +1,118 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/function/sort.h"
#include "fastdeploy/function/eigen.h"
#include "fastdeploy/function/transpose.h"
#include <algorithm>
#include <cmath>
#include <numeric>
namespace fastdeploy {
namespace function {
template <typename T, typename Type>
static void FullSort(Type input_height, Type input_width, int input_dim,
const FDTensor* input, FDTensor* out, FDTensor* indices,
bool descending) {
out->Allocate(input->Shape(), input->Dtype());
indices->Allocate(input->Shape(), TypeToDataType<Type>::dtype);
T* t_out = reinterpret_cast<T*>(out->Data());
Type* t_indices = reinterpret_cast<Type*>(indices->Data());
for (Type i = 0; i < input_height; ++i) {
std::vector<std::pair<T, Type>> col_vec;
col_vec.reserve(input_width);
if (input_dim == 1) {
auto e_input = EigenVector<T>::Flatten(*input);
for (Type j = 0; j < input_width; ++j) {
col_vec.push_back(std::pair<T, Type>(e_input(j), j));
}
} else {
auto e_input = EigenMatrix<T>::Reshape(*input, input_dim - 1);
for (Type j = 0; j < input_width; ++j) {
col_vec.push_back(std::pair<T, Type>(e_input(i, j), j));
}
}
std::sort(col_vec.begin(), col_vec.end(),
[&](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
if (descending)
return (std::isnan(static_cast<double>(l.first)) &&
!std::isnan(static_cast<double>(r.first))) ||
(l.first > r.first);
else
return (!std::isnan(static_cast<double>(l.first)) &&
std::isnan(static_cast<double>(r.first))) ||
(l.first < r.first);
});
for (Type j = 0; j < input_width; ++j) {
t_out[i * input_width + j] = col_vec[j].first;
t_indices[i * input_width + j] = col_vec[j].second;
}
}
}
template <typename T>
void SortKernel(const FDTensor& x, FDTensor* out, FDTensor* indices,
FDDataType indices_type, bool descending, int axis) {
auto input_shape = x.Shape();
int rank = input_shape.size();
axis = (axis < 0) ? (rank + axis) : axis;
// Do full sort
if (axis == -1 || axis + 1 == rank) {
const int64_t input_width = input_shape[rank - 1];
const int64_t input_height = x.Numel() / input_width;
FD_VISIT_INT_TYPES(indices_type, "FullSort", ([&] {
FullSort<T, data_t>(input_height, input_width, rank,
&x, out, indices, descending);
}));
} else {
// If not full sort do transpose
std::vector<int64_t> trans;
for (int i = 0; i < axis; i++) {
trans.push_back(i);
}
trans.push_back(rank - 1);
for (int i = axis + 1; i < rank - 1; i++) {
trans.push_back(i);
}
trans.push_back(axis);
FDTensor trans_inp;
Transpose(x, &trans_inp, trans);
const int64_t input_width = input_shape[axis];
const int64_t input_height = x.Numel() / input_width;
FD_VISIT_INT_TYPES(indices_type, "FullSort", ([&] {
FullSort<T, data_t>(input_height, input_width, rank,
&trans_inp, out, indices,
descending);
}));
// transpose back
Transpose(*out, out, trans);
Transpose(*indices, indices, trans);
}
}
void Sort(const FDTensor& x, FDTensor* out, FDTensor* indices, int axis,
bool descending, FDDataType indices_type) {
FD_VISIT_INT_FLOAT_TYPES(x.dtype, "SortKernel", ([&] {
SortKernel<data_t>(x, out, indices, indices_type,
descending, axis);
}));
}
} // namespace function
} // namespace fastdeploy

View File

@@ -0,0 +1,47 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/core/fd_tensor.h"
namespace fastdeploy {
namespace function {
/**
* @brief Performs sorting on the input tensor along the given axis and outputs
* two tensors, Output(Out) and Output(Indices). They reserve the same
* shape with Input(X), and Output(Out) represents the sorted tensor
* while Output(Indices) gives the sorted order along the given axis
* Attr(axis).
* @param x The input of sort
* @param out The sorted tensor of sort op, with the same shape as
* x
* @param indices The indices of a tensor giving the sorted order, with
* the same shape as x
* @param axis The axis along which to sort the tensor.
* When axis < 0, the actual axis will be the |axis|'th
* counting backwards
* @param descending The descending attribute is a flag to tell
* algorithm how to sort the input data.
* If descending is true, will sort by descending order,
* else if false, sort by ascending order
* @param indices_type The data type of indices, default to int64
*/
FASTDEPLOY_DECL void Sort(const FDTensor& x, FDTensor* out, FDTensor* indices,
int axis = 0, bool descending = false,
FDDataType indices_type = FDDataType::INT64);
} // namespace function
} // namespace fastdeploy

View File

@@ -37,13 +37,13 @@ void SortBoxes(std::vector<std::array<int, 8>>* boxes) {
}
for (int i = 0; i < boxes->size() - 1; i++) {
if (abs((*boxes)[i + 1][1] - (*boxes)[i][1]) < 10 &&
if (std::abs((*boxes)[i + 1][1] - (*boxes)[i][1]) < 10 &&
((*boxes)[i + 1][0] < (*boxes)[i][0])) {
std::swap((*boxes)[i], (*boxes)[i + 1]);
}
}
}
} // namesoace ocr
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,52 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/function/cast.h"
#include "glog/logging.h"
#include "gtest_utils.h"
#include "gtest/gtest.h"
#include <array>
#include <vector>
namespace fastdeploy {
namespace function {
std::vector<float> CreateTestData() {
// Shape: [2, 3, 4]
std::vector<float> x_data = {
1.8428625, 0.6461913, 0.13740455, 0.11430702, 0.659926, 0.535816,
0.7429162, 0.8456049, -1.21228176, 0.29970083, 0.8621713, 0.40894133,
0.12684688, 2.1566195, -9.42884097, 20.8476526, 0.2458633, 0.669046,
0.87888306, 0.6762589, 0.666453, 0.32523027, 0.4139388, 0.8341406};
return x_data;
}
TEST(fastdeploy, cast) {
CheckShape check_shape;
CheckData check_data;
FDTensor x, y;
auto test_data = CreateTestData();
x.SetExternalData({2, 3, 4}, FDDataType::FP32, test_data.data());
Cast(x, &y, FDDataType::INT32);
std::vector<int> result = {1, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0,
0, 2, -9, 20, 0, 0, 0, 0, 0, 0, 0, 0};
check_shape(y.shape, {2, 3, 4});
check_data(reinterpret_cast<const int*>(y.Data()), result.data(),
result.size());
}
} // namespace function
} // namespace fastdeploy

View File

@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <array>
#include <vector>
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/function/concat.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "gtest_utils.h"
#include "gtest/gtest.h"
#include <array>
#include <vector>
namespace fastdeploy {
namespace function {

View File

@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/core/float16.h"
#include "gtest_utils.h"
#include "gtest/gtest.h"
#include <array>
#include <cstring>
#include <vector>
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/core/float16.h"
#include "gtest/gtest.h"
#include "gtest_utils.h"
namespace fastdeploy {
@@ -113,7 +113,7 @@ TEST(float16, comparison_cpu) {
TEST(float16, floating) {
// compile time assert.
FDASSERT(std::is_floating_point<float16>::value,
"The float16 support in CPU failed.")
"The float16 support in CPU failed.");
}
TEST(float16, print) {

View File

@@ -0,0 +1,50 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/function/full.h"
#include "glog/logging.h"
#include "gtest_utils.h"
#include "gtest/gtest.h"
#include <array>
#include <vector>
namespace fastdeploy {
namespace function {
TEST(fastdeploy, full) {
CheckShape check_shape;
CheckData check_data;
FDTensor y;
Full(1, {2, 3, 4}, &y);
std::vector<float> result(24, 1);
check_shape(y.Shape(), {2, 3, 4});
check_data(reinterpret_cast<float*>(y.Data()),
reinterpret_cast<float*>(result.data()), result.size());
}
TEST(fastdeploy, full_like) {
CheckShape check_shape;
CheckData check_data;
FDTensor x, y;
x.Allocate({3, 4}, FDDataType::FP32);
FullLike(x, 0, &y, FDDataType::INT32);
std::vector<int> result(12, 0);
check_shape(y.Shape(), {3, 4});
check_data(reinterpret_cast<int*>(y.Data()),
reinterpret_cast<int*>(result.data()), result.size());
}
} // namespace function
} // namespace fastdeploy

View File

@@ -0,0 +1,74 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/function/gather_scatter_along_axis.h"
#include "glog/logging.h"
#include "gtest_utils.h"
#include "gtest/gtest.h"
#include <array>
#include <vector>
namespace fastdeploy {
namespace function {
std::vector<float> CreateTestData() {
// Shape: [2, 3, 4]
std::vector<float> x_data = {
1.8428625, 0.6461913, 0.13740455, 0.11430702, 0.659926, 0.535816,
0.7429162, 0.8456049, -1.21228176, 0.29970083, 0.8621713, 0.40894133,
0.12684688, 2.1566195, -9.42884097, 20.8476526, 0.2458633, 0.669046,
0.87888306, 0.6762589, 0.666453, 0.32523027, 0.4139388, 0.8341406};
return x_data;
}
TEST(fastdeploy, gather) {
CheckShape check_shape;
CheckData check_data;
FDTensor x, y;
auto test_data = CreateTestData();
x.SetExternalData({2, 3, 4}, FDDataType::FP32, test_data.data());
FDTensor index;
index.Resize({1, 1, 1}, FDDataType::INT32);
reinterpret_cast<int*>(index.Data())[0] = 0;
GatherAlongAxis(x, index, &y, 0);
std::vector<float> result = {1.842862, 0.646191, 0.137405, 0.114307,
0.659926, 0.535816, 0.742916, 0.845605,
-1.212282, 0.299701, 0.862171, 0.408941};
check_shape(y.shape, {1, 3, 4});
check_data(reinterpret_cast<const float*>(y.Data()), result.data(),
result.size());
reinterpret_cast<int*>(index.Data())[0] = 1;
GatherAlongAxis(x, index, &y, 1);
result = {0.659926, 0.535816, 0.742916, 0.845605,
0.245863, 0.669046, 0.878883, 0.676259};
check_shape(y.shape, {2, 1, 4});
check_data(reinterpret_cast<const float*>(y.Data()), result.data(),
result.size());
index.Resize({1, 1, 2});
reinterpret_cast<int*>(index.Data())[0] = 0;
reinterpret_cast<int*>(index.Data())[1] = 2;
GatherAlongAxis(x, index, &y, 2);
result = {1.842862, 0.137405, 0.659926, 0.742916, -1.212282, 0.862171,
0.126847, -9.428841, 0.245863, 0.878883, 0.666453, 0.413939};
check_shape(y.shape, {2, 3, 2});
check_data(reinterpret_cast<const float*>(y.Data()), result.data(),
result.size());
}
} // namespace function
} // namespace fastdeploy

View File

@@ -0,0 +1,79 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/function/isfinite.h"
#include "glog/logging.h"
#include "gtest_utils.h"
#include "gtest/gtest.h"
#include <array>
#include <cmath>
#include <vector>
namespace fastdeploy {
namespace function {
std::vector<float> CreateTestData() {
// Shape: [2, 3]
std::vector<float> x_data = {0.8428625, NAN, INFINITY,
0.11430702, 0.659926, 0.535816};
return x_data;
}
TEST(fastdeploy, finite) {
CheckShape check_shape;
CheckData check_data;
FDTensor x, y;
auto test_data = CreateTestData();
x.SetExternalData({2, 3}, FDDataType::FP32, test_data.data());
std::array<bool, 6> result = {false, true, false, false, false, false};
IsNan(x, &y);
check_shape(y.shape, {2, 3});
check_data(reinterpret_cast<const bool*>(y.Data()), result.data(),
result.size());
std::vector<int> int_result = {0, 1, 0, 0, 0, 0};
IsNan(x, &y, FDDataType::INT32);
check_shape(y.shape, {2, 3});
check_data(reinterpret_cast<const int*>(y.Data()), int_result.data(),
int_result.size());
result = {false, false, true, false, false, false};
IsInf(x, &y);
check_shape(y.shape, {2, 3});
check_data(reinterpret_cast<const bool*>(y.Data()), result.data(),
result.size());
int_result = {0, 0, 1, 0, 0, 0};
IsInf(x, &y, FDDataType::INT32);
check_shape(y.shape, {2, 3});
check_data(reinterpret_cast<const int*>(y.Data()), int_result.data(),
int_result.size());
result = {true, false, false, true, true, true};
IsFinite(x, &y);
check_shape(y.shape, {2, 3});
check_data(reinterpret_cast<const bool*>(y.Data()), result.data(),
result.size());
int_result = {1, 0, 0, 1, 1, 1};
IsFinite(x, &y, FDDataType::INT32);
check_shape(y.shape, {2, 3});
check_data(reinterpret_cast<const int*>(y.Data()), int_result.data(),
int_result.size());
}
} // namespace function
} // namespace fastdeploy

View File

@@ -71,6 +71,22 @@ TEST(fastdeploy, exp_sqrt_round_log) {
check_data(reinterpret_cast<const float*>(y.Data()), round_result.data(),
round_result.size());
Ceil(x, &y);
std::vector<float> ceil_result = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0};
check_shape(y.shape, {2, 3, 4});
check_data(reinterpret_cast<const float*>(y.Data()), ceil_result.data(),
ceil_result.size());
Floor(x, &y);
std::vector<float> floor_result = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
check_shape(y.shape, {2, 3, 4});
check_data(reinterpret_cast<const float*>(y.Data()), floor_result.data(),
floor_result.size());
// Test Log function
Log(x, &y);
std::vector<float> log_result = {

View File

@@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <numeric>
#include <vector>
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/function/pad.h"
#include <numeric>
#include <vector>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "gtest_utils.h"
#include "gtest/gtest.h"
namespace fastdeploy {
namespace function {
@@ -29,12 +29,10 @@ TEST(fastdeploy, pad_2d) {
CheckData check_data;
CheckType check_type;
std::vector<float> inputs = {2, 4, 3,
7, 1, 5};
std::vector<float> expected_result = {2.2, 2.2, 2.2, 2.2, 2.2,
2.2, 2, 4, 3, 2.2,
2.2, 7, 1, 5, 2.2,
2.2, 2.2, 2.2, 2.2, 2.2};
std::vector<float> inputs = {2, 4, 3, 7, 1, 5};
std::vector<float> expected_result = {2.2, 2.2, 2.2, 2.2, 2.2, 2.2, 2,
4, 3, 2.2, 2.2, 7, 1, 5,
2.2, 2.2, 2.2, 2.2, 2.2, 2.2};
input.SetExternalData({2, 3}, FDDataType::FP32, inputs.data());
Pad(input, &output, {1, 1, 1, 1}, 2.2);
@@ -50,12 +48,9 @@ TEST(fastdeploy, pad_2d_int32_t) {
CheckData check_data;
CheckType check_type;
std::vector<int32_t> inputs = {2, 4, 3,
7, 1, 5};
std::vector<int32_t> expected_result = {2, 2, 2, 2, 2,
2, 2, 4, 3, 2,
2, 7, 1, 5, 2,
2, 2, 2, 2, 2};
std::vector<int32_t> inputs = {2, 4, 3, 7, 1, 5};
std::vector<int32_t> expected_result = {2, 2, 2, 2, 2, 2, 2, 4, 3, 2,
2, 7, 1, 5, 2, 2, 2, 2, 2, 2};
input.SetExternalData({2, 3}, FDDataType::INT32, inputs.data());
Pad(input, &output, {1, 1, 1, 1}, 2.2);

View File

@@ -0,0 +1,66 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/function/quantile.h"
#include "glog/logging.h"
#include "gtest_utils.h"
#include "gtest/gtest.h"
#include <array>
#include <vector>
namespace fastdeploy {
namespace function {
std::vector<float> CreateTestData() {
// Shape: [2, 3, 4]
std::vector<float> x_data = {
1.8428625, 0.6461913, 0.13740455, 0.11430702, 0.659926, 0.535816,
0.7429162, 0.8456049, -1.21228176, 0.29970083, 0.8621713, 0.40894133,
0.12684688, 2.1566195, -9.42884097, 20.8476526, 0.2458633, 0.669046,
0.87888306, 0.6762589, 0.666453, 0.32523027, 0.4139388, 0.8341406};
return x_data;
}
TEST(fastdeploy, quantile) {
CheckShape check_shape;
CheckData check_data;
FDTensor x, y;
auto test_data = CreateTestData();
x.SetExternalData({2, 3, 4}, FDDataType::FP32, test_data.data());
std::vector<float> result = {1.834282, 2.149067, 0.089573, 20.743986,
0.657856, 0.66838, 0.878203, 0.844758,
0.657059, 0.325103, 0.85993, 0.832015};
Quantile(x, {0.995}, {0}, &y);
check_shape(y.shape, {3, 4});
check_data(reinterpret_cast<const float*>(y.Data()), result.data(),
result.size());
result = {1.831033, 0.645088, 0.860979, 0.841238,
0.662247, 2.141744, 0.874234, 20.647517};
Quantile(x, {0.995}, {1}, &y);
check_shape(y.shape, {2, 4});
check_data(reinterpret_cast<const float*>(y.Data()), result.data(),
result.size());
result = {1.824912, 0.844065, 0.855373, 20.567287, 0.875844, 0.831625};
Quantile(x, {0.995}, {2}, &y);
check_shape(y.shape, {2, 3});
check_data(reinterpret_cast<const float*>(y.Data()), result.data(),
result.size());
}
} // namespace function
} // namespace fastdeploy

View File

@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <array>
#include <vector>
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/function/reduce.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "gtest_utils.h"
#include "gtest/gtest.h"
#include <array>
#include <vector>
namespace fastdeploy {
namespace function {

View File

@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/function/softmax.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "gtest_utils.h"
#include "gtest/gtest.h"
#include <vector>
namespace fastdeploy {
namespace function {

109
tests/function/test_sort.cc Normal file
View File

@@ -0,0 +1,109 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/function/sort.h"
#include "glog/logging.h"
#include "gtest_utils.h"
#include "gtest/gtest.h"
#include <array>
#include <vector>
namespace fastdeploy {
namespace function {
std::vector<float> CreateTestData() {
// Shape: [2, 3, 4]
std::vector<float> x_data = {
0.8428625, 0.6461913, 0.13740455, 0.11430702, 0.659926, 0.535816,
0.7429162, 0.8456049, 0.21228176, 0.29970083, 0.8621713, 0.40894133,
0.12684688, 0.1566195, 0.42884097, 0.8476526, 0.2458633, 0.669046,
0.87888306, 0.6762589, 0.666453, 0.32523027, 0.4139388, 0.8341406};
return x_data;
}
TEST(fastdeploy, sort_dim0) {
CheckShape check_shape;
CheckData check_data;
FDTensor x, out, indices;
auto test_data = CreateTestData();
x.SetExternalData({2, 3, 4}, FDDataType::FP32, test_data.data());
Sort(x, &out, &indices, 0);
std::vector<float> out_result = {
0.126847, 0.15662, 0.137405, 0.114307, 0.245863, 0.535816,
0.742916, 0.676259, 0.212282, 0.299701, 0.413939, 0.408941,
0.842862, 0.646191, 0.428841, 0.847653, 0.659926, 0.669046,
0.878883, 0.845605, 0.666453, 0.32523, 0.862171, 0.834141};
std::vector<int64_t> indices_result = {1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0,
0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1};
check_shape(out.shape, {2, 3, 4});
check_data(reinterpret_cast<const float*>(out.Data()), out_result.data(),
out_result.size());
check_shape(indices.shape, {2, 3, 4});
check_data(reinterpret_cast<const int64_t*>(indices.Data()),
indices_result.data(), indices_result.size());
}
TEST(fastdeploy, sort_dim1) {
CheckShape check_shape;
CheckData check_data;
FDTensor x, out, indices;
auto test_data = CreateTestData();
x.SetExternalData({2, 3, 4}, FDDataType::FP32, test_data.data());
Sort(x, &out, &indices, 1);
std::vector<float> out_result = {
0.212282, 0.299701, 0.137405, 0.114307, 0.659926, 0.535816,
0.742916, 0.408941, 0.842862, 0.646191, 0.862171, 0.845605,
0.126847, 0.15662, 0.413939, 0.676259, 0.245863, 0.32523,
0.428841, 0.834141, 0.666453, 0.669046, 0.878883, 0.847653};
std::vector<int64_t> indices_result = {2, 2, 0, 0, 1, 1, 1, 2, 0, 0, 2, 1,
0, 0, 2, 1, 1, 2, 0, 2, 2, 1, 1, 0};
check_shape(out.shape, {2, 3, 4});
check_data(reinterpret_cast<const float*>(out.Data()), out_result.data(),
out_result.size());
check_shape(indices.shape, {2, 3, 4});
check_data(reinterpret_cast<const int64_t*>(indices.Data()),
indices_result.data(), indices_result.size());
}
TEST(fastdeploy, sort_dim2) {
CheckShape check_shape;
CheckData check_data;
FDTensor x, out, indices;
auto test_data = CreateTestData();
x.SetExternalData({2, 3, 4}, FDDataType::FP32, test_data.data());
Sort(x, &out, &indices, 2);
std::vector<float> out_result = {
0.114307, 0.137405, 0.646191, 0.842862, 0.535816, 0.659926,
0.742916, 0.845605, 0.212282, 0.299701, 0.408941, 0.862171,
0.126847, 0.15662, 0.428841, 0.847653, 0.245863, 0.669046,
0.676259, 0.878883, 0.32523, 0.413939, 0.666453, 0.834141};
std::vector<int64_t> indices_result = {3, 2, 1, 0, 1, 0, 2, 3, 0, 1, 3, 2,
0, 1, 2, 3, 0, 1, 3, 2, 1, 2, 0, 3};
check_shape(out.shape, {2, 3, 4});
check_data(reinterpret_cast<const float*>(out.Data()), out_result.data(),
out_result.size());
check_shape(indices.shape, {2, 3, 4});
check_data(reinterpret_cast<const int64_t*>(indices.Data()),
indices_result.data(), indices_result.size());
}
} // namespace function
} // namespace fastdeploy

View File

@@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <numeric>
#include <vector>
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/function/transpose.h"
#include <numeric>
#include <vector>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "gtest_utils.h"
#include "gtest/gtest.h"
namespace fastdeploy {
namespace function {

View File

@@ -13,9 +13,9 @@
// limitations under the License.
#pragma once
#include <vector>
#include <cmath>
#include "gtest/gtest.h"
#include <cmath>
#include <vector>
namespace fastdeploy {
@@ -33,8 +33,8 @@ struct CheckData {
template <typename T>
void operator()(const T* lhs_ptr, const T* rhs_ptr, int num, int atol = 0) {
for (int i = 0; i < num; ++i) {
// ASSERT_FLOAT_EQ(lhs_ptr[i], rhs_ptr[i]);
int abs_diff = abs(lhs_ptr[i] - rhs_ptr[i]);
// ASSERT_FLOAT_EQ(lhs_ptr[i], rhs_ptr[i]);
int abs_diff = std::abs(lhs_ptr[i] - rhs_ptr[i]);
if (abs_diff > atol) {
std::cout << "lhs_ptr: " << static_cast<int64_t>(lhs_ptr[i])
<< " rhs_ptr: " << static_cast<int64_t>(rhs_ptr[i])
@@ -44,16 +44,16 @@ struct CheckData {
ASSERT_EQ(1, 1);
}
}
void operator()(const float* lhs_ptr, const float* rhs_ptr,
int num, float atol = 1e-06, float rtol = 1e-06) {
void operator()(const float* lhs_ptr, const float* rhs_ptr, int num,
float atol = 1e-06, float rtol = 1e-06) {
for (int i = 0; i < num; ++i) {
float abs_diff = fabs(lhs_ptr[i] - rhs_ptr[i]);
float rel_diff = abs_diff / (std::max(fabs(lhs_ptr[i]),
fabs(rhs_ptr[i])) + 1e-06);
float rel_diff =
abs_diff / (std::max(fabs(lhs_ptr[i]), fabs(rhs_ptr[i])) + 1e-06);
if (abs_diff > atol && rel_diff > rtol) {
std::cout << "lhs_ptr: " << lhs_ptr[i] << " rhs_ptr: "
<< rhs_ptr[i] << " abs_diff: " << abs_diff
<< " rel_diff: " << rel_diff << std::endl;
std::cout << "lhs_ptr: " << lhs_ptr[i] << " rhs_ptr: " << rhs_ptr[i]
<< " abs_diff: " << abs_diff << " rel_diff: " << rel_diff
<< std::endl;
ASSERT_EQ(1, 0);
}
ASSERT_EQ(1, 1);