/* * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /** * @file epilogue_helpers.h * * This file includes types for the epilogues. The empty structs exist so we can signal to template * code the type of epilogue we want to run, and let the underlying code specify the details such as * element types, accumulator type and elements per vector access. * */ #pragma once #include "cutlass/epilogue/thread/linear_combination.h" #include "cutlass/epilogue/thread/linear_combination_generic.h" #include "cutlass/epilogue/thread/linear_combination_relu.h" #include "cutlass/epilogue/thread/linear_combination_silu.h" #include "cutlass_extensions/epilogue/thread/fused_activations.h" // #include "cutlass/epilogue/fusion/operations.hpp" namespace cutlass_extensions { struct EpilogueOpBiasSilu { }; struct EpilogueOpBiasReLU { }; struct EpilogueOpBiasFtGelu { }; struct EpilogueOpBias { }; struct EpilogueOpDefaultSilu { }; struct EpilogueOpDefaultReLU { }; struct EpilogueOpDefaultFtGelu { }; struct EpilogueOpDefault { }; template struct Epilogue { static_assert(sizeof(ElementType) == 0, "Unrecognized Epilogue Tag"); }; constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling; template struct Epilogue { using Op = cutlass::epilogue::thread::LinearCombinationSilu; }; template struct Epilogue { using Op = cutlass::epilogue::thread::LinearCombinationRelu; }; template struct Epilogue { using Op = cutlass::epilogue::thread::LinearCombinationGeneric; }; template struct Epilogue { using Op = cutlass::epilogue::thread::LinearCombination; }; constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default; template struct Epilogue { using Op = cutlass::epilogue::thread::LinearCombinationSilu; }; template struct Epilogue { using Op = cutlass::epilogue::thread::LinearCombinationRelu; }; template struct Epilogue { using Op = cutlass::epilogue::thread::LinearCombinationGeneric; }; template struct Epilogue { using Op = cutlass::epilogue::thread::LinearCombination; }; } // namespace cutlass_extensions