mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-03 15:56:49 +08:00
【New Feature】支持Fp8 group Gemm 24稀疏 (#3463)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
* 支持24稀疏 * code style * 增加stmatrix 宏定义判断 * code style
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -170,3 +170,6 @@ third_party
|
||||
|
||||
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_*.cu
|
||||
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_template.h
|
||||
|
||||
custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8Afp8_sparse_gemm_*.cu
|
||||
custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8Afp8_sparse_gemm_template.h
|
||||
|
151
custom_ops/gpu_ops/wfp8afp8_sparse_gemm/kernel_traits.h
Normal file
151
custom_ops/gpu_ops/wfp8afp8_sparse_gemm/kernel_traits.h
Normal file
@@ -0,0 +1,151 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "cute/algorithm/copy.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <int kStages, class GemmType, class OutputType, class SmemLayoutA,
|
||||
class SmemLayoutE,
|
||||
class SmemLayoutB, class SmemLayoutC>
|
||||
struct SharedStorage {
|
||||
union {
|
||||
struct {
|
||||
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutA>> smem_a;
|
||||
cute::array_aligned<uint32_t, cute::cosize_v<SmemLayoutE>> smem_e;
|
||||
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutB>> smem_b;
|
||||
};
|
||||
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutC>> smem_c;
|
||||
};
|
||||
|
||||
struct {
|
||||
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline;
|
||||
};
|
||||
};
|
||||
|
||||
template<int kBlockM_, int kBlockN_, int kBlockK_,
|
||||
int kNWarps_, int kStages_,
|
||||
int kTiles_, int M_,
|
||||
int TokenPackSize_,
|
||||
int TAIL_N_ = 0,
|
||||
int kClusterM_ = 1,
|
||||
typename elem_type=cutlass::float_e4m3_t,
|
||||
typename OutputType = cutlass::bfloat16_t>
|
||||
struct Kernel_traits {
|
||||
using Element = elem_type;
|
||||
using ElementAccum = float;
|
||||
using ElementOutput = OutputType;
|
||||
static_assert(cutlass::sizeof_bits_v<Element> == 8);
|
||||
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
|
||||
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
|
||||
|
||||
static_assert(kNWarps_ == 12);
|
||||
|
||||
static constexpr int kBlockM = kBlockM_;
|
||||
static constexpr int kBlockN = kBlockN_;
|
||||
static constexpr int kBlockK = kBlockK_;
|
||||
static constexpr int kTiles = kTiles_;
|
||||
static constexpr int TokenPackSize = TokenPackSize_;
|
||||
static constexpr int TAIL_N = TAIL_N_;
|
||||
static constexpr int M = M_;
|
||||
|
||||
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kBlockK>>;
|
||||
using TileShape_MNK_TAIL = Shape<Int<kBlockM>, Int<TAIL_N>, Int<kBlockK>>;
|
||||
static constexpr int kClusterM = kClusterM_;
|
||||
using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
|
||||
|
||||
static constexpr int kStages = kStages_;
|
||||
static_assert(kStages > 1);
|
||||
|
||||
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
using Mma = decltype(cute::GMMA::ss_op_selector_sparse<Element, Element, ElementAccum, TileShape_MNK>());
|
||||
|
||||
using Mma_TAIL = decltype(cute::GMMA::ss_op_selector_sparse<Element, Element, ElementAccum, TileShape_MNK_TAIL>());
|
||||
|
||||
using SmemLayoutAtomA = decltype(
|
||||
cutlass::gemm::collective::detail::rs_smem_selector<
|
||||
GMMA::Major::K, Element, Int<kBlockM / 2>, Int<kBlockK>>());
|
||||
|
||||
using SmemLayoutA = decltype(
|
||||
tile_to_shape(SmemLayoutAtomA{},
|
||||
make_shape(Int<kBlockM / 2>{}, Int<kBlockK>{}, Int<kStages>{})));
|
||||
|
||||
using SmemLayoutAtomB = decltype(
|
||||
cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
using SmemLayoutB = decltype(
|
||||
tile_to_shape(SmemLayoutAtomB{},
|
||||
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
|
||||
|
||||
using SmemLayoutAtomB_TAIL = decltype(
|
||||
cutlass::gemm::collective::detail::rs_smem_selector<
|
||||
GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK_TAIL{})),
|
||||
decltype(cute::get<2>(TileShape_MNK_TAIL{}))>());
|
||||
|
||||
using SmemLayoutB_TAIL = decltype(
|
||||
tile_to_shape(SmemLayoutAtomB_TAIL{},
|
||||
make_shape(
|
||||
shape<1>(TileShape_MNK_TAIL{}),
|
||||
shape<2>(TileShape_MNK_TAIL{}),
|
||||
Int<kStages>{})
|
||||
));
|
||||
using SmemLayoutAtomC = decltype(
|
||||
cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, ElementOutput,
|
||||
decltype(cute::get<0>(TileShape_MNK{})),
|
||||
decltype(cute::get<1>(TileShape_MNK{}))>());
|
||||
|
||||
using SmemLayoutC = decltype(tile_to_shape(SmemLayoutAtomC{}, select<0, 1>(TileShape_MNK{})));
|
||||
|
||||
using SmemLayoutE = Layout<Shape<Int<NumMmaThreads>, Int<kBlockK / 64>, Int<kStages>>>;
|
||||
|
||||
using SharedStorage = SharedStorage<
|
||||
kStages, Element, ElementOutput, SmemLayoutA, SmemLayoutE, SmemLayoutB, SmemLayoutC>;
|
||||
|
||||
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
|
||||
using PipelineState = typename cutlass::PipelineState<kStages>;
|
||||
|
||||
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<OutputType>);
|
||||
static constexpr int kNumThreadsPerRow = kBlockN / kNumVecElem;
|
||||
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
|
||||
using TiledCopyCAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, OutputType>;
|
||||
using TiledCopyCThrLayout = decltype(cute::make_layout(
|
||||
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
|
||||
LayoutRight{}));
|
||||
using TiledCopyCValLayout = decltype(cute::make_layout(
|
||||
cute::make_shape(_1{}, Int<kNumVecElem>{}),
|
||||
LayoutRight{}));
|
||||
using TiledCopyC = decltype(make_tiled_copy(
|
||||
TiledCopyCAtom{},
|
||||
TiledCopyCThrLayout{}, // Thr layout
|
||||
TiledCopyCValLayout{} // Val layout
|
||||
));
|
||||
};
|
466
custom_ops/gpu_ops/wfp8afp8_sparse_gemm/mainloop_fwd.h
Normal file
466
custom_ops/gpu_ops/wfp8afp8_sparse_gemm/mainloop_fwd.h
Normal file
@@ -0,0 +1,466 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
template <typename Ktraits>
|
||||
struct CollectiveMainloopFwd {
|
||||
|
||||
using Element = typename Ktraits::Element;
|
||||
using ElementOutput = typename Ktraits::ElementOutput;
|
||||
using TileShape_MNK = typename Ktraits::TileShape_MNK;
|
||||
using ClusterShape = typename Ktraits::ClusterShape_MNK;
|
||||
using ElementAccum = typename Ktraits::ElementAccum;
|
||||
|
||||
static constexpr int kStages = Ktraits::kStages;
|
||||
static constexpr int kBlockM = Ktraits::kBlockM;
|
||||
static constexpr int kBlockN = Ktraits::kBlockN;
|
||||
static constexpr int kBlockK = Ktraits::kBlockK;
|
||||
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr int kTiles = Ktraits::kTiles;
|
||||
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
|
||||
static constexpr int TokenPackSize = Ktraits::TokenPackSize;
|
||||
static constexpr int M = Ktraits::M;
|
||||
|
||||
|
||||
using GmemTiledCopy = cute::SM90_TMA_LOAD;
|
||||
using GmemTiledCopyStore = cute::SM90_TMA_STORE;
|
||||
|
||||
using SmemLayoutA = typename Ktraits::SmemLayoutA;
|
||||
using SmemLayoutB = typename Ktraits::SmemLayoutB;
|
||||
using SmemLayoutC = typename Ktraits::SmemLayoutC;
|
||||
using SmemLayoutE = typename Ktraits::SmemLayoutE;
|
||||
using SmemLayoutB_TAIL = typename Ktraits::SmemLayoutB_TAIL;
|
||||
|
||||
using ShapeT = cute::Shape<int64_t, int64_t, int64_t>;
|
||||
using StrideT = cute::Shape<int64_t, _1, int64_t>;
|
||||
using LayoutT = cute::Layout<ShapeT, StrideT>;
|
||||
|
||||
using WShapeT = cute::Shape<int64_t, int64_t, int64_t, int64_t, int64_t>;
|
||||
using WStrideT = cute::Shape<int64_t, _1, int64_t, int64_t, int64_t>;
|
||||
using WLayoutT = cute::Layout<WShapeT, WStrideT>;
|
||||
|
||||
using EShapeT = cute::Shape<int64_t, int64_t, int64_t, int64_t, int64_t>;
|
||||
using EStrideT = cute::Shape<_1, int64_t, int64_t, int64_t, int64_t>;
|
||||
using ELayoutT = cute::Layout<EShapeT, EStrideT>;
|
||||
|
||||
using TMA_A = decltype(make_tma_copy(
|
||||
GmemTiledCopy{},
|
||||
make_tensor(
|
||||
make_gmem_ptr(static_cast<Element const*>(nullptr)),
|
||||
WShapeT{},
|
||||
WStrideT{}
|
||||
),
|
||||
SmemLayoutA{}(_, _, _0{}),
|
||||
select<0, 1>(Shape<Int<kBlockM / 2>, Int<kBlockK>>{}),
|
||||
size<0>(ClusterShape{})));
|
||||
|
||||
using TMA_B = decltype(make_tma_copy(
|
||||
GmemTiledCopy{},
|
||||
make_tensor(
|
||||
make_gmem_ptr(static_cast<Element const*>(nullptr)),
|
||||
ShapeT{},
|
||||
StrideT{}
|
||||
),
|
||||
take<0, 2>(SmemLayoutB{}),
|
||||
select<1, 2>(TileShape_MNK{}),
|
||||
size<0>(ClusterShape{})));
|
||||
|
||||
using TMA_E = decltype(make_tma_copy(
|
||||
GmemTiledCopy{},
|
||||
make_tensor(
|
||||
make_gmem_ptr(static_cast<uint32_t const*>(nullptr)),
|
||||
EShapeT{},
|
||||
EStrideT{}
|
||||
),
|
||||
SmemLayoutE{}(_, _, _0{}),
|
||||
select<0, 1>(Shape<Int<NumMmaThreads>, Int<kBlockK / 64>>{}),
|
||||
size<0>(ClusterShape{})));
|
||||
|
||||
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
|
||||
static constexpr uint32_t TmaTransactionBytesA = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutA{})) * cutlass::sizeof_bits_v<Element> / 8);
|
||||
static constexpr uint32_t TmaTransactionBytesB = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutB{})) * cutlass::sizeof_bits_v<Element> / 8);
|
||||
static constexpr uint32_t TmaTransactionBytesE = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutE{})) * cutlass::sizeof_bits_v<int> / 8);
|
||||
|
||||
struct Arguments {
|
||||
Element const* ptr_A;
|
||||
WLayoutT layout_A;
|
||||
uint32_t const* ptr_E;
|
||||
ELayoutT layout_E;
|
||||
Element const* ptr_B;
|
||||
LayoutT layout_B;
|
||||
ElementOutput * ptr_C;
|
||||
LayoutT layout_C;
|
||||
const int *tokens;
|
||||
const float *weight_scale;
|
||||
};
|
||||
|
||||
struct Params {
|
||||
WLayoutT layout_A;
|
||||
ELayoutT layout_E;
|
||||
LayoutT layout_B;
|
||||
TMA_A tma_load_A;
|
||||
TMA_E tma_load_E;
|
||||
TMA_B tma_load_B;
|
||||
const int *tokens;
|
||||
const float *weight_scale;
|
||||
ElementOutput * ptr_C;
|
||||
};
|
||||
|
||||
|
||||
Params static
|
||||
to_underlying_arguments(Arguments const& args) {
|
||||
Tensor mA = make_tensor(make_gmem_ptr(args.ptr_A), args.layout_A);
|
||||
TMA_A tma_load_A = make_tma_copy(
|
||||
GmemTiledCopy{},
|
||||
mA,
|
||||
SmemLayoutA{}(_, _, _0{}),
|
||||
select<0, 1>(Shape<Int<kBlockM / 2>, Int<kBlockK>>{}),
|
||||
size<0>(ClusterShape{}));
|
||||
Tensor mE = make_tensor(make_gmem_ptr(args.ptr_E), args.layout_E);
|
||||
TMA_E tma_load_E = make_tma_copy(
|
||||
GmemTiledCopy{},
|
||||
mE,
|
||||
SmemLayoutE{}(_, _, _0{}),
|
||||
select<0, 1>(Shape<Int<NumMmaThreads>, Int<kBlockK / 64>>{}),
|
||||
size<0>(ClusterShape{}));
|
||||
Tensor mB = make_tensor(make_gmem_ptr(args.ptr_B), args.layout_B);
|
||||
TMA_B tma_load_B = make_tma_copy(
|
||||
GmemTiledCopy{},
|
||||
mB,
|
||||
SmemLayoutB{}(_, _, _0{}),
|
||||
select<1, 2>(TileShape_MNK{}),
|
||||
size<0>(ClusterShape{}));
|
||||
|
||||
return {args.layout_A, args.layout_E, args.layout_B,
|
||||
tma_load_A, tma_load_E, tma_load_B,
|
||||
args.tokens, args.weight_scale, args.ptr_C};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params) {
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_A.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_B.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_E.get_tma_descriptor());
|
||||
}
|
||||
|
||||
template <int CUR_N, typename SharedStorage>
|
||||
CUTLASS_DEVICE void
|
||||
store(Params const& mainloop_params,
|
||||
float * acc_s,
|
||||
SharedStorage& shared_storage,
|
||||
const int pre_fix_tokens,
|
||||
const int tokens,
|
||||
const float * weight_scale,
|
||||
const int bidm,
|
||||
const int bidn,
|
||||
const int bidb,
|
||||
const int tidx) {
|
||||
typename Ktraits::TiledMma tiled_mma;
|
||||
using packHalf = typename PackedHalf<ElementOutput>::Type;
|
||||
Tensor tOrO_out = make_tensor<ElementOutput>(partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{})).layout());
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tOrO_out); i+=4) {
|
||||
acc_s[i] *= weight_scale[0];
|
||||
acc_s[i + 1] *= weight_scale[0];
|
||||
acc_s[i + 2] *= weight_scale[1];
|
||||
acc_s[i + 3] *= weight_scale[1];
|
||||
*reinterpret_cast<packHalf*>(&tOrO_out[i]) = packHalf(acc_s[i], acc_s[i + 2]);
|
||||
*reinterpret_cast<packHalf*>(&tOrO_out[i + 2]) = packHalf(acc_s[i + 1], acc_s[i + 3]);
|
||||
}
|
||||
|
||||
uint16_t *smem_c = reinterpret_cast<uint16_t *>(shared_storage.smem_c.data());
|
||||
|
||||
uint32_t * reg_data = reinterpret_cast<uint32_t*>(tOrO_out.data());
|
||||
|
||||
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
|
||||
|
||||
constexpr int k_copy_times = CUR_N / 16;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < k_copy_times; i++) {
|
||||
uint32_t smem_ptr = cast_smem_ptr_to_uint(reinterpret_cast<uint128_t*>(smem_c + i * 16 * 128) + tidx);
|
||||
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
|
||||
asm volatile (
|
||||
"stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
|
||||
:: "r"(smem_ptr), "r"(reg_data[4 * i + 0]), "r"(reg_data[4 * i + 2]), "r"(reg_data[4 * i + 1]), "r"(reg_data[4 * i + 3]));
|
||||
#endif
|
||||
}
|
||||
|
||||
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
|
||||
const int batch_idx = TokenPackSize == 0 ? pre_fix_tokens * M : bidb * M * TokenPackSize;
|
||||
ElementOutput * store_c = mainloop_params.ptr_C + batch_idx + bidn * (M * kBlockN) + bidm * kBlockM;
|
||||
|
||||
const int reamin_tokens = tokens - bidn * kBlockN;
|
||||
|
||||
const int col = tidx % 2;
|
||||
|
||||
constexpr int kPackSize = 16 / sizeof(ElementOutput);
|
||||
constexpr int kNumVecElem = kBlockM / kPackSize;
|
||||
constexpr int copy_len = CUR_N * kNumVecElem;
|
||||
#pragma unroll
|
||||
for (int idx = tidx; idx < copy_len; idx += NumMmaThreads) {
|
||||
const int idx_div2 = idx / 2;
|
||||
const int store_idx = idx_div2 / 128 * 128 + idx_div2 % 8 * 16 + idx_div2 % 128 / 16 + idx_div2 % 16 / 8 * 8;
|
||||
const int store_global_idx = store_idx * 2 + col;
|
||||
const int row = store_global_idx / kNumVecElem;
|
||||
const int col = store_global_idx % kNumVecElem;
|
||||
if (row >= reamin_tokens) {
|
||||
continue;
|
||||
}
|
||||
const int offset = row * (M / kPackSize) + col;
|
||||
reinterpret_cast<uint4*>(store_c)[offset] = reinterpret_cast<uint4*>(smem_c)[idx];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename MTensor>
|
||||
CUTLASS_DEVICE auto get_local_packed_tensor(
|
||||
const MTensor &mB,
|
||||
const int tokens,
|
||||
const int bidn) const {
|
||||
|
||||
auto mB_this_batch = make_tensor(
|
||||
mB.data(),
|
||||
make_layout(
|
||||
cute::make_shape(tokens, size<1>(mB)),
|
||||
mB.stride()
|
||||
));
|
||||
return local_tile(mB_this_batch, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
|
||||
}
|
||||
|
||||
template <typename MTensor>
|
||||
CUTLASS_DEVICE auto get_local_no_packed_tensor(
|
||||
const MTensor &mB,
|
||||
const int pre_fix_token,
|
||||
const int actual_token,
|
||||
const int bidn) const {
|
||||
|
||||
auto g_offset = local_tile(
|
||||
mB(_, _, 0),
|
||||
cute::make_shape(1, size<1>(mB)),
|
||||
make_coord(pre_fix_token, _0{}));
|
||||
|
||||
auto g_tensor = make_tensor(
|
||||
g_offset.data(),
|
||||
make_layout(
|
||||
cute::make_shape(actual_token, size<1>(mB)),
|
||||
g_offset.stride()
|
||||
));
|
||||
|
||||
Tensor gB = local_tile(g_tensor, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
|
||||
|
||||
return gB;
|
||||
}
|
||||
|
||||
|
||||
template <typename SharedStorage>
|
||||
CUTLASS_DEVICE void
|
||||
load(Params const& mainloop_params,
|
||||
MainloopPipeline pipeline,
|
||||
PipelineState& smem_pipe_write,
|
||||
SharedStorage &shared_storage,
|
||||
const int pre_fix_tokens,
|
||||
const int tokens,
|
||||
const int bidm,
|
||||
const int bidn,
|
||||
const int bidb,
|
||||
const int tidx) {
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{});
|
||||
Tensor sE = make_tensor(make_smem_ptr(shared_storage.smem_e.data()), SmemLayoutE{});
|
||||
|
||||
Tensor mA = mainloop_params.tma_load_A.get_tma_tensor(mainloop_params.layout_A.shape());
|
||||
Tensor mB = mainloop_params.tma_load_B.get_tma_tensor(mainloop_params.layout_B.shape());
|
||||
Tensor mE = mainloop_params.tma_load_E.get_tma_tensor(mainloop_params.layout_E.shape());
|
||||
|
||||
Tensor gA = local_tile(mA(_, _, _, bidm, bidb), select<0, 1>(Shape<Int<kBlockM / 2>, Int<kBlockK>>{}), make_coord(0,0,_));
|
||||
|
||||
Tensor gE = local_tile(mE(_, _, _, bidm, bidb), select<0, 1>(Shape<Int<NumMmaThreads>, Int<kBlockK / 64>>{}), make_coord(0, 0));
|
||||
|
||||
auto [tAgA, tAsA] = tma_partition(mainloop_params.tma_load_A, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sA), group_modes<0, 2>(gA));
|
||||
|
||||
auto [tEgE, tEsE] = tma_partition(mainloop_params.tma_load_E, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sE), group_modes<0, 2>(gE));
|
||||
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
|
||||
|
||||
if constexpr (TokenPackSize == 0) {
|
||||
Tensor gB = get_local_no_packed_tensor(
|
||||
mB,
|
||||
pre_fix_tokens,
|
||||
tokens,
|
||||
bidn);
|
||||
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sB), group_modes<0, 2>(gB));
|
||||
|
||||
const int kIters = kTiles / kStages;
|
||||
if (tidx == 0) {
|
||||
#pragma unroll
|
||||
for (int kiter = 0; kiter < kIters; ++kiter) {
|
||||
#pragma unroll
|
||||
for (int s = 0; s < kStages; s++) {
|
||||
const int i = kiter * kStages + s;
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tAgA(_, i), tAsA(_, s));
|
||||
copy(mainloop_params.tma_load_E.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tEgE(_, i), tEsE(_, s));
|
||||
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tBgB(_, i), tBsB(_, s));
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = kIters * kStages; i < kTiles; ++i) {
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
|
||||
copy(mainloop_params.tma_load_E.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tEgE(_, i), tEsE(_, smem_pipe_write.index()));
|
||||
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto mB_this_batch = make_tensor(
|
||||
mB(_, _, bidb).data(),
|
||||
make_layout(
|
||||
cute::make_shape(tokens, size<1>(mB)),
|
||||
mB.stride()
|
||||
));
|
||||
Tensor gB = local_tile(mB_this_batch, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
|
||||
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sB), group_modes<0, 2>(gB));
|
||||
|
||||
const int kIters = kTiles / kStages;
|
||||
if (tidx == 0) {
|
||||
#pragma unroll
|
||||
for (int kiter = 0; kiter < kIters; ++kiter) {
|
||||
#pragma unroll
|
||||
for (int s = 0; s < kStages; s++) {
|
||||
const int i = kiter * kStages + s;
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tAgA(_, i), tAsA(_, s));
|
||||
copy(mainloop_params.tma_load_E.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tEgE(_, i), tEsE(_, s));
|
||||
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tBgB(_, i), tBsB(_, s));
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = kIters * kStages; i < kTiles; ++i) {
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
|
||||
copy(mainloop_params.tma_load_E.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tEgE(_, i), tEsE(_, smem_pipe_write.index()));
|
||||
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int CUR_N, typename SharedStorage>
|
||||
CUTLASS_DEVICE void
|
||||
mma(Params const& mainloop_params,
|
||||
MainloopPipeline pipeline,
|
||||
PipelineState& smem_pipe_read,
|
||||
SharedStorage& shared_storage,
|
||||
float *acc_s,
|
||||
const int tidx) {
|
||||
|
||||
using sMemBLayout = std::conditional_t<
|
||||
CUR_N == kBlockN,
|
||||
SmemLayoutB,
|
||||
SmemLayoutB_TAIL
|
||||
>;
|
||||
|
||||
using Mma = std::conditional_t<
|
||||
CUR_N == kBlockN,
|
||||
typename Ktraits::Mma,
|
||||
typename Ktraits::Mma_TAIL
|
||||
>;
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), sMemBLayout{});
|
||||
Tensor sE = make_tensor(make_smem_ptr(shared_storage.smem_e.data()), SmemLayoutE{});
|
||||
|
||||
const int wg_idx = tidx / 128;
|
||||
const int wg_offset = wg_idx * 64;
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
};
|
||||
|
||||
constexpr int E_STEP = kBlockK / 64 * NumMmaThreads;
|
||||
constexpr int B_STEPS = CUR_N == 0 ? 1 : (kBlockN / CUR_N);
|
||||
|
||||
const int kIters = kTiles / kStages;
|
||||
#pragma unroll
|
||||
for (int kiter = 0; kiter < kIters; ++kiter) {
|
||||
#pragma unroll
|
||||
for (int s = 0; s < kStages; s++) {
|
||||
consumer_wait(pipeline, smem_pipe_read);
|
||||
|
||||
gemm<Mma, kBlockK, NumMmaThreads>(
|
||||
sA(_, _, s).data().get().get() + wg_offset,
|
||||
sB(_, _, s * B_STEPS).data().get().get(),
|
||||
acc_s,
|
||||
shared_storage.smem_e.data() + s * E_STEP + tidx);
|
||||
|
||||
pipeline.consumer_release(smem_pipe_read);
|
||||
++smem_pipe_read;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kTiles % kStages; ++i) {
|
||||
consumer_wait(pipeline, smem_pipe_read);
|
||||
|
||||
gemm<Mma, kBlockK, NumMmaThreads>(
|
||||
sA(_, _, i).data().get().get() + wg_offset,
|
||||
sB(_, _, i * B_STEPS).data().get().get(),
|
||||
acc_s,
|
||||
shared_storage.smem_e.data() + i * E_STEP + tidx);
|
||||
|
||||
pipeline.consumer_release(smem_pipe_read);
|
||||
++smem_pipe_read;
|
||||
}
|
||||
}
|
||||
|
||||
};
|
100
custom_ops/gpu_ops/wfp8afp8_sparse_gemm/utils.hpp
Normal file
100
custom_ops/gpu_ops/wfp8afp8_sparse_gemm/utils.hpp
Normal file
@@ -0,0 +1,100 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/arch/cluster_sm90.hpp> // For cute::elect_one_sync()
|
||||
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
|
||||
using namespace cute;
|
||||
|
||||
|
||||
template<typename T>
|
||||
struct PackedHalf;
|
||||
|
||||
template<>
|
||||
struct PackedHalf<cutlass::half_t> {
|
||||
using Type = __half2;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct PackedHalf<cutlass::bfloat16_t> {
|
||||
using Type = nv_bfloat162;
|
||||
};
|
||||
|
||||
template <class PointerType>
|
||||
__device__ GmmaDescriptor make_smem_desc(
|
||||
PointerType smem_ptr,
|
||||
int layout_type,
|
||||
int leading_byte_offset = 0,
|
||||
int stride_byte_offset = 1024) {
|
||||
|
||||
GmmaDescriptor desc;
|
||||
auto uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
desc.bitfield.start_address_ = uint_ptr >> 4;
|
||||
desc.bitfield.layout_type_ = layout_type;
|
||||
desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
|
||||
desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
|
||||
desc.bitfield.base_offset_ = 0;
|
||||
return desc;
|
||||
}
|
||||
|
||||
template <typename Mma, size_t ...Idx>
|
||||
__forceinline__ __device__ static void gemm(uint64_t const& desc_a, uint64_t const& desc_b, float* d, const uint32_t e, std::index_sequence<Idx...>) {
|
||||
Mma::fma(desc_a, desc_b, d[Idx]..., e, GMMA::ScaleOut::One);
|
||||
}
|
||||
|
||||
template <typename Mma, int kBlockK, int NumMmaThreads, typename T>
|
||||
__forceinline__ __device__ void gemm(
|
||||
const T * sA,
|
||||
const T * sB,
|
||||
float * acc_c,
|
||||
const uint32_t *E) {
|
||||
|
||||
constexpr int acc_num = sizeof(Mma::CRegisters) / sizeof(float);
|
||||
|
||||
warpgroup_arrive();
|
||||
// 选择的下标 对应的16进制
|
||||
// 01 4
|
||||
// 02 8
|
||||
// 03 12
|
||||
// 12 9
|
||||
// 13 13
|
||||
// 23 14
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kBlockK / 64; i++) {
|
||||
GmmaDescriptor a_desc = make_smem_desc(sA + i * 32, 1, 0, 1024);
|
||||
GmmaDescriptor b_desc = make_smem_desc(sB + i * 64, 1, 0, 1024);
|
||||
gemm<Mma>(a_desc, b_desc, acc_c, E[i * NumMmaThreads], std::make_index_sequence<acc_num>{});
|
||||
}
|
||||
|
||||
warpgroup_commit_batch();
|
||||
warpgroup_wait<0>();
|
||||
}
|
@@ -0,0 +1,309 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
|
||||
#include "kernel_traits.h"
|
||||
#include "mainloop_fwd.h"
|
||||
|
||||
template <typename Ktraits>
|
||||
void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) w8a8_sparse_gemm_kernel(
|
||||
CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits>::Params const mainloop_params) {
|
||||
|
||||
using Element = typename Ktraits::Element;
|
||||
static_assert(cutlass::sizeof_bits_v<Element> == 8);
|
||||
|
||||
using TileShape_MNK = typename Ktraits::TileShape_MNK;
|
||||
using ClusterShape = typename Ktraits::ClusterShape_MNK;
|
||||
|
||||
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
|
||||
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr int TokenPackSize = Ktraits::TokenPackSize;
|
||||
static constexpr int kBlockM = Ktraits::kBlockM;
|
||||
static constexpr int kBlockN = Ktraits::kBlockN;
|
||||
static constexpr int TAIL_N = Ktraits::TAIL_N;
|
||||
static constexpr int M = Ktraits::M;
|
||||
|
||||
using CollectiveMainloop = CollectiveMainloopFwd<Ktraits>;
|
||||
|
||||
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
|
||||
extern __shared__ char shared_memory[];
|
||||
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
|
||||
|
||||
int const lane_predicate = cute::elect_one_sync();
|
||||
int const warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
|
||||
if (warp_idx == 0 && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
|
||||
}
|
||||
|
||||
// Obtain warp index
|
||||
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
||||
|
||||
PipelineParams pipeline_params;
|
||||
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesA + CollectiveMainloop::TmaTransactionBytesE + CollectiveMainloop::TmaTransactionBytesB;
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
pipeline_params.role = warp_group_idx == 0
|
||||
? MainloopPipeline::ThreadCategory::Producer
|
||||
: MainloopPipeline::ThreadCategory::Consumer;
|
||||
pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
pipeline_params.num_consumers = NumMmaThreads;
|
||||
|
||||
MainloopPipeline pipeline(shared_storage.pipeline, pipeline_params, ClusterShape{});
|
||||
|
||||
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesB;
|
||||
|
||||
CollectiveMainloop collective_mainloop;
|
||||
|
||||
if constexpr (size(ClusterShape{}) > 1) {
|
||||
cute::cluster_arrive_relaxed();
|
||||
cute::cluster_wait();
|
||||
} else {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
|
||||
const int bidm = blockIdx.x;
|
||||
const int bidn = blockIdx.y;
|
||||
const int bidb = blockIdx.z;
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
const int pre_fix_tokens = TokenPackSize == 0 ? mainloop_params.tokens[bidb] : 0;
|
||||
|
||||
const int tokens = TokenPackSize == 0 ? mainloop_params.tokens[bidb + 1] - pre_fix_tokens : mainloop_params.tokens[bidb];
|
||||
|
||||
|
||||
if (bidn * kBlockN >= tokens) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (warp_group_idx == 0) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<40>();
|
||||
PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
collective_mainloop.load(
|
||||
mainloop_params,
|
||||
pipeline,
|
||||
smem_pipe_write,
|
||||
shared_storage,
|
||||
pre_fix_tokens,
|
||||
tokens,
|
||||
bidm,
|
||||
bidn,
|
||||
bidb,
|
||||
tidx);
|
||||
} else {
|
||||
cutlass::arch::warpgroup_reg_alloc<232>();
|
||||
PipelineState smem_pipe_read;
|
||||
|
||||
constexpr int acc_num = sizeof(typename Ktraits::Mma::CRegisters) / sizeof(float);
|
||||
float acc_s[acc_num];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < acc_num; ++i) {
|
||||
acc_s[i] = 0.0f;
|
||||
}
|
||||
|
||||
const int reamin_tokens = tokens - bidn * kBlockN;
|
||||
|
||||
const int mma_tidx = tidx - NumCopyThreads;
|
||||
|
||||
const float2 weight_scale = reinterpret_cast<const float2*>(mainloop_params.weight_scale + bidb * M + bidm * kBlockM)[mma_tidx / 4];
|
||||
|
||||
|
||||
if (TAIL_N > 0 && reamin_tokens < kBlockN) {
|
||||
collective_mainloop.mma<TAIL_N>(
|
||||
mainloop_params,
|
||||
pipeline,
|
||||
smem_pipe_read,
|
||||
shared_storage,
|
||||
acc_s,
|
||||
mma_tidx);
|
||||
|
||||
collective_mainloop.store<TAIL_N>(
|
||||
mainloop_params,
|
||||
acc_s,
|
||||
shared_storage,
|
||||
pre_fix_tokens,
|
||||
tokens,
|
||||
reinterpret_cast<const float*>(&weight_scale),
|
||||
bidm,
|
||||
bidn,
|
||||
bidb,
|
||||
mma_tidx);
|
||||
} else {
|
||||
collective_mainloop.mma<kBlockN>(
|
||||
mainloop_params,
|
||||
pipeline,
|
||||
smem_pipe_read,
|
||||
shared_storage,
|
||||
acc_s,
|
||||
mma_tidx);
|
||||
|
||||
collective_mainloop.store<kBlockN>(
|
||||
mainloop_params,
|
||||
acc_s,
|
||||
shared_storage,
|
||||
pre_fix_tokens,
|
||||
tokens,
|
||||
reinterpret_cast<const float*>(&weight_scale),
|
||||
bidm,
|
||||
bidn,
|
||||
bidb,
|
||||
mma_tidx);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <int Batch>
|
||||
auto get_gmem_layout(int Rows, int Cols) {
|
||||
return make_layout(
|
||||
make_shape(
|
||||
static_cast<int64_t>(Rows),
|
||||
static_cast<int64_t>(Cols),
|
||||
static_cast<int64_t>(Batch)),
|
||||
make_stride(
|
||||
static_cast<int64_t>(Cols),
|
||||
cute::_1{},
|
||||
static_cast<int64_t>(Rows * Cols)));
|
||||
}
|
||||
|
||||
template <int Batch>
|
||||
auto get_weight_gmem_layout(int m_nums, int k_nums, int Rows, int Cols) {
|
||||
return make_layout(
|
||||
make_shape(
|
||||
static_cast<int64_t>(Rows),
|
||||
static_cast<int64_t>(Cols),
|
||||
static_cast<int64_t>(k_nums),
|
||||
static_cast<int64_t>(m_nums),
|
||||
static_cast<int64_t>(Batch)),
|
||||
make_stride(
|
||||
static_cast<int64_t>(Cols),
|
||||
cute::_1{},
|
||||
static_cast<int64_t>(Rows * Cols),
|
||||
static_cast<int64_t>(Rows * Cols * k_nums),
|
||||
static_cast<int64_t>(Rows * Cols * k_nums * m_nums)));
|
||||
}
|
||||
|
||||
template <int Batch>
|
||||
auto get_gmem_e_layout(int ms, int ks, int ks_in, int Cols) {
|
||||
return make_layout(
|
||||
make_shape(
|
||||
static_cast<int64_t>(Cols),
|
||||
static_cast<int64_t>(ks_in),
|
||||
static_cast<int64_t>(ks),
|
||||
static_cast<int64_t>(ms),
|
||||
static_cast<int64_t>(Batch)),
|
||||
make_stride(
|
||||
cute::_1{},
|
||||
static_cast<int64_t>(Cols),
|
||||
static_cast<int64_t>(ks_in * Cols),
|
||||
static_cast<int64_t>(ks * ks_in * Cols),
|
||||
static_cast<int64_t>(ms * ks * Cols * 2)));
|
||||
}
|
||||
|
||||
template <typename InputType, typename OutputType, typename Kernel_traits, int M, int K, int Batch, int kPackTokenSize>
|
||||
void run_gemm(
|
||||
const InputType * A,
|
||||
const uint32_t *E,
|
||||
const InputType * B,
|
||||
OutputType * C,
|
||||
const float *weight_scale,
|
||||
const int *tokens_idx,
|
||||
const int max_tokens,
|
||||
cudaStream_t stream) {
|
||||
|
||||
using ElementOutput = typename Kernel_traits::ElementOutput;
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using CollectiveMainloop = CollectiveMainloopFwd<Kernel_traits>;
|
||||
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
|
||||
constexpr int NumMmaThreads = Kernel_traits::NumMmaThreads;
|
||||
constexpr int kBlockK = Kernel_traits::kBlockK;
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
|
||||
static_assert(M % Kernel_traits::kBlockM == 0);
|
||||
constexpr int M_nums = M / Kernel_traits::kBlockM;
|
||||
const int N_nums = (max_tokens + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
|
||||
|
||||
constexpr int kTiles = Kernel_traits::kTiles;
|
||||
|
||||
typename CollectiveMainloop::Params mainloop_params =
|
||||
CollectiveMainloop::to_underlying_arguments({
|
||||
static_cast<Element const*>(A),
|
||||
get_weight_gmem_layout<Batch>(M_nums, kTiles, kBlockM / 2, kBlockK),
|
||||
static_cast<uint32_t const*>(E),
|
||||
get_gmem_e_layout<Batch>(M_nums, kTiles, kBlockK / 64, NumMmaThreads),
|
||||
static_cast<Element const*>(B),
|
||||
get_gmem_layout<Batch>(kPackTokenSize == 0 ? max_tokens * Batch : kPackTokenSize, K),
|
||||
static_cast<ElementOutput*>(C),
|
||||
get_gmem_layout<Batch>(M, kPackTokenSize == 0 ? max_tokens : kPackTokenSize),
|
||||
tokens_idx,
|
||||
weight_scale,
|
||||
});
|
||||
|
||||
void *kernel;
|
||||
kernel = (void *)w8a8_sparse_gemm_kernel<Kernel_traits>;
|
||||
|
||||
int smem_size = sizeof(typename Kernel_traits::SharedStorage);
|
||||
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
|
||||
dim3 grid_dims;
|
||||
grid_dims.x = M_nums;
|
||||
grid_dims.y = N_nums;
|
||||
grid_dims.z = Batch;
|
||||
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
|
||||
dim3 block_dims(ctaSize);
|
||||
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
|
||||
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
|
||||
cutlass::launch_kernel_on_cluster(
|
||||
launch_params, kernel, mainloop_params);
|
||||
}
|
||||
|
||||
template <typename InputType, typename OutputType, int M, int K, int Batch, int kPackTokenSize>
|
||||
void w8a8_sparse_gemm(
|
||||
const InputType * A,
|
||||
const uint32_t * E,
|
||||
const InputType * B,
|
||||
OutputType * C,
|
||||
const float *weight_scale,
|
||||
const int *tokens_idx,
|
||||
const int max_tokens,
|
||||
cudaStream_t stream) {
|
||||
constexpr static int kBlockM = 128;
|
||||
constexpr static int kBlockK = 128;
|
||||
constexpr static int kNWarps = 4 + kBlockM / 16;
|
||||
constexpr static int kStages = 5;
|
||||
constexpr int kCluster = 1;
|
||||
static_assert(K % kBlockK == 0);
|
||||
constexpr int kTiles = K / kBlockK;
|
||||
const int max_tokens_pack16 = (max_tokens + 31) / 32 * 32;
|
||||
|
||||
using Kernel_traits = Kernel_traits<kBlockM, 256, kBlockK, kNWarps, kStages, kTiles, M, kPackTokenSize, 0, kCluster, InputType, OutputType>;
|
||||
run_gemm<InputType, OutputType, Kernel_traits, M, K, Batch, kPackTokenSize>(A, E, B, C, weight_scale, tokens_idx, max_tokens_pack16, stream);
|
||||
}
|
112
custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8afp8_sparse_gemm.cu
Normal file
112
custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8afp8_sparse_gemm.cu
Normal file
@@ -0,0 +1,112 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
#include "wfp8Afp8_sparse_gemm_template.h"
|
||||
|
||||
template <typename OutputType>
|
||||
void DisPatchWFp8AFp8Gemm(
|
||||
const cutlass::float_e4m3_t* input,
|
||||
const uint32_t* sparse_idx,
|
||||
const cutlass::float_e4m3_t* weight,
|
||||
const int * tokens,
|
||||
const float * weight_scale,
|
||||
OutputType * out,
|
||||
const int token_padding_size,
|
||||
const int max_tokens,
|
||||
const int batch_size,
|
||||
const int M,
|
||||
const int K,
|
||||
cudaStream_t stream) {
|
||||
|
||||
const int max_tokens_pack32 = (max_tokens + 31) / 32 * 32;
|
||||
|
||||
int kBlockN = 256;
|
||||
int TailN = max_tokens_pack32 % kBlockN;
|
||||
if (max_tokens < 256) {
|
||||
kBlockN = max_tokens_pack32;
|
||||
TailN = 0;
|
||||
}
|
||||
if constexpr (std::is_same_v<OutputType, cutlass::bfloat16_t>) {
|
||||
SPARSE_GEMM_SWITCH_BF16(M, K, batch_size, token_padding_size, kBlockN, TailN,
|
||||
weight,
|
||||
sparse_idx,
|
||||
input,
|
||||
out,
|
||||
weight_scale,
|
||||
tokens,
|
||||
max_tokens,
|
||||
stream)
|
||||
} else {
|
||||
PD_THROW("Only supported dtype in ['BFLOAT16'].");
|
||||
}
|
||||
}
|
||||
|
||||
void WFp8AFp8Gemm(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& sparse_idx,
|
||||
const paddle::Tensor& weight,
|
||||
const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group
|
||||
const paddle::Tensor& weight_scale,
|
||||
const paddle::Tensor& out,
|
||||
const int token_padding_size,
|
||||
const int max_tokens,
|
||||
const bool is_bfloat16) {
|
||||
|
||||
const int batch_size = weight.dims()[0];
|
||||
const int M = weight.dims()[1];
|
||||
const int K = weight.dims()[2] * 2;
|
||||
|
||||
if (input.dtype() != paddle::DataType::FLOAT8_E4M3FN) {
|
||||
PD_THROW("Only supported dtype in ['FLOAT8_E4M3FN'].");
|
||||
}
|
||||
|
||||
if (is_bfloat16) {
|
||||
DisPatchWFp8AFp8Gemm(
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
|
||||
reinterpret_cast<const uint32_t*>(sparse_idx.data<int32_t>()),
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<phi::dtype::float8_e4m3fn>()),
|
||||
tokens.data<int>(),
|
||||
weight_scale.data<float>(),
|
||||
reinterpret_cast<cutlass::bfloat16_t*>(const_cast<phi::dtype::bfloat16*>(out.data<phi::dtype::bfloat16>())),
|
||||
token_padding_size,
|
||||
max_tokens,
|
||||
batch_size,
|
||||
M,
|
||||
K,
|
||||
input.stream()
|
||||
);
|
||||
} else {
|
||||
PD_THROW("Only supported dtype in ['BFLOAT16'].");
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(wfp8afp8_sparse_gemm)
|
||||
.Inputs({"input",
|
||||
"sparse_idx",
|
||||
"weight",
|
||||
"tokens",
|
||||
"weight_scale",
|
||||
"ffn_out"})
|
||||
.Outputs({"out"})
|
||||
.SetInplaceMap({{"ffn_out", "out"}})
|
||||
.Attrs({"token_padding_size: int",
|
||||
"max_tokens: int",
|
||||
"is_bfloat16: bool"})
|
||||
.SetKernelFn(PD_KERNEL(WFp8AFp8Gemm));
|
@@ -0,0 +1,96 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
|
||||
void pack_E(const uint8_t *E_src, int32_t *E_dst, const int M, const int K, const int Batch) {
|
||||
// 选择的下标 对应的16进制
|
||||
// 01 4
|
||||
// 02 8
|
||||
// 03 12
|
||||
// 12 9
|
||||
// 13 13
|
||||
// 23 14
|
||||
const int ld1 = K / 4;
|
||||
const int ld2 = K / 4 / 8;
|
||||
const uint8_t select_idx[6] = {14, 13, 9, 12, 8, 4};
|
||||
for (int b = 0; b < Batch; ++b) {
|
||||
for (int m = 0; m < M; ++m) {
|
||||
for (int k = 0; k < ld1; k+=8) {
|
||||
uint32_t dst = 0;
|
||||
for (int k2 = 7; k2 > 0; --k2) {
|
||||
dst |= select_idx[E_src[b * M * ld1 + m * ld1 + k + k2]];
|
||||
dst <<= 4;
|
||||
}
|
||||
dst |= select_idx[E_src[b * M * ld1 + m * ld1 + k]];
|
||||
E_dst[b * M * ld2 + m * ld2 + k / 8] = dst;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void peruate_E(const int32_t *E_src, int32_t *E_dst, const int M, const int K, const int Batch) {
|
||||
const int m_nums = M / 128;
|
||||
const int k_nums = K / 128;
|
||||
for (int b = 0; b < Batch; ++b) {
|
||||
for (int m = 0; m < m_nums; ++m) {
|
||||
for (int k = 0; k < k_nums; ++k) {
|
||||
const int dst_idx = b * m_nums * k_nums * 512 + m * k_nums * 512 + k * 512;
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
E_dst[dst_idx + 0 + j * 32 + 4 * i] = E_src[dst_idx + 0 + j * 64 + 4 * i];
|
||||
E_dst[dst_idx + 2 + j * 32 + 4 * i] = E_src[dst_idx + 1 + j * 64 + 4 * i];
|
||||
E_dst[dst_idx + 1 + j * 32 + 4 * i] = E_src[dst_idx + 32 + j * 64 + 4 * i];
|
||||
E_dst[dst_idx + 3 + j * 32 + 4 * i] = E_src[dst_idx + 33 + j * 64 + 4 * i];
|
||||
}
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
E_dst[dst_idx + 256 + j * 32 + 4 * i] = E_src[dst_idx + 2 + j * 64 + 4 * i];
|
||||
E_dst[dst_idx + 258 + j * 32 + 4 * i] = E_src[dst_idx + 3 + j * 64 + 4 * i];
|
||||
E_dst[dst_idx + 257 + j * 32 + 4 * i] = E_src[dst_idx + 34 + j * 64 + 4 * i];
|
||||
E_dst[dst_idx + 259 + j * 32 + 4 * i] = E_src[dst_idx + 35 + j * 64 + 4 * i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> WFp8AFp8GemmSparseIdxConvert(
|
||||
const paddle::Tensor& weight,
|
||||
const int batch_size,
|
||||
const int M,
|
||||
const int K) {
|
||||
|
||||
paddle::Tensor weight_temp = paddle::empty({batch_size, M, K / 32}, paddle::DataType::INT32, weight.place());
|
||||
paddle::Tensor weight_new = paddle::empty({batch_size, M, K / 32}, paddle::DataType::INT32, weight.place());
|
||||
pack_E(weight.data<uint8_t>(), weight_temp.data<int32_t>(), M, K, batch_size);
|
||||
peruate_E(weight_temp.data<int32_t>(), weight_new.data<int32_t>(), M, K, batch_size);
|
||||
return {weight_new};
|
||||
}
|
||||
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(wfp8afp8_gemm_sparse_idx_convert)
|
||||
.Inputs({"weight"})
|
||||
.Outputs({"converted_weight"})
|
||||
.Attrs({"batch: int",
|
||||
"M: int",
|
||||
"K: int"})
|
||||
.SetKernelFn(PD_KERNEL(WFp8AFp8GemmSparseIdxConvert));
|
@@ -510,6 +510,8 @@ elif paddle.is_compiled_with_cuda():
|
||||
sources += ["gpu_ops/flash_mask_attn/flash_mask_attn.cu"]
|
||||
os.system("python utils/auto_gen_w4afp8_gemm_kernel.py")
|
||||
sources += find_end_files("gpu_ops/w4afp8_gemm", ".cu")
|
||||
os.system("python utils/auto_gen_wfp8afp8_sparse_gemm_kernel.py")
|
||||
sources += find_end_files("gpu_ops/wfp8afp8_sparse_gemm", ".cu")
|
||||
|
||||
setup(
|
||||
name="fastdeploy_ops",
|
||||
|
207
custom_ops/utils/auto_gen_wfp8afp8_sparse_gemm_kernel.py
Normal file
207
custom_ops/utils/auto_gen_wfp8afp8_sparse_gemm_kernel.py
Normal file
@@ -0,0 +1,207 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
file_dir = "./gpu_ops/wfp8afp8_sparse_gemm/"
|
||||
|
||||
gemm_template_head = """
|
||||
#pragma once
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
#include <cuda_fp16.h>
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
"""
|
||||
gemm_template_case = """
|
||||
void wfp8afp8_sparse_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
|
||||
const cutlass::float_e4m3_t * weight,
|
||||
const uint32_t * sparse_idx,
|
||||
const cutlass::float_e4m3_t * input,
|
||||
{cutlass_type} * out,
|
||||
const float *weight_scale,
|
||||
const int *tokens,
|
||||
const int max_tokens,
|
||||
cudaStream_t stream);
|
||||
"""
|
||||
|
||||
gemm_template_cu_head = """
|
||||
#include "paddle/extension.h"
|
||||
#include "wfp8Afp8_sparse_gemm_template.h"
|
||||
#include "w8a8_sparse_gemm_kernel.hpp"
|
||||
|
||||
"""
|
||||
gemm_template_cu_template = """
|
||||
void wfp8afp8_sparse_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
|
||||
const cutlass::float_e4m3_t * weight,
|
||||
const uint32_t * sparse_idx,
|
||||
const cutlass::float_e4m3_t * input,
|
||||
{cutlass_type} * out,
|
||||
const float *weight_scale,
|
||||
const int *tokens,
|
||||
const int max_tokens,
|
||||
cudaStream_t stream) {{
|
||||
|
||||
constexpr static int M = {M};
|
||||
constexpr static int K = {K};
|
||||
constexpr static int Batch = {BATCH};
|
||||
constexpr static int TokenPackSize = {PADDING};
|
||||
constexpr static int kBlockN = {N};
|
||||
constexpr static int kBlockN_TAIL = {TAILN};
|
||||
constexpr static int kBlockM = 128;
|
||||
constexpr static int kBlockK = 128;
|
||||
constexpr static int kNWarps = 4 + kBlockM / 16;
|
||||
constexpr static int kStages = 5;
|
||||
constexpr int kCluster = 1;
|
||||
static_assert(K % kBlockK == 0);
|
||||
constexpr int kTiles = K / kBlockK;
|
||||
|
||||
using Kernel_traits = Kernel_traits<
|
||||
kBlockM, kBlockN, kBlockK, kNWarps, kStages, kTiles,
|
||||
M, TokenPackSize, kBlockN_TAIL, kCluster, cutlass::float_e4m3_t,
|
||||
{cutlass_type}>;
|
||||
run_gemm<cutlass::float_e4m3_t, {cutlass_type},
|
||||
Kernel_traits, M, K, Batch, TokenPackSize>
|
||||
(weight, sparse_idx, input, out, weight_scale,
|
||||
tokens, max_tokens, stream);
|
||||
}}
|
||||
"""
|
||||
|
||||
gemm_case = [
|
||||
[128, 128, 1, 0],
|
||||
[7168, 8192, 8, 0], # eb45T ffn1
|
||||
]
|
||||
|
||||
dtype = ["BF16"]
|
||||
|
||||
|
||||
def get_cutlass_type(type):
|
||||
if type == "BF16":
|
||||
return "cutlass::bfloat16_t"
|
||||
elif type == "FP16":
|
||||
return "cutlass::half_t"
|
||||
|
||||
|
||||
template_head_file = open(f"{file_dir}wfp8Afp8_sparse_gemm_template.h", "w")
|
||||
template_head_file.write(gemm_template_head)
|
||||
|
||||
for type in dtype:
|
||||
for case in gemm_case:
|
||||
for n in range(32, 257, 32):
|
||||
template_head_file.write(
|
||||
gemm_template_case.format(
|
||||
M=case[0],
|
||||
K=case[1],
|
||||
N=n,
|
||||
BATCH=case[2],
|
||||
TYPE=type,
|
||||
PADDING=case[3],
|
||||
TAILN=0,
|
||||
cutlass_type=get_cutlass_type(type),
|
||||
)
|
||||
)
|
||||
template_head_file.write(
|
||||
gemm_template_case.format(
|
||||
M=case[0],
|
||||
K=case[1],
|
||||
N=256,
|
||||
BATCH=case[2],
|
||||
TYPE=type,
|
||||
PADDING=case[3],
|
||||
TAILN=n - 32,
|
||||
cutlass_type=get_cutlass_type(type),
|
||||
)
|
||||
)
|
||||
|
||||
template_cu_file = open(
|
||||
f"{file_dir}wfp8Afp8_sparse_gemm_M{case[0]}_N{n}_TAILN{0}_K{case[1]}_B{case[2]}_P{case[3]}_{type}.cu",
|
||||
"w",
|
||||
)
|
||||
template_cu_file.write(gemm_template_cu_head)
|
||||
template_cu_file.write(
|
||||
gemm_template_cu_template.format(
|
||||
M=case[0],
|
||||
K=case[1],
|
||||
N=n,
|
||||
BATCH=case[2],
|
||||
TYPE=type,
|
||||
PADDING=case[3],
|
||||
TAILN=0,
|
||||
cutlass_type=get_cutlass_type(type),
|
||||
)
|
||||
)
|
||||
|
||||
template_cu_file.close()
|
||||
|
||||
template_cu_file = open(
|
||||
f"{file_dir}wfp8Afp8_sparse_gemm_M{case[0]}_N{256}_TAILN{n-32}_K{case[1]}_B{case[2]}_P{case[3]}_{type}.cu",
|
||||
"w",
|
||||
)
|
||||
template_cu_file.write(gemm_template_cu_head)
|
||||
template_cu_file.write(
|
||||
gemm_template_cu_template.format(
|
||||
M=case[0],
|
||||
K=case[1],
|
||||
N=256,
|
||||
BATCH=case[2],
|
||||
TYPE=type,
|
||||
PADDING=case[3],
|
||||
TAILN=n - 32,
|
||||
cutlass_type=get_cutlass_type(type),
|
||||
)
|
||||
)
|
||||
|
||||
template_cu_file.close()
|
||||
|
||||
for type in dtype:
|
||||
template_head_file.write("\n")
|
||||
template_head_file.write(
|
||||
"""#define SPARSE_GEMM_SWITCH_{TYPE}(_M, _K, _BATCH, _TokenPaddingSize, _kBlockN, _TailN, ...) {{ \\
|
||||
if (_M == 0 && _K == 0 && _BATCH == 0 && _TokenPaddingSize == 0 && _kBlockN == 0 && _TailN == 0) {{ \\""".format(
|
||||
TYPE=type
|
||||
)
|
||||
)
|
||||
|
||||
template_head_file.write("\n")
|
||||
|
||||
for case in gemm_case:
|
||||
for n in range(32, 257, 32):
|
||||
template_head_file.write(
|
||||
""" }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\
|
||||
wfp8afp8_sparse_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(
|
||||
M=case[0], K=case[1], N=n, BATCH=case[2], TYPE=type, PADDING=case[3], TAILN=0
|
||||
)
|
||||
)
|
||||
template_head_file.write("\n")
|
||||
template_head_file.write(
|
||||
""" }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\
|
||||
wfp8afp8_sparse_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(
|
||||
M=case[0], K=case[1], N=256, BATCH=case[2], TYPE=type, PADDING=case[3], TAILN=n - 32
|
||||
)
|
||||
)
|
||||
template_head_file.write("\n")
|
||||
|
||||
template_head_file.write(
|
||||
""" } else { \\
|
||||
PADDLE_THROW(phi::errors::Unimplemented("WFp8aFp8 Sparse not supported m=%d k=%d batch=%d token_padding_size=%d kBlockN=%d tailN=%d\\n", _M, _K, _BATCH, _TokenPaddingSize, _kBlockN, _TailN)); \\
|
||||
} \\
|
||||
}"""
|
||||
)
|
||||
|
||||
template_head_file.close()
|
163
test/operators/test_wfp8afp8_sparse_gemm.py
Normal file
163
test/operators/test_wfp8afp8_sparse_gemm.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
wfp8afp8_gemm_sparse_idx_convert,
|
||||
wfp8afp8_sparse_gemm,
|
||||
)
|
||||
|
||||
|
||||
def wfp8afp8_gemm_naive(input_bf16, weight_quant, tokens, weight_scale, BATCH, N):
|
||||
weight = weight_quant.astype("bfloat16") / weight_scale
|
||||
input_bf16 = input_bf16.astype("bfloat16")
|
||||
all_tokens = int(tokens.sum())
|
||||
out = paddle.zeros([all_tokens, N], dtype="bfloat16")
|
||||
pre_fix_token = 0
|
||||
for i in range(BATCH):
|
||||
input = input_bf16[pre_fix_token : pre_fix_token + tokens[i], :]
|
||||
out_i = paddle.matmul(input, weight[i], transpose_y=True)
|
||||
out[pre_fix_token : pre_fix_token + tokens[i], :] = out_i
|
||||
pre_fix_token += tokens[i]
|
||||
return out
|
||||
|
||||
|
||||
def peruate_scale(weight_scale, N):
|
||||
BATCH = weight_scale.shape[0]
|
||||
weight_scale = weight_scale.reshape([BATCH, N])
|
||||
temp = paddle.zeros([16])
|
||||
for b in range(BATCH):
|
||||
for n in range(0, N, 16):
|
||||
temp[:] = weight_scale[b, n : n + 16]
|
||||
for j in range(0, 16, 2):
|
||||
weight_scale[b, n + j] = temp[j // 2]
|
||||
weight_scale[b, n + j + 1] = temp[j // 2 + 8]
|
||||
return weight_scale
|
||||
|
||||
|
||||
def sparse(weight, sparse_idx):
|
||||
pack_weight = np.zeros([weight.shape[0], weight.shape[1], weight.shape[2] // 2], dtype=weight.dtype)
|
||||
|
||||
idx_select = [
|
||||
[0, 1, 2, 3],
|
||||
[0, 2, 1, 3],
|
||||
[0, 3, 1, 2],
|
||||
[1, 2, 0, 3],
|
||||
[1, 3, 0, 2],
|
||||
[2, 3, 0, 1],
|
||||
]
|
||||
for b in range(weight.shape[0]):
|
||||
for i in range(weight.shape[1]):
|
||||
for j in range(0, weight.shape[2], 4):
|
||||
idx = sparse_idx[b, i, j // 4]
|
||||
idx1 = idx_select[idx][0]
|
||||
idx2 = idx_select[idx][1]
|
||||
idx3 = idx_select[idx][2]
|
||||
idx4 = idx_select[idx][3]
|
||||
|
||||
weight[b, i, j + idx1] = 0
|
||||
weight[b, i, j + idx2] = 0
|
||||
|
||||
pack_weight[b, i, j // 4 * 2] = weight[b, i, j + idx3]
|
||||
pack_weight[b, i, j // 4 * 2 + 1] = weight[b, i, j + idx4]
|
||||
return weight, pack_weight
|
||||
|
||||
|
||||
def convert(weight, sparse_idx, K):
|
||||
BATCH = weight.shape[0]
|
||||
temp = np.zeros(weight.shape, dtype=weight.dtype)
|
||||
|
||||
for i in range(0, weight.shape[1], 128):
|
||||
for j in range(0, 128):
|
||||
dst_idx = j // 2 + (j % 2) * 64
|
||||
temp[:, j + i, :] = weight[:, i + dst_idx, :]
|
||||
|
||||
temp_trans = np.zeros([BATCH, weight.shape[1] // 128, K // 128, 128, 64], dtype=weight.dtype)
|
||||
temp_E = np.zeros([BATCH, weight.shape[1] // 128, K // 128, 128, 32], dtype=sparse_idx.dtype)
|
||||
|
||||
for b in range(BATCH):
|
||||
for i in range(weight.shape[1] // 128):
|
||||
for j in range(K // 128):
|
||||
temp_trans[b, i, j] = temp[b, i * 128 : i * 128 + 128, j * 64 : j * 64 + 64]
|
||||
temp_E[b, i, j] = sparse_idx[b, i * 128 : i * 128 + 128, j * 32 : j * 32 + 32]
|
||||
|
||||
return temp_trans, temp_E
|
||||
|
||||
|
||||
class TestWFp8Afp8SparseGemm(unittest.TestCase):
|
||||
def test_wfp8afp8_sparse_gemm(self):
|
||||
paddle.seed(0)
|
||||
tokens_per_group = 10
|
||||
N = 128
|
||||
K = 128
|
||||
BATCH = 1
|
||||
TokenPadding = 0
|
||||
|
||||
tokens = [tokens_per_group] * BATCH
|
||||
tokens_perfix_sum = np.cumsum(tokens)
|
||||
tokens_perfix_sum = np.insert(tokens_perfix_sum, 0, 0)
|
||||
|
||||
tokens = paddle.to_tensor(tokens, dtype="int32")
|
||||
tokens_perfix_sum = paddle.to_tensor(tokens_perfix_sum, dtype="int32")
|
||||
|
||||
all_tokens = int(tokens.sum())
|
||||
|
||||
input_fp8 = paddle.randn([all_tokens, K], dtype="bfloat16").astype(paddle.float8_e4m3fn)
|
||||
|
||||
weight = paddle.randn([BATCH, N, K], dtype="bfloat16")
|
||||
|
||||
weight_scale = 40 / weight.abs().max(axis=-1).reshape([BATCH, N, 1])
|
||||
|
||||
weight_quant = (weight * weight_scale).astype(paddle.float8_e4m3fn).astype("bfloat16")
|
||||
|
||||
weight_quant = weight_quant.numpy()
|
||||
|
||||
sparse_idx = np.random.randint(0, high=6, size=(BATCH, N, K // 4))
|
||||
|
||||
weight_quant, pack_weight = sparse(weight_quant, sparse_idx)
|
||||
|
||||
weight_quant = paddle.to_tensor(weight_quant)
|
||||
out_naive = wfp8afp8_gemm_naive(input_fp8, weight_quant, tokens, weight_scale, BATCH, N)
|
||||
|
||||
pack_weight, convert_sparse_idx = convert(pack_weight, sparse_idx, K)
|
||||
|
||||
pack_weight = paddle.to_tensor(pack_weight).astype(paddle.float8_e4m3fn)
|
||||
convert_sparse_idx = paddle.to_tensor(convert_sparse_idx).astype("uint8").cpu()
|
||||
convert_sparse_idx = wfp8afp8_gemm_sparse_idx_convert(convert_sparse_idx, int(BATCH), int(N), int(K)).cuda()
|
||||
|
||||
weight_scale = paddle.to_tensor(peruate_scale(weight_scale, N)).astype("float32")
|
||||
|
||||
out_pd = paddle.zeros([all_tokens, N], dtype="bfloat16")
|
||||
|
||||
wfp8afp8_sparse_gemm(
|
||||
input_fp8,
|
||||
convert_sparse_idx,
|
||||
pack_weight.reshape([BATCH, N, K // 2]),
|
||||
tokens_perfix_sum if TokenPadding == 0 else tokens,
|
||||
1 / weight_scale,
|
||||
out_pd,
|
||||
int(TokenPadding),
|
||||
int(tokens_per_group),
|
||||
True,
|
||||
)
|
||||
|
||||
print((out_pd - out_naive).abs().max())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user