// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #include #endif #include #include // For cute::elect_one_sync() #include #include #include #include using namespace cute; template struct PackedHalf; template<> struct PackedHalf { using Type = __half2; }; template<> struct PackedHalf { using Type = nv_bfloat162; }; template __forceinline__ __device__ auto convert_type(Tensor const &tensor) { using From_type = typename Engine::value_type; constexpr int numel = decltype(size(tensor))::value; cutlass::NumericArrayConverter convert_op; auto frag = convert_op(*reinterpret_cast *>(tensor.data())); return make_tensor(make_rmem_ptr(&frag), tensor.layout()); } template __forceinline__ __device__ void convert_c4_2_fp8(const int32_t * src, int32_t * dst1, int32_t * dst2) { #pragma unroll for (int i = 0; i < numel; ++i) { dst1[i] = (src[i] >> 4) & 0x0f0f0f0f; dst2[i] = src[i] & 0x0f0f0f0f; } } template __forceinline__ __device__ void gemm( TiledMma &tiled_mma, Tensor0 &tCrA, Tensor1 &tCsA, Tensor2 const &tCrB, Tensor3 &tCrC, TiledCopyA const &tiled_copy_A, ThrCopyA const &thr_copy_A) { constexpr bool Is_RS = !cute::is_base_of::value; Tensor tCrA1 = make_tensor(tCrA.layout()); Tensor tCrA2 = make_tensor(tCrA.layout()); if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } warpgroup_fence_operand(tCrC); if constexpr (arrive) { warpgroup_arrive(); } constexpr int numel = decltype(size(tCrA(_, _, 0)))::value / 4; Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); cute::copy(tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { if (k_block < size<2>(tCrA) - 1) { cute::copy(tiled_copy_A, tCsA(_, _, k_block + 1), tCrA_copy_view(_, _, k_block + 1)); } int32_t * tCrA_data = reinterpret_cast(tCrA(_,_,k_block).data()); int32_t * tCrA1_data = reinterpret_cast(tCrA1(_,_,k_block).data()); int32_t * tCrA2_data = reinterpret_cast(tCrA2(_,_,k_block).data()); convert_c4_2_fp8(tCrA_data, tCrA1_data, tCrA2_data); cute::gemm(tiled_mma, tCrA1(_,_,k_block), tCrB(_,_,2 * k_block), tCrC); cute::gemm(tiled_mma, tCrA2(_,_,k_block), tCrB(_,_, 2 * k_block + 1), tCrC); } if constexpr (commit) { warpgroup_commit_batch(); } if constexpr (wg_wait >= 0) { warpgroup_wait(); } warpgroup_fence_operand(tCrC); if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } }