// /* // * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & // * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 // * // * Licensed under the Apache License, Version 2.0 (the "License"); // * you may not use this file except in compliance with the License. // * You may obtain a copy of the License at // * // * http://www.apache.org/licenses/LICENSE-2.0 // * // * Unless required by applicable law or agreed to in writing, software // * distributed under the License is distributed on an "AS IS" BASIS, // * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // * See the License for the specific language governing permissions and // * limitations under the License. // */ #pragma once #include #include #include "moe/fused_moe_imp_op.h" #include "moe/fused_moe_helper.h" #include "cutlass/numeric_conversion.h" // Ignore CUTLASS warnings about type punning #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" #pragma GCC diagnostic ignored "-Wunused-function" // #include "paddle/phi/backends/gpu/gpu_info.h" #pragma GCC diagnostic pop #include "helper.h" #define WARP_SIZE 32 namespace phi { struct GpuLaunchConfig { dim3 block_per_grid; dim3 thread_per_block; }; inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) { int blocks_x = cols; int blocks_y = 1; int blocks_z = 1; if (blocks_x > 1024) { blocks_y = 256; blocks_x = (blocks_x + blocks_y - 1) / blocks_y; } GpuLaunchConfig config; config.block_per_grid.x = blocks_x; config.block_per_grid.y = blocks_y; config.block_per_grid.z = blocks_z; return config; } constexpr static int FINALIZE_THREADS_PER_BLOCK = 256; template __host__ __device__ constexpr static U arrayConvert(T const& input) { using Type = typename U::Element; static_assert(T::kElements == U::kElements); U u; #pragma unroll for (int i = 0; i < U::kElements; i++) { u[i] = static_cast(input[i]); } return u; } struct uint8 { uint4 u; uint4 v; }; template struct BytesToType {}; template<> struct BytesToType<32> { using Type = uint8; static_assert(sizeof(Type) == 32); }; template<> struct BytesToType<16> { using Type = uint4; static_assert(sizeof(Type) == 16); }; template<> struct BytesToType<8> { using Type = uint64_t; static_assert(sizeof(Type) == 8); }; template<> struct BytesToType<4> { using Type = uint32_t; static_assert(sizeof(Type) == 4); }; template<> struct BytesToType<2> { using Type = uint16_t; static_assert(sizeof(Type) == 2); }; template<> struct BytesToType<1> { using Type = uint8_t; static_assert(sizeof(Type) == 1); }; template