Files
FastDeploy/paddle2onnx/mapper/tensor/temporal_shift.cc
Jason 6343b0db47 [Build] Support build with source code of Paddle2ONNX (#1559)
* Add notes for tensors

* Optimize some apis

* move some warnings

* Support build with Paddle2ONNX

* Add protobuf support

* Fix compile on mac

* add clearn package script

* Add paddle2onnx code

* remove submodule

* Add onnx ocde

* remove softlink

* add onnx code

* fix error

* Add cmake file

* fix patchelf

* update paddle2onnx

* Delete .gitmodules

---------

Co-authored-by: PaddleCI <paddle_ci@example.com>
Co-authored-by: pangyoki <pangyoki@126.com>
Co-authored-by: jiangjiajun <jiangjiajun@baidu.lcom>
2023-03-17 10:03:22 +08:00

80 lines
2.7 KiB
C++

// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle2onnx/mapper/tensor/temporal_shift.h"
namespace paddle2onnx {
REGISTER_MAPPER(temporal_shift, TemporalShiftMapper)
int32_t TemporalShiftMapper::GetMinOpset(bool verbose) {
if (data_format_ == "NHWC") {
Error() << "Only support data_format of NCHW, but now the data format is "
<< data_format_ << "." << std::endl;
return -1;
}
auto input_info = GetOutput("Out");
if (input_info[0].Rank() != 4) {
Error() << "The input dims must be 4, but now the input dims is "
<< std::to_string(input_info[0].Rank()) << "." << std::endl;
return -1;
}
return 7;
}
void TemporalShiftMapper::Opset7() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");
int64_t C = input_info[0].shape[1];
int64_t H = input_info[0].shape[2];
int64_t W = input_info[0].shape[3];
std::vector<int64_t> reshape_shape = {-1, seg_num_, C, H, W};
std::string reshape_input =
helper_->Reshape(input_info[0].name, reshape_shape);
std::vector<int64_t> paddings(10, 0);
paddings[1] = 1;
paddings[6] = 1;
std::string padding_constant_node =
helper_->Constant(GetOnnxDtype(P2ODataType::INT64), paddings);
std::string pad_node = "";
if (helper_->GetOpsetVersion() < 11) {
auto node = helper_->MakeNode("Pad", {reshape_input});
AddAttribute(node, "pads", paddings);
float val = 0.0;
AddAttribute(node, "value", val);
pad_node = node->output(0);
} else {
auto node =
helper_->MakeNode("Pad", {reshape_input, padding_constant_node});
pad_node = node->output(0);
}
int64_t C1 = C * shift_ratio_;
int64_t C2 = 2 * C * shift_ratio_;
std::string slice_1 =
helper_->Slice(pad_node, {1, 2}, {0, 0}, {seg_num_, C1});
std::string slice_2 =
helper_->Slice(pad_node, {1, 2}, {2, C1}, {2 + seg_num_, C2});
std::string slice_3 =
helper_->Slice(pad_node, {1, 2}, {1, C2}, {1 + seg_num_, C});
std::string concat_out = helper_->Concat({slice_1, slice_2, slice_3}, 2);
helper_->Reshape(concat_out, output_info[0].name, {-1, C, H, W});
}
} // namespace paddle2onnx