diff --git a/fastdeploy/function/concat.cc b/fastdeploy/function/concat.cc new file mode 100644 index 000000000..c2b1f2744 --- /dev/null +++ b/fastdeploy/function/concat.cc @@ -0,0 +1,124 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/function/concat.h" + +#include +#include +#include +#include +#include "fastdeploy/utils/utils.h" + +namespace fastdeploy { + +std::string Str(const std::vector& shape) { + std::ostringstream oss; + oss << "[ " << shape[0]; + for (int i = 1; i < shape.size(); ++i) { + oss << " ," << shape[i]; + } + oss << " ]"; + return oss.str(); +} + +std::vector ComputeAndCheckConcatOutputShape( + const std::vector& input, int axis) { + const size_t n = input.size(); + auto out_dims = input[0].shape; + size_t in_zero_dims_size = out_dims.size(); + for (size_t i = 1; i < n; ++i) { + FDASSERT(input[i].shape.size() == out_dims.size(), + "The shape of input[0] and input[%d] is expected to be equal. But " + "received input[0]'s shape = %s, input[%d]'s shape = %s.", + i, Str(out_dims).c_str(), i, Str(input[i].shape).c_str()); + for (size_t j = 0; j < in_zero_dims_size; j++) { + if (j == axis) { + out_dims[axis] += input[i].shape[axis]; + } else { + FDASSERT( + input[0].shape[j] == input[i].shape[j], + "The %d-th dimension of input[0] and input[%d] is expected to be " + "equal." + "But received input[0]'s shape = %s, input[%d]'s shape = %s.", + j, i, Str(input[0].shape).c_str(), i, Str(input[i].shape).c_str()); + } + } + } + return out_dims; +} + +template +struct ConcatFunctor { + void operator()(const std::vector& input, int axis, + FDTensor* output) { + size_t num = input.size(); + + int64_t rows = 1; + auto dim_0 = input[0].shape; + for (int i = 0; i < axis; ++i) { + rows *= dim_0[i]; + } + int64_t out_rows = rows, out_cols = 0; + + std::vector input_cols(num); + for (size_t i = 0; i < num; ++i) { + int64_t t_cols = input[i].Numel() / rows; + out_cols += t_cols; + input_cols[i] = t_cols; + } + + // computation + T* output_data = reinterpret_cast(output->Data()); + int64_t col_idx = 0; + for (size_t j = 0; j < num; ++j) { + int64_t col_len = input_cols[j]; + const T* input_data = reinterpret_cast(input[j].Data()); + for (int64_t k = 0; k < out_rows; ++k) { + std::memcpy(output_data + k * out_cols + col_idx, + input_data + k * col_len, sizeof(T) * col_len); + } + col_idx += col_len; + } + } +}; + +template +void ConcatKernel(const std::vector& input, FDTensor* output, + int axis) { + auto output_shape = ComputeAndCheckConcatOutputShape(input, axis); + output->Allocate(output_shape, TypeToDataType::dtype); + + ConcatFunctor functor; + functor(input, axis, output); +} + +void Concat(const std::vector& x, FDTensor* out, int axis) { + FDASSERT(x.size() > 0, + "The number of FDTensor array should be larger than 0, but the size " + "of input is %d", + x.size()); + int64_t rank = x[0].shape.size(); + FDASSERT(axis >= -rank && axis < rank, + "The axis is expected to be in range of [%d, %d), but got %d", -rank, + rank, axis); + if (axis < 0) { + axis += rank; + } + FDTensor out_temp; + FD_VISIT_ALL_TYPES(x[0].dtype, "Concat", + ([&] { ConcatKernel(x, &out_temp, axis); })); + *out = std::move(out_temp); +} + +} // namespace fastdeploy diff --git a/fastdeploy/function/concat.h b/fastdeploy/function/concat.h new file mode 100644 index 000000000..22e388b0f --- /dev/null +++ b/fastdeploy/function/concat.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fastdeploy/core/fd_tensor.h" + +namespace fastdeploy { + +/** Excute the concatenate operation for input FDTensor along given axis. + @param x The input tensor. + @param out The output tensor which stores the result. + @param axisi Axis which will be concatenated. +*/ + +FASTDEPLOY_DECL void Concat(const std::vector& x, FDTensor* out, + int axis = 0); + +} // namespace fastdeploy diff --git a/tests/function/test_concat.cc b/tests/function/test_concat.cc new file mode 100644 index 000000000..4bdd49dd1 --- /dev/null +++ b/tests/function/test_concat.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "fastdeploy/core/fd_tensor.h" +#include "fastdeploy/function/concat.h" +#include "glog/logging.h" +#include "gtest/gtest.h" +#include "gtest_utils.h" + +namespace fastdeploy { + +TEST(fastdeploy, concat1) { + CheckShape check_shape; + std::vector inputs(3); + FDTensor output; + inputs[0].Allocate({5, 1, 4, 5}, FDDataType::FP32); + inputs[1].Allocate({5, 2, 4, 5}, FDDataType::FP32); + inputs[2].Allocate({5, 3, 4, 5}, FDDataType::FP32); + Concat(inputs, &output, 1); + check_shape(output.shape, {5, 6, 4, 5}); +} + +TEST(fastdeploy, concat2) { + CheckShape check_shape; + std::vector inputs(3); + FDTensor output; + inputs[0].Allocate({2, 3, 4, 5}, FDDataType::FP32); + inputs[1].Allocate({2, 3, 4, 5}, FDDataType::FP32); + inputs[2].Allocate({2, 3, 4, 5}, FDDataType::FP32); + Concat(inputs, &output, 1); + check_shape(output.shape, {2, 9, 4, 5}); +} + +TEST(fastdeploy, concat3) { + CheckShape check_shape; + std::vector inputs(3); + FDTensor output; + inputs[0].Allocate({1, 256, 170, 256}, FDDataType::FP32); + inputs[1].Allocate({1, 128, 170, 256}, FDDataType::FP32); + inputs[2].Allocate({1, 128, 170, 256}, FDDataType::FP32); + Concat(inputs, &output, 1); + check_shape(output.shape, {1, 512, 170, 256}); +} + +TEST(fastdeploy, concat4) { + CheckShape check_shape; + std::vector inputs(3); + FDTensor output; + inputs[0].Allocate({2, 3, 4, 5}, FDDataType::FP32); + inputs[1].Allocate({2, 3, 4, 5}, FDDataType::FP32); + inputs[2].Allocate({0, 3, 4, 5}, FDDataType::FP32); + Concat(inputs, &output, 0); + check_shape(output.shape, {4, 3, 4, 5}); +} + +TEST(fastdeploy, concat5) { + CheckShape check_shape; + std::vector inputs(3); + FDTensor output; + inputs[0].Allocate({5, 1, 4, 5}, FDDataType::FP32); + inputs[1].Allocate({5, 2, 4, 5}, FDDataType::FP32); + inputs[2].Allocate({5, 3, 4, 5}, FDDataType::FP32); + Concat(inputs, &output, -3); + check_shape(output.shape, {5, 6, 4, 5}); +} + +} // namespace fastdeploy