// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #include #endif #include #include // For cute::elect_one_sync() #include #include #include #include using namespace cute; template struct PackedHalf; template<> struct PackedHalf { using Type = __half2; }; template<> struct PackedHalf { using Type = nv_bfloat162; }; template __device__ GmmaDescriptor make_smem_desc( PointerType smem_ptr, int layout_type, int leading_byte_offset = 0, int stride_byte_offset = 1024) { GmmaDescriptor desc; auto uint_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); desc.bitfield.start_address_ = uint_ptr >> 4; desc.bitfield.layout_type_ = layout_type; desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4; desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4; desc.bitfield.base_offset_ = 0; return desc; } template __forceinline__ __device__ static void gemm(uint64_t const& desc_a, uint64_t const& desc_b, float* d, const uint32_t e, std::index_sequence) { Mma::fma(desc_a, desc_b, d[Idx]..., e, GMMA::ScaleOut::One); } template __forceinline__ __device__ void gemm( const T * sA, const T * sB, float * acc_c, const uint32_t *E) { constexpr int acc_num = sizeof(Mma::CRegisters) / sizeof(float); warpgroup_arrive(); // 选择的下标 对应的16进制 // 01 4 // 02 8 // 03 12 // 12 9 // 13 13 // 23 14 #pragma unroll for (int i = 0; i < kBlockK / 64; i++) { GmmaDescriptor a_desc = make_smem_desc(sA + i * 32, 1, 0, 1024); GmmaDescriptor b_desc = make_smem_desc(sB + i * 64, 1, 0, 1024); gemm(a_desc, b_desc, acc_c, E[i * NumMmaThreads], std::make_index_sequence{}); } warpgroup_commit_batch(); warpgroup_wait<0>(); }