diff --git a/.gitignore b/.gitignore index f17e02f7c..c0beb6c0f 100644 --- a/.gitignore +++ b/.gitignore @@ -170,3 +170,6 @@ third_party custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_*.cu custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_template.h + +custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8Afp8_sparse_gemm_*.cu +custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8Afp8_sparse_gemm_template.h diff --git a/custom_ops/gpu_ops/wfp8afp8_sparse_gemm/kernel_traits.h b/custom_ops/gpu_ops/wfp8afp8_sparse_gemm/kernel_traits.h new file mode 100644 index 000000000..db4e86a2a --- /dev/null +++ b/custom_ops/gpu_ops/wfp8afp8_sparse_gemm/kernel_traits.h @@ -0,0 +1,151 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cute/algorithm/copy.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" + +using namespace cute; + +template +struct SharedStorage { + union { + struct { + cute::array_aligned> smem_a; + cute::array_aligned> smem_e; + cute::array_aligned> smem_b; + }; + cute::array_aligned> smem_c; + }; + + struct { + typename cutlass::PipelineTmaAsync::SharedStorage pipeline; + }; +}; + +template +struct Kernel_traits { + using Element = elem_type; + using ElementAccum = float; + using ElementOutput = OutputType; + static_assert(cutlass::sizeof_bits_v == 8); + + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; + static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup; + static constexpr int NumMmaThreads = kNThreads - NumProducerThreads; + + static_assert(kNWarps_ == 12); + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kBlockK = kBlockK_; + static constexpr int kTiles = kTiles_; + static constexpr int TokenPackSize = TokenPackSize_; + static constexpr int TAIL_N = TAIL_N_; + static constexpr int M = M_; + + using TileShape_MNK = Shape, Int, Int>; + using TileShape_MNK_TAIL = Shape, Int, Int>; + static constexpr int kClusterM = kClusterM_; + using ClusterShape_MNK = Shape, _1, _1>; + + static constexpr int kStages = kStages_; + static_assert(kStages > 1); + + using AtomLayoutMNK = Layout, _1, _1>>; + + using TiledMma = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), + AtomLayoutMNK{})); + + using Mma = decltype(cute::GMMA::ss_op_selector_sparse()); + + using Mma_TAIL = decltype(cute::GMMA::ss_op_selector_sparse()); + + using SmemLayoutAtomA = decltype( + cutlass::gemm::collective::detail::rs_smem_selector< + GMMA::Major::K, Element, Int, Int>()); + + using SmemLayoutA = decltype( + tile_to_shape(SmemLayoutAtomA{}, + make_shape(Int{}, Int{}, Int{}))); + + using SmemLayoutAtomB = decltype( + cutlass::gemm::collective::detail::ss_smem_selector< + GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{}))>()); + + using SmemLayoutB = decltype( + tile_to_shape(SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + + using SmemLayoutAtomB_TAIL = decltype( + cutlass::gemm::collective::detail::rs_smem_selector< + GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK_TAIL{})), + decltype(cute::get<2>(TileShape_MNK_TAIL{}))>()); + + using SmemLayoutB_TAIL = decltype( + tile_to_shape(SmemLayoutAtomB_TAIL{}, + make_shape( + shape<1>(TileShape_MNK_TAIL{}), + shape<2>(TileShape_MNK_TAIL{}), + Int{}) + )); + using SmemLayoutAtomC = decltype( + cutlass::gemm::collective::detail::ss_smem_selector< + GMMA::Major::K, ElementOutput, + decltype(cute::get<0>(TileShape_MNK{})), + decltype(cute::get<1>(TileShape_MNK{}))>()); + + using SmemLayoutC = decltype(tile_to_shape(SmemLayoutAtomC{}, select<0, 1>(TileShape_MNK{}))); + + using SmemLayoutE = Layout, Int, Int>>; + + using SharedStorage = SharedStorage< + kStages, Element, ElementOutput, SmemLayoutA, SmemLayoutE, SmemLayoutB, SmemLayoutC>; + + using MainloopPipeline = typename cutlass::PipelineTmaAsync; + using PipelineState = typename cutlass::PipelineState; + + static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v); + static constexpr int kNumThreadsPerRow = kBlockN / kNumVecElem; + static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow; + using TiledCopyCAtom = cute::Copy_Atom, OutputType>; + using TiledCopyCThrLayout = decltype(cute::make_layout( + cute::make_shape(Int{}, Int{}), + LayoutRight{})); + using TiledCopyCValLayout = decltype(cute::make_layout( + cute::make_shape(_1{}, Int{}), + LayoutRight{})); + using TiledCopyC = decltype(make_tiled_copy( + TiledCopyCAtom{}, + TiledCopyCThrLayout{}, // Thr layout + TiledCopyCValLayout{} // Val layout + )); +}; diff --git a/custom_ops/gpu_ops/wfp8afp8_sparse_gemm/mainloop_fwd.h b/custom_ops/gpu_ops/wfp8afp8_sparse_gemm/mainloop_fwd.h new file mode 100644 index 000000000..10f86d53b --- /dev/null +++ b/custom_ops/gpu_ops/wfp8afp8_sparse_gemm/mainloop_fwd.h @@ -0,0 +1,466 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "cutlass/pipeline/pipeline.hpp" + +#include "cute/tensor.hpp" + +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "utils.hpp" + +using namespace cute; +template +struct CollectiveMainloopFwd { + + using Element = typename Ktraits::Element; + using ElementOutput = typename Ktraits::ElementOutput; + using TileShape_MNK = typename Ktraits::TileShape_MNK; + using ClusterShape = typename Ktraits::ClusterShape_MNK; + using ElementAccum = typename Ktraits::ElementAccum; + + static constexpr int kStages = Ktraits::kStages; + static constexpr int kBlockM = Ktraits::kBlockM; + static constexpr int kBlockN = Ktraits::kBlockN; + static constexpr int kBlockK = Ktraits::kBlockK; + static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup; + static constexpr int kTiles = Ktraits::kTiles; + static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{}); + static constexpr int TokenPackSize = Ktraits::TokenPackSize; + static constexpr int M = Ktraits::M; + + + using GmemTiledCopy = cute::SM90_TMA_LOAD; + using GmemTiledCopyStore = cute::SM90_TMA_STORE; + + using SmemLayoutA = typename Ktraits::SmemLayoutA; + using SmemLayoutB = typename Ktraits::SmemLayoutB; + using SmemLayoutC = typename Ktraits::SmemLayoutC; + using SmemLayoutE = typename Ktraits::SmemLayoutE; + using SmemLayoutB_TAIL = typename Ktraits::SmemLayoutB_TAIL; + + using ShapeT = cute::Shape; + using StrideT = cute::Shape; + using LayoutT = cute::Layout; + + using WShapeT = cute::Shape; + using WStrideT = cute::Shape; + using WLayoutT = cute::Layout; + + using EShapeT = cute::Shape; + using EStrideT = cute::Shape<_1, int64_t, int64_t, int64_t, int64_t>; + using ELayoutT = cute::Layout; + + using TMA_A = decltype(make_tma_copy( + GmemTiledCopy{}, + make_tensor( + make_gmem_ptr(static_cast(nullptr)), + WShapeT{}, + WStrideT{} + ), + SmemLayoutA{}(_, _, _0{}), + select<0, 1>(Shape, Int>{}), + size<0>(ClusterShape{}))); + + using TMA_B = decltype(make_tma_copy( + GmemTiledCopy{}, + make_tensor( + make_gmem_ptr(static_cast(nullptr)), + ShapeT{}, + StrideT{} + ), + take<0, 2>(SmemLayoutB{}), + select<1, 2>(TileShape_MNK{}), + size<0>(ClusterShape{}))); + + using TMA_E = decltype(make_tma_copy( + GmemTiledCopy{}, + make_tensor( + make_gmem_ptr(static_cast(nullptr)), + EShapeT{}, + EStrideT{} + ), + SmemLayoutE{}(_, _, _0{}), + select<0, 1>(Shape, Int>{}), + size<0>(ClusterShape{}))); + + using MainloopPipeline = typename Ktraits::MainloopPipeline; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename MainloopPipeline::PipelineState; + + static constexpr uint32_t TmaTransactionBytesA = static_cast(size(take<0, 2>(SmemLayoutA{})) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesB = static_cast(size(take<0, 2>(SmemLayoutB{})) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesE = static_cast(size(take<0, 2>(SmemLayoutE{})) * cutlass::sizeof_bits_v / 8); + + struct Arguments { + Element const* ptr_A; + WLayoutT layout_A; + uint32_t const* ptr_E; + ELayoutT layout_E; + Element const* ptr_B; + LayoutT layout_B; + ElementOutput * ptr_C; + LayoutT layout_C; + const int *tokens; + const float *weight_scale; + }; + + struct Params { + WLayoutT layout_A; + ELayoutT layout_E; + LayoutT layout_B; + TMA_A tma_load_A; + TMA_E tma_load_E; + TMA_B tma_load_B; + const int *tokens; + const float *weight_scale; + ElementOutput * ptr_C; + }; + + + Params static + to_underlying_arguments(Arguments const& args) { + Tensor mA = make_tensor(make_gmem_ptr(args.ptr_A), args.layout_A); + TMA_A tma_load_A = make_tma_copy( + GmemTiledCopy{}, + mA, + SmemLayoutA{}(_, _, _0{}), + select<0, 1>(Shape, Int>{}), + size<0>(ClusterShape{})); + Tensor mE = make_tensor(make_gmem_ptr(args.ptr_E), args.layout_E); + TMA_E tma_load_E = make_tma_copy( + GmemTiledCopy{}, + mE, + SmemLayoutE{}(_, _, _0{}), + select<0, 1>(Shape, Int>{}), + size<0>(ClusterShape{})); + Tensor mB = make_tensor(make_gmem_ptr(args.ptr_B), args.layout_B); + TMA_B tma_load_B = make_tma_copy( + GmemTiledCopy{}, + mB, + SmemLayoutB{}(_, _, _0{}), + select<1, 2>(TileShape_MNK{}), + size<0>(ClusterShape{})); + + return {args.layout_A, args.layout_E, args.layout_B, + tma_load_A, tma_load_E, tma_load_B, + args.tokens, args.weight_scale, args.ptr_C}; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_A.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_B.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_E.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE void + store(Params const& mainloop_params, + float * acc_s, + SharedStorage& shared_storage, + const int pre_fix_tokens, + const int tokens, + const float * weight_scale, + const int bidm, + const int bidn, + const int bidb, + const int tidx) { + typename Ktraits::TiledMma tiled_mma; + using packHalf = typename PackedHalf::Type; + Tensor tOrO_out = make_tensor(partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{})).layout()); + + #pragma unroll + for (int i = 0; i < size(tOrO_out); i+=4) { + acc_s[i] *= weight_scale[0]; + acc_s[i + 1] *= weight_scale[0]; + acc_s[i + 2] *= weight_scale[1]; + acc_s[i + 3] *= weight_scale[1]; + *reinterpret_cast(&tOrO_out[i]) = packHalf(acc_s[i], acc_s[i + 2]); + *reinterpret_cast(&tOrO_out[i + 2]) = packHalf(acc_s[i + 1], acc_s[i + 3]); + } + + uint16_t *smem_c = reinterpret_cast(shared_storage.smem_c.data()); + + uint32_t * reg_data = reinterpret_cast(tOrO_out.data()); + + cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0); + + constexpr int k_copy_times = CUR_N / 16; + + #pragma unroll + for (int i = 0; i < k_copy_times; i++) { + uint32_t smem_ptr = cast_smem_ptr_to_uint(reinterpret_cast(smem_c + i * 16 * 128) + tidx); + #if defined(CUTE_ARCH_STSM_SM90_ENABLED) + asm volatile ( + "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" + :: "r"(smem_ptr), "r"(reg_data[4 * i + 0]), "r"(reg_data[4 * i + 2]), "r"(reg_data[4 * i + 1]), "r"(reg_data[4 * i + 3])); + #endif + } + + cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0); + const int batch_idx = TokenPackSize == 0 ? pre_fix_tokens * M : bidb * M * TokenPackSize; + ElementOutput * store_c = mainloop_params.ptr_C + batch_idx + bidn * (M * kBlockN) + bidm * kBlockM; + + const int reamin_tokens = tokens - bidn * kBlockN; + + const int col = tidx % 2; + + constexpr int kPackSize = 16 / sizeof(ElementOutput); + constexpr int kNumVecElem = kBlockM / kPackSize; + constexpr int copy_len = CUR_N * kNumVecElem; + #pragma unroll + for (int idx = tidx; idx < copy_len; idx += NumMmaThreads) { + const int idx_div2 = idx / 2; + const int store_idx = idx_div2 / 128 * 128 + idx_div2 % 8 * 16 + idx_div2 % 128 / 16 + idx_div2 % 16 / 8 * 8; + const int store_global_idx = store_idx * 2 + col; + const int row = store_global_idx / kNumVecElem; + const int col = store_global_idx % kNumVecElem; + if (row >= reamin_tokens) { + continue; + } + const int offset = row * (M / kPackSize) + col; + reinterpret_cast(store_c)[offset] = reinterpret_cast(smem_c)[idx]; + } + } + + template + CUTLASS_DEVICE auto get_local_packed_tensor( + const MTensor &mB, + const int tokens, + const int bidn) const { + + auto mB_this_batch = make_tensor( + mB.data(), + make_layout( + cute::make_shape(tokens, size<1>(mB)), + mB.stride() + )); + return local_tile(mB_this_batch, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _)); + } + + template + CUTLASS_DEVICE auto get_local_no_packed_tensor( + const MTensor &mB, + const int pre_fix_token, + const int actual_token, + const int bidn) const { + + auto g_offset = local_tile( + mB(_, _, 0), + cute::make_shape(1, size<1>(mB)), + make_coord(pre_fix_token, _0{})); + + auto g_tensor = make_tensor( + g_offset.data(), + make_layout( + cute::make_shape(actual_token, size<1>(mB)), + g_offset.stride() + )); + + Tensor gB = local_tile(g_tensor, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _)); + + return gB; + } + + + template + CUTLASS_DEVICE void + load(Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState& smem_pipe_write, + SharedStorage &shared_storage, + const int pre_fix_tokens, + const int tokens, + const int bidm, + const int bidn, + const int bidb, + const int tidx) { + + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{}); + Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{}); + Tensor sE = make_tensor(make_smem_ptr(shared_storage.smem_e.data()), SmemLayoutE{}); + + Tensor mA = mainloop_params.tma_load_A.get_tma_tensor(mainloop_params.layout_A.shape()); + Tensor mB = mainloop_params.tma_load_B.get_tma_tensor(mainloop_params.layout_B.shape()); + Tensor mE = mainloop_params.tma_load_E.get_tma_tensor(mainloop_params.layout_E.shape()); + + Tensor gA = local_tile(mA(_, _, _, bidm, bidb), select<0, 1>(Shape, Int>{}), make_coord(0,0,_)); + + Tensor gE = local_tile(mE(_, _, _, bidm, bidb), select<0, 1>(Shape, Int>{}), make_coord(0, 0)); + + auto [tAgA, tAsA] = tma_partition(mainloop_params.tma_load_A, _0{}, Layout{}, group_modes<0, 2>(sA), group_modes<0, 2>(gA)); + + auto [tEgE, tEsE] = tma_partition(mainloop_params.tma_load_E, _0{}, Layout{}, group_modes<0, 2>(sE), group_modes<0, 2>(gE)); + + int lane_predicate = cute::elect_one_sync(); + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + + if constexpr (TokenPackSize == 0) { + Tensor gB = get_local_no_packed_tensor( + mB, + pre_fix_tokens, + tokens, + bidn); + auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, _0{}, Layout{}, group_modes<0, 2>(sB), group_modes<0, 2>(gB)); + + const int kIters = kTiles / kStages; + if (tidx == 0) { + #pragma unroll + for (int kiter = 0; kiter < kIters; ++kiter) { + #pragma unroll + for (int s = 0; s < kStages; s++) { + const int i = kiter * kStages + s; + pipeline.producer_acquire(smem_pipe_write); + copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), + tAgA(_, i), tAsA(_, s)); + copy(mainloop_params.tma_load_E.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), + tEgE(_, i), tEsE(_, s)); + copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), + tBgB(_, i), tBsB(_, s)); + ++smem_pipe_write; + } + } + + #pragma unroll + for (int i = kIters * kStages; i < kTiles; ++i) { + pipeline.producer_acquire(smem_pipe_write); + copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), + tAgA(_, i), tAsA(_, smem_pipe_write.index())); + copy(mainloop_params.tma_load_E.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), + tEgE(_, i), tEsE(_, smem_pipe_write.index())); + copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), + tBgB(_, i), tBsB(_, smem_pipe_write.index())); + ++smem_pipe_write; + } + } + } else { + auto mB_this_batch = make_tensor( + mB(_, _, bidb).data(), + make_layout( + cute::make_shape(tokens, size<1>(mB)), + mB.stride() + )); + Tensor gB = local_tile(mB_this_batch, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _)); + auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, _0{}, Layout{}, group_modes<0, 2>(sB), group_modes<0, 2>(gB)); + + const int kIters = kTiles / kStages; + if (tidx == 0) { + #pragma unroll + for (int kiter = 0; kiter < kIters; ++kiter) { + #pragma unroll + for (int s = 0; s < kStages; s++) { + const int i = kiter * kStages + s; + pipeline.producer_acquire(smem_pipe_write); + copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), + tAgA(_, i), tAsA(_, s)); + copy(mainloop_params.tma_load_E.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), + tEgE(_, i), tEsE(_, s)); + copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), + tBgB(_, i), tBsB(_, s)); + ++smem_pipe_write; + } + } + + #pragma unroll + for (int i = kIters * kStages; i < kTiles; ++i) { + pipeline.producer_acquire(smem_pipe_write); + copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), + tAgA(_, i), tAsA(_, smem_pipe_write.index())); + copy(mainloop_params.tma_load_E.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), + tEgE(_, i), tEsE(_, smem_pipe_write.index())); + copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), + tBgB(_, i), tBsB(_, smem_pipe_write.index())); + ++smem_pipe_write; + } + } + } + } + + template + CUTLASS_DEVICE void + mma(Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState& smem_pipe_read, + SharedStorage& shared_storage, + float *acc_s, + const int tidx) { + + using sMemBLayout = std::conditional_t< + CUR_N == kBlockN, + SmemLayoutB, + SmemLayoutB_TAIL + >; + + using Mma = std::conditional_t< + CUR_N == kBlockN, + typename Ktraits::Mma, + typename Ktraits::Mma_TAIL + >; + + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{}); + Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), sMemBLayout{}); + Tensor sE = make_tensor(make_smem_ptr(shared_storage.smem_e.data()), SmemLayoutE{}); + + const int wg_idx = tidx / 128; + const int wg_offset = wg_idx * 64; + + auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + }; + + constexpr int E_STEP = kBlockK / 64 * NumMmaThreads; + constexpr int B_STEPS = CUR_N == 0 ? 1 : (kBlockN / CUR_N); + + const int kIters = kTiles / kStages; + #pragma unroll + for (int kiter = 0; kiter < kIters; ++kiter) { + #pragma unroll + for (int s = 0; s < kStages; s++) { + consumer_wait(pipeline, smem_pipe_read); + + gemm( + sA(_, _, s).data().get().get() + wg_offset, + sB(_, _, s * B_STEPS).data().get().get(), + acc_s, + shared_storage.smem_e.data() + s * E_STEP + tidx); + + pipeline.consumer_release(smem_pipe_read); + ++smem_pipe_read; + } + } + + #pragma unroll + for (int i = 0; i < kTiles % kStages; ++i) { + consumer_wait(pipeline, smem_pipe_read); + + gemm( + sA(_, _, i).data().get().get() + wg_offset, + sB(_, _, i * B_STEPS).data().get().get(), + acc_s, + shared_storage.smem_e.data() + i * E_STEP + tidx); + + pipeline.consumer_release(smem_pipe_read); + ++smem_pipe_read; + } + } + +}; diff --git a/custom_ops/gpu_ops/wfp8afp8_sparse_gemm/utils.hpp b/custom_ops/gpu_ops/wfp8afp8_sparse_gemm/utils.hpp new file mode 100644 index 000000000..33551e919 --- /dev/null +++ b/custom_ops/gpu_ops/wfp8afp8_sparse_gemm/utils.hpp @@ -0,0 +1,100 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include +#include // For cute::elect_one_sync() + +#include +#include +#include +#include + + +using namespace cute; + + +template +struct PackedHalf; + +template<> +struct PackedHalf { + using Type = __half2; +}; + +template<> +struct PackedHalf { + using Type = nv_bfloat162; +}; + +template +__device__ GmmaDescriptor make_smem_desc( + PointerType smem_ptr, + int layout_type, + int leading_byte_offset = 0, + int stride_byte_offset = 1024) { + + GmmaDescriptor desc; + auto uint_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + desc.bitfield.start_address_ = uint_ptr >> 4; + desc.bitfield.layout_type_ = layout_type; + desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4; + desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4; + desc.bitfield.base_offset_ = 0; + return desc; +} + +template +__forceinline__ __device__ static void gemm(uint64_t const& desc_a, uint64_t const& desc_b, float* d, const uint32_t e, std::index_sequence) { + Mma::fma(desc_a, desc_b, d[Idx]..., e, GMMA::ScaleOut::One); +} + +template +__forceinline__ __device__ void gemm( + const T * sA, + const T * sB, + float * acc_c, + const uint32_t *E) { + + constexpr int acc_num = sizeof(Mma::CRegisters) / sizeof(float); + + warpgroup_arrive(); + // 选择的下标 对应的16进制 + // 01 4 + // 02 8 + // 03 12 + // 12 9 + // 13 13 + // 23 14 + #pragma unroll + for (int i = 0; i < kBlockK / 64; i++) { + GmmaDescriptor a_desc = make_smem_desc(sA + i * 32, 1, 0, 1024); + GmmaDescriptor b_desc = make_smem_desc(sB + i * 64, 1, 0, 1024); + gemm(a_desc, b_desc, acc_c, E[i * NumMmaThreads], std::make_index_sequence{}); + } + + warpgroup_commit_batch(); + warpgroup_wait<0>(); +} diff --git a/custom_ops/gpu_ops/wfp8afp8_sparse_gemm/w8a8_sparse_gemm_kernel.hpp b/custom_ops/gpu_ops/wfp8afp8_sparse_gemm/w8a8_sparse_gemm_kernel.hpp new file mode 100644 index 000000000..c86cba01d --- /dev/null +++ b/custom_ops/gpu_ops/wfp8afp8_sparse_gemm/w8a8_sparse_gemm_kernel.hpp @@ -0,0 +1,309 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "cute/atom/mma_atom.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/cluster_launch.hpp" +#include "cutlass/arch/reg_reconfig.h" + +#include "kernel_traits.h" +#include "mainloop_fwd.h" + +template +void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) w8a8_sparse_gemm_kernel( + CUTE_GRID_CONSTANT typename CollectiveMainloopFwd::Params const mainloop_params) { + + using Element = typename Ktraits::Element; + static_assert(cutlass::sizeof_bits_v == 8); + + using TileShape_MNK = typename Ktraits::TileShape_MNK; + using ClusterShape = typename Ktraits::ClusterShape_MNK; + + static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{}); + static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup; + static constexpr int TokenPackSize = Ktraits::TokenPackSize; + static constexpr int kBlockM = Ktraits::kBlockM; + static constexpr int kBlockN = Ktraits::kBlockN; + static constexpr int TAIL_N = Ktraits::TAIL_N; + static constexpr int M = Ktraits::M; + + using CollectiveMainloop = CollectiveMainloopFwd; + + using MainloopPipeline = typename Ktraits::MainloopPipeline; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename MainloopPipeline::PipelineState; + + extern __shared__ char shared_memory[]; + auto &shared_storage = *reinterpret_cast(shared_memory); + + int const lane_predicate = cute::elect_one_sync(); + int const warp_idx = cutlass::canonical_warp_idx_sync(); + + if (warp_idx == 0 && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(mainloop_params); + } + + // Obtain warp index + int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + + PipelineParams pipeline_params; + pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesA + CollectiveMainloop::TmaTransactionBytesE + CollectiveMainloop::TmaTransactionBytesB; + int warp_group_idx = cutlass::canonical_warp_group_idx(); + pipeline_params.role = warp_group_idx == 0 + ? MainloopPipeline::ThreadCategory::Producer + : MainloopPipeline::ThreadCategory::Consumer; + pipeline_params.is_leader = warp_group_thread_idx == 0; + pipeline_params.num_consumers = NumMmaThreads; + + MainloopPipeline pipeline(shared_storage.pipeline, pipeline_params, ClusterShape{}); + + pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesB; + + CollectiveMainloop collective_mainloop; + + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + } else { + __syncthreads(); + } + + + const int bidm = blockIdx.x; + const int bidn = blockIdx.y; + const int bidb = blockIdx.z; + const int tidx = threadIdx.x; + + const int pre_fix_tokens = TokenPackSize == 0 ? mainloop_params.tokens[bidb] : 0; + + const int tokens = TokenPackSize == 0 ? mainloop_params.tokens[bidb + 1] - pre_fix_tokens : mainloop_params.tokens[bidb]; + + + if (bidn * kBlockN >= tokens) { + return; + } + + if (warp_group_idx == 0) { + cutlass::arch::warpgroup_reg_dealloc<40>(); + PipelineState smem_pipe_write = cutlass::make_producer_start_state(); + collective_mainloop.load( + mainloop_params, + pipeline, + smem_pipe_write, + shared_storage, + pre_fix_tokens, + tokens, + bidm, + bidn, + bidb, + tidx); + } else { + cutlass::arch::warpgroup_reg_alloc<232>(); + PipelineState smem_pipe_read; + + constexpr int acc_num = sizeof(typename Ktraits::Mma::CRegisters) / sizeof(float); + float acc_s[acc_num]; + + #pragma unroll + for (int i = 0; i < acc_num; ++i) { + acc_s[i] = 0.0f; + } + + const int reamin_tokens = tokens - bidn * kBlockN; + + const int mma_tidx = tidx - NumCopyThreads; + + const float2 weight_scale = reinterpret_cast(mainloop_params.weight_scale + bidb * M + bidm * kBlockM)[mma_tidx / 4]; + + + if (TAIL_N > 0 && reamin_tokens < kBlockN) { + collective_mainloop.mma( + mainloop_params, + pipeline, + smem_pipe_read, + shared_storage, + acc_s, + mma_tidx); + + collective_mainloop.store( + mainloop_params, + acc_s, + shared_storage, + pre_fix_tokens, + tokens, + reinterpret_cast(&weight_scale), + bidm, + bidn, + bidb, + mma_tidx); + } else { + collective_mainloop.mma( + mainloop_params, + pipeline, + smem_pipe_read, + shared_storage, + acc_s, + mma_tidx); + + collective_mainloop.store( + mainloop_params, + acc_s, + shared_storage, + pre_fix_tokens, + tokens, + reinterpret_cast(&weight_scale), + bidm, + bidn, + bidb, + mma_tidx); + } + } + +} + +template +auto get_gmem_layout(int Rows, int Cols) { + return make_layout( + make_shape( + static_cast(Rows), + static_cast(Cols), + static_cast(Batch)), + make_stride( + static_cast(Cols), + cute::_1{}, + static_cast(Rows * Cols))); +} + +template +auto get_weight_gmem_layout(int m_nums, int k_nums, int Rows, int Cols) { + return make_layout( + make_shape( + static_cast(Rows), + static_cast(Cols), + static_cast(k_nums), + static_cast(m_nums), + static_cast(Batch)), + make_stride( + static_cast(Cols), + cute::_1{}, + static_cast(Rows * Cols), + static_cast(Rows * Cols * k_nums), + static_cast(Rows * Cols * k_nums * m_nums))); +} + +template +auto get_gmem_e_layout(int ms, int ks, int ks_in, int Cols) { + return make_layout( + make_shape( + static_cast(Cols), + static_cast(ks_in), + static_cast(ks), + static_cast(ms), + static_cast(Batch)), + make_stride( + cute::_1{}, + static_cast(Cols), + static_cast(ks_in * Cols), + static_cast(ks * ks_in * Cols), + static_cast(ms * ks * Cols * 2))); +} + +template +void run_gemm( + const InputType * A, + const uint32_t *E, + const InputType * B, + OutputType * C, + const float *weight_scale, + const int *tokens_idx, + const int max_tokens, + cudaStream_t stream) { + + using ElementOutput = typename Kernel_traits::ElementOutput; + using Element = typename Kernel_traits::Element; + using CollectiveMainloop = CollectiveMainloopFwd; + using ClusterShape = typename Kernel_traits::ClusterShape_MNK; + constexpr int NumMmaThreads = Kernel_traits::NumMmaThreads; + constexpr int kBlockK = Kernel_traits::kBlockK; + constexpr int kBlockM = Kernel_traits::kBlockM; + + static_assert(M % Kernel_traits::kBlockM == 0); + constexpr int M_nums = M / Kernel_traits::kBlockM; + const int N_nums = (max_tokens + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; + + constexpr int kTiles = Kernel_traits::kTiles; + + typename CollectiveMainloop::Params mainloop_params = + CollectiveMainloop::to_underlying_arguments({ + static_cast(A), + get_weight_gmem_layout(M_nums, kTiles, kBlockM / 2, kBlockK), + static_cast(E), + get_gmem_e_layout(M_nums, kTiles, kBlockK / 64, NumMmaThreads), + static_cast(B), + get_gmem_layout(kPackTokenSize == 0 ? max_tokens * Batch : kPackTokenSize, K), + static_cast(C), + get_gmem_layout(M, kPackTokenSize == 0 ? max_tokens : kPackTokenSize), + tokens_idx, + weight_scale, + }); + + void *kernel; + kernel = (void *)w8a8_sparse_gemm_kernel; + + int smem_size = sizeof(typename Kernel_traits::SharedStorage); + + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + + dim3 grid_dims; + grid_dims.x = M_nums; + grid_dims.y = N_nums; + grid_dims.z = Batch; + static constexpr int ctaSize = Kernel_traits::kNWarps * 32; + dim3 block_dims(ctaSize); + dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); + cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream}; + cutlass::launch_kernel_on_cluster( + launch_params, kernel, mainloop_params); +} + +template +void w8a8_sparse_gemm( + const InputType * A, + const uint32_t * E, + const InputType * B, + OutputType * C, + const float *weight_scale, + const int *tokens_idx, + const int max_tokens, + cudaStream_t stream) { + constexpr static int kBlockM = 128; + constexpr static int kBlockK = 128; + constexpr static int kNWarps = 4 + kBlockM / 16; + constexpr static int kStages = 5; + constexpr int kCluster = 1; + static_assert(K % kBlockK == 0); + constexpr int kTiles = K / kBlockK; + const int max_tokens_pack16 = (max_tokens + 31) / 32 * 32; + + using Kernel_traits = Kernel_traits; + run_gemm(A, E, B, C, weight_scale, tokens_idx, max_tokens_pack16, stream); +} diff --git a/custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8afp8_sparse_gemm.cu b/custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8afp8_sparse_gemm.cu new file mode 100644 index 000000000..03d3c16a1 --- /dev/null +++ b/custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8afp8_sparse_gemm.cu @@ -0,0 +1,112 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +#include "helper.h" +#include "paddle/extension.h" +#include "wfp8Afp8_sparse_gemm_template.h" + +template +void DisPatchWFp8AFp8Gemm( + const cutlass::float_e4m3_t* input, + const uint32_t* sparse_idx, + const cutlass::float_e4m3_t* weight, + const int * tokens, + const float * weight_scale, + OutputType * out, + const int token_padding_size, + const int max_tokens, + const int batch_size, + const int M, + const int K, + cudaStream_t stream) { + + const int max_tokens_pack32 = (max_tokens + 31) / 32 * 32; + + int kBlockN = 256; + int TailN = max_tokens_pack32 % kBlockN; + if (max_tokens < 256) { + kBlockN = max_tokens_pack32; + TailN = 0; + } + if constexpr (std::is_same_v) { + SPARSE_GEMM_SWITCH_BF16(M, K, batch_size, token_padding_size, kBlockN, TailN, + weight, + sparse_idx, + input, + out, + weight_scale, + tokens, + max_tokens, + stream) + } else { + PD_THROW("Only supported dtype in ['BFLOAT16']."); + } +} + +void WFp8AFp8Gemm( + const paddle::Tensor& input, + const paddle::Tensor& sparse_idx, + const paddle::Tensor& weight, + const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group + const paddle::Tensor& weight_scale, + const paddle::Tensor& out, + const int token_padding_size, + const int max_tokens, + const bool is_bfloat16) { + + const int batch_size = weight.dims()[0]; + const int M = weight.dims()[1]; + const int K = weight.dims()[2] * 2; + + if (input.dtype() != paddle::DataType::FLOAT8_E4M3FN) { + PD_THROW("Only supported dtype in ['FLOAT8_E4M3FN']."); + } + + if (is_bfloat16) { + DisPatchWFp8AFp8Gemm( + reinterpret_cast(input.data()), + reinterpret_cast(sparse_idx.data()), + reinterpret_cast(weight.data()), + tokens.data(), + weight_scale.data(), + reinterpret_cast(const_cast(out.data())), + token_padding_size, + max_tokens, + batch_size, + M, + K, + input.stream() + ); + } else { + PD_THROW("Only supported dtype in ['BFLOAT16']."); + } +} + +PD_BUILD_STATIC_OP(wfp8afp8_sparse_gemm) + .Inputs({"input", + "sparse_idx", + "weight", + "tokens", + "weight_scale", + "ffn_out"}) + .Outputs({"out"}) + .SetInplaceMap({{"ffn_out", "out"}}) + .Attrs({"token_padding_size: int", + "max_tokens: int", + "is_bfloat16: bool"}) + .SetKernelFn(PD_KERNEL(WFp8AFp8Gemm)); diff --git a/custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8afp8_sparse_gemm_weight.cu b/custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8afp8_sparse_gemm_weight.cu new file mode 100644 index 000000000..9871dea4b --- /dev/null +++ b/custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8afp8_sparse_gemm_weight.cu @@ -0,0 +1,96 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +#include "helper.h" +#include "paddle/extension.h" + + +void pack_E(const uint8_t *E_src, int32_t *E_dst, const int M, const int K, const int Batch) { + // 选择的下标 对应的16进制 + // 01 4 + // 02 8 + // 03 12 + // 12 9 + // 13 13 + // 23 14 + const int ld1 = K / 4; + const int ld2 = K / 4 / 8; + const uint8_t select_idx[6] = {14, 13, 9, 12, 8, 4}; + for (int b = 0; b < Batch; ++b) { + for (int m = 0; m < M; ++m) { + for (int k = 0; k < ld1; k+=8) { + uint32_t dst = 0; + for (int k2 = 7; k2 > 0; --k2) { + dst |= select_idx[E_src[b * M * ld1 + m * ld1 + k + k2]]; + dst <<= 4; + } + dst |= select_idx[E_src[b * M * ld1 + m * ld1 + k]]; + E_dst[b * M * ld2 + m * ld2 + k / 8] = dst; + } + } + } +} + +void peruate_E(const int32_t *E_src, int32_t *E_dst, const int M, const int K, const int Batch) { + const int m_nums = M / 128; + const int k_nums = K / 128; + for (int b = 0; b < Batch; ++b) { + for (int m = 0; m < m_nums; ++m) { + for (int k = 0; k < k_nums; ++k) { + const int dst_idx = b * m_nums * k_nums * 512 + m * k_nums * 512 + k * 512; + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + E_dst[dst_idx + 0 + j * 32 + 4 * i] = E_src[dst_idx + 0 + j * 64 + 4 * i]; + E_dst[dst_idx + 2 + j * 32 + 4 * i] = E_src[dst_idx + 1 + j * 64 + 4 * i]; + E_dst[dst_idx + 1 + j * 32 + 4 * i] = E_src[dst_idx + 32 + j * 64 + 4 * i]; + E_dst[dst_idx + 3 + j * 32 + 4 * i] = E_src[dst_idx + 33 + j * 64 + 4 * i]; + } + for (int j = 0; j < 8; ++j) { + E_dst[dst_idx + 256 + j * 32 + 4 * i] = E_src[dst_idx + 2 + j * 64 + 4 * i]; + E_dst[dst_idx + 258 + j * 32 + 4 * i] = E_src[dst_idx + 3 + j * 64 + 4 * i]; + E_dst[dst_idx + 257 + j * 32 + 4 * i] = E_src[dst_idx + 34 + j * 64 + 4 * i]; + E_dst[dst_idx + 259 + j * 32 + 4 * i] = E_src[dst_idx + 35 + j * 64 + 4 * i]; + } + } + } + } + } +} + +std::vector WFp8AFp8GemmSparseIdxConvert( + const paddle::Tensor& weight, + const int batch_size, + const int M, + const int K) { + + paddle::Tensor weight_temp = paddle::empty({batch_size, M, K / 32}, paddle::DataType::INT32, weight.place()); + paddle::Tensor weight_new = paddle::empty({batch_size, M, K / 32}, paddle::DataType::INT32, weight.place()); + pack_E(weight.data(), weight_temp.data(), M, K, batch_size); + peruate_E(weight_temp.data(), weight_new.data(), M, K, batch_size); + return {weight_new}; +} + + + +PD_BUILD_STATIC_OP(wfp8afp8_gemm_sparse_idx_convert) + .Inputs({"weight"}) + .Outputs({"converted_weight"}) + .Attrs({"batch: int", + "M: int", + "K: int"}) + .SetKernelFn(PD_KERNEL(WFp8AFp8GemmSparseIdxConvert)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index a94c22f48..9861511e8 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -510,6 +510,8 @@ elif paddle.is_compiled_with_cuda(): sources += ["gpu_ops/flash_mask_attn/flash_mask_attn.cu"] os.system("python utils/auto_gen_w4afp8_gemm_kernel.py") sources += find_end_files("gpu_ops/w4afp8_gemm", ".cu") + os.system("python utils/auto_gen_wfp8afp8_sparse_gemm_kernel.py") + sources += find_end_files("gpu_ops/wfp8afp8_sparse_gemm", ".cu") setup( name="fastdeploy_ops", diff --git a/custom_ops/utils/auto_gen_wfp8afp8_sparse_gemm_kernel.py b/custom_ops/utils/auto_gen_wfp8afp8_sparse_gemm_kernel.py new file mode 100644 index 000000000..d490b3583 --- /dev/null +++ b/custom_ops/utils/auto_gen_wfp8afp8_sparse_gemm_kernel.py @@ -0,0 +1,207 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +file_dir = "./gpu_ops/wfp8afp8_sparse_gemm/" + +gemm_template_head = """ +#pragma once +#include +#include +#include +#include +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif +#include +#include +#include +#include +#include +""" +gemm_template_case = """ +void wfp8afp8_sparse_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}( + const cutlass::float_e4m3_t * weight, + const uint32_t * sparse_idx, + const cutlass::float_e4m3_t * input, + {cutlass_type} * out, + const float *weight_scale, + const int *tokens, + const int max_tokens, + cudaStream_t stream); +""" + +gemm_template_cu_head = """ +#include "paddle/extension.h" +#include "wfp8Afp8_sparse_gemm_template.h" +#include "w8a8_sparse_gemm_kernel.hpp" + +""" +gemm_template_cu_template = """ +void wfp8afp8_sparse_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}( + const cutlass::float_e4m3_t * weight, + const uint32_t * sparse_idx, + const cutlass::float_e4m3_t * input, + {cutlass_type} * out, + const float *weight_scale, + const int *tokens, + const int max_tokens, + cudaStream_t stream) {{ + + constexpr static int M = {M}; + constexpr static int K = {K}; + constexpr static int Batch = {BATCH}; + constexpr static int TokenPackSize = {PADDING}; + constexpr static int kBlockN = {N}; + constexpr static int kBlockN_TAIL = {TAILN}; + constexpr static int kBlockM = 128; + constexpr static int kBlockK = 128; + constexpr static int kNWarps = 4 + kBlockM / 16; + constexpr static int kStages = 5; + constexpr int kCluster = 1; + static_assert(K % kBlockK == 0); + constexpr int kTiles = K / kBlockK; + + using Kernel_traits = Kernel_traits< + kBlockM, kBlockN, kBlockK, kNWarps, kStages, kTiles, + M, TokenPackSize, kBlockN_TAIL, kCluster, cutlass::float_e4m3_t, + {cutlass_type}>; + run_gemm + (weight, sparse_idx, input, out, weight_scale, + tokens, max_tokens, stream); +}} +""" + +gemm_case = [ + [128, 128, 1, 0], + [7168, 8192, 8, 0], # eb45T ffn1 +] + +dtype = ["BF16"] + + +def get_cutlass_type(type): + if type == "BF16": + return "cutlass::bfloat16_t" + elif type == "FP16": + return "cutlass::half_t" + + +template_head_file = open(f"{file_dir}wfp8Afp8_sparse_gemm_template.h", "w") +template_head_file.write(gemm_template_head) + +for type in dtype: + for case in gemm_case: + for n in range(32, 257, 32): + template_head_file.write( + gemm_template_case.format( + M=case[0], + K=case[1], + N=n, + BATCH=case[2], + TYPE=type, + PADDING=case[3], + TAILN=0, + cutlass_type=get_cutlass_type(type), + ) + ) + template_head_file.write( + gemm_template_case.format( + M=case[0], + K=case[1], + N=256, + BATCH=case[2], + TYPE=type, + PADDING=case[3], + TAILN=n - 32, + cutlass_type=get_cutlass_type(type), + ) + ) + + template_cu_file = open( + f"{file_dir}wfp8Afp8_sparse_gemm_M{case[0]}_N{n}_TAILN{0}_K{case[1]}_B{case[2]}_P{case[3]}_{type}.cu", + "w", + ) + template_cu_file.write(gemm_template_cu_head) + template_cu_file.write( + gemm_template_cu_template.format( + M=case[0], + K=case[1], + N=n, + BATCH=case[2], + TYPE=type, + PADDING=case[3], + TAILN=0, + cutlass_type=get_cutlass_type(type), + ) + ) + + template_cu_file.close() + + template_cu_file = open( + f"{file_dir}wfp8Afp8_sparse_gemm_M{case[0]}_N{256}_TAILN{n-32}_K{case[1]}_B{case[2]}_P{case[3]}_{type}.cu", + "w", + ) + template_cu_file.write(gemm_template_cu_head) + template_cu_file.write( + gemm_template_cu_template.format( + M=case[0], + K=case[1], + N=256, + BATCH=case[2], + TYPE=type, + PADDING=case[3], + TAILN=n - 32, + cutlass_type=get_cutlass_type(type), + ) + ) + + template_cu_file.close() + +for type in dtype: + template_head_file.write("\n") + template_head_file.write( + """#define SPARSE_GEMM_SWITCH_{TYPE}(_M, _K, _BATCH, _TokenPaddingSize, _kBlockN, _TailN, ...) {{ \\ + if (_M == 0 && _K == 0 && _BATCH == 0 && _TokenPaddingSize == 0 && _kBlockN == 0 && _TailN == 0) {{ \\""".format( + TYPE=type + ) + ) + + template_head_file.write("\n") + + for case in gemm_case: + for n in range(32, 257, 32): + template_head_file.write( + """ }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\ + wfp8afp8_sparse_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format( + M=case[0], K=case[1], N=n, BATCH=case[2], TYPE=type, PADDING=case[3], TAILN=0 + ) + ) + template_head_file.write("\n") + template_head_file.write( + """ }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\ + wfp8afp8_sparse_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format( + M=case[0], K=case[1], N=256, BATCH=case[2], TYPE=type, PADDING=case[3], TAILN=n - 32 + ) + ) + template_head_file.write("\n") + + template_head_file.write( + """ } else { \\ + PADDLE_THROW(phi::errors::Unimplemented("WFp8aFp8 Sparse not supported m=%d k=%d batch=%d token_padding_size=%d kBlockN=%d tailN=%d\\n", _M, _K, _BATCH, _TokenPaddingSize, _kBlockN, _TailN)); \\ + } \\ + }""" + ) + +template_head_file.close() diff --git a/test/operators/test_wfp8afp8_sparse_gemm.py b/test/operators/test_wfp8afp8_sparse_gemm.py new file mode 100644 index 000000000..e1cc51fef --- /dev/null +++ b/test/operators/test_wfp8afp8_sparse_gemm.py @@ -0,0 +1,163 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.gpu import ( + wfp8afp8_gemm_sparse_idx_convert, + wfp8afp8_sparse_gemm, +) + + +def wfp8afp8_gemm_naive(input_bf16, weight_quant, tokens, weight_scale, BATCH, N): + weight = weight_quant.astype("bfloat16") / weight_scale + input_bf16 = input_bf16.astype("bfloat16") + all_tokens = int(tokens.sum()) + out = paddle.zeros([all_tokens, N], dtype="bfloat16") + pre_fix_token = 0 + for i in range(BATCH): + input = input_bf16[pre_fix_token : pre_fix_token + tokens[i], :] + out_i = paddle.matmul(input, weight[i], transpose_y=True) + out[pre_fix_token : pre_fix_token + tokens[i], :] = out_i + pre_fix_token += tokens[i] + return out + + +def peruate_scale(weight_scale, N): + BATCH = weight_scale.shape[0] + weight_scale = weight_scale.reshape([BATCH, N]) + temp = paddle.zeros([16]) + for b in range(BATCH): + for n in range(0, N, 16): + temp[:] = weight_scale[b, n : n + 16] + for j in range(0, 16, 2): + weight_scale[b, n + j] = temp[j // 2] + weight_scale[b, n + j + 1] = temp[j // 2 + 8] + return weight_scale + + +def sparse(weight, sparse_idx): + pack_weight = np.zeros([weight.shape[0], weight.shape[1], weight.shape[2] // 2], dtype=weight.dtype) + + idx_select = [ + [0, 1, 2, 3], + [0, 2, 1, 3], + [0, 3, 1, 2], + [1, 2, 0, 3], + [1, 3, 0, 2], + [2, 3, 0, 1], + ] + for b in range(weight.shape[0]): + for i in range(weight.shape[1]): + for j in range(0, weight.shape[2], 4): + idx = sparse_idx[b, i, j // 4] + idx1 = idx_select[idx][0] + idx2 = idx_select[idx][1] + idx3 = idx_select[idx][2] + idx4 = idx_select[idx][3] + + weight[b, i, j + idx1] = 0 + weight[b, i, j + idx2] = 0 + + pack_weight[b, i, j // 4 * 2] = weight[b, i, j + idx3] + pack_weight[b, i, j // 4 * 2 + 1] = weight[b, i, j + idx4] + return weight, pack_weight + + +def convert(weight, sparse_idx, K): + BATCH = weight.shape[0] + temp = np.zeros(weight.shape, dtype=weight.dtype) + + for i in range(0, weight.shape[1], 128): + for j in range(0, 128): + dst_idx = j // 2 + (j % 2) * 64 + temp[:, j + i, :] = weight[:, i + dst_idx, :] + + temp_trans = np.zeros([BATCH, weight.shape[1] // 128, K // 128, 128, 64], dtype=weight.dtype) + temp_E = np.zeros([BATCH, weight.shape[1] // 128, K // 128, 128, 32], dtype=sparse_idx.dtype) + + for b in range(BATCH): + for i in range(weight.shape[1] // 128): + for j in range(K // 128): + temp_trans[b, i, j] = temp[b, i * 128 : i * 128 + 128, j * 64 : j * 64 + 64] + temp_E[b, i, j] = sparse_idx[b, i * 128 : i * 128 + 128, j * 32 : j * 32 + 32] + + return temp_trans, temp_E + + +class TestWFp8Afp8SparseGemm(unittest.TestCase): + def test_wfp8afp8_sparse_gemm(self): + paddle.seed(0) + tokens_per_group = 10 + N = 128 + K = 128 + BATCH = 1 + TokenPadding = 0 + + tokens = [tokens_per_group] * BATCH + tokens_perfix_sum = np.cumsum(tokens) + tokens_perfix_sum = np.insert(tokens_perfix_sum, 0, 0) + + tokens = paddle.to_tensor(tokens, dtype="int32") + tokens_perfix_sum = paddle.to_tensor(tokens_perfix_sum, dtype="int32") + + all_tokens = int(tokens.sum()) + + input_fp8 = paddle.randn([all_tokens, K], dtype="bfloat16").astype(paddle.float8_e4m3fn) + + weight = paddle.randn([BATCH, N, K], dtype="bfloat16") + + weight_scale = 40 / weight.abs().max(axis=-1).reshape([BATCH, N, 1]) + + weight_quant = (weight * weight_scale).astype(paddle.float8_e4m3fn).astype("bfloat16") + + weight_quant = weight_quant.numpy() + + sparse_idx = np.random.randint(0, high=6, size=(BATCH, N, K // 4)) + + weight_quant, pack_weight = sparse(weight_quant, sparse_idx) + + weight_quant = paddle.to_tensor(weight_quant) + out_naive = wfp8afp8_gemm_naive(input_fp8, weight_quant, tokens, weight_scale, BATCH, N) + + pack_weight, convert_sparse_idx = convert(pack_weight, sparse_idx, K) + + pack_weight = paddle.to_tensor(pack_weight).astype(paddle.float8_e4m3fn) + convert_sparse_idx = paddle.to_tensor(convert_sparse_idx).astype("uint8").cpu() + convert_sparse_idx = wfp8afp8_gemm_sparse_idx_convert(convert_sparse_idx, int(BATCH), int(N), int(K)).cuda() + + weight_scale = paddle.to_tensor(peruate_scale(weight_scale, N)).astype("float32") + + out_pd = paddle.zeros([all_tokens, N], dtype="bfloat16") + + wfp8afp8_sparse_gemm( + input_fp8, + convert_sparse_idx, + pack_weight.reshape([BATCH, N, K // 2]), + tokens_perfix_sum if TokenPadding == 0 else tokens, + 1 / weight_scale, + out_pd, + int(TokenPadding), + int(tokens_per_group), + True, + ) + + print((out_pd - out_naive).abs().max()) + + +if __name__ == "__main__": + unittest.main()