[Function] Add slice function (#719)

* fix math functions

* add slice function
This commit is contained in:
Jack Zhou
2022-11-28 15:33:33 +08:00
committed by GitHub
parent dd18471b41
commit d0307192f9
5 changed files with 282 additions and 3 deletions

View File

@@ -28,11 +28,13 @@ namespace function {
template <typename T, typename Functor>
void ActivationImpl(const FDTensor& X, FDTensor* Out, const Functor& functor) {
FDASSERT(Out != nullptr, "Output Out should not be nullptr");
FDTensor out_tmp;
auto x = EigenVector<T>::Flatten(X);
Out->Allocate(X.Shape(), X.Dtype());
auto out = EigenVector<T>::Flatten(*Out);
out_tmp.Allocate(X.Shape(), X.Dtype());
auto out = EigenVector<T>::Flatten(out_tmp);
const auto& dev = *EigenDeviceWrapper::GetInstance()->GetDevice();
functor(dev, x, out);
*Out = std::move(out_tmp);
}
DEFINE_ACTIVATION_KERNEL(Sqrt, SqrtFunctor)