支持w4afp8 (#3324)

This commit is contained in:
yangjianfengo1
2025-08-11 19:00:18 +08:00
committed by GitHub
parent c7cb31051b
commit c7993d35cb
9 changed files with 1454 additions and 0 deletions

4
.gitignore vendored
View File

@@ -164,3 +164,7 @@ build
.ccls-cache
third_party
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_*.cu
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_template.h

View File

@@ -0,0 +1,154 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "cute/algorithm/copy.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
using namespace cute;
template <int kStages, class GemmType, class OutputType, class SmemLayoutA,
class SmemLayoutB, class SmemLayoutC>
struct SharedStorage {
union {
struct {
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutA>> smem_a;
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutB>> smem_b;
};
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutC>> smem_c;
};
struct {
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline;
};
};
template<int kBlockM_, int kBlockN_, int kBlockK_,
int kNWarps_, int kStages_,
int kTiles_, int M_,
int TokenPackSize_,
int TAIL_N_ = 0,
int kClusterM_ = 1,
typename elem_type=cutlass::float_e4m3_t,
typename OutputType = cutlass::bfloat16_t>
struct Kernel_traits {
using Element = elem_type;
using ElementAccum = float;
using ElementOutput = OutputType;
static_assert(cutlass::sizeof_bits_v<Element> == 8);
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
static_assert(kNWarps_ == 12 || kNWarps_ == 16);
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kBlockK = kBlockK_;
static constexpr int kTiles = kTiles_;
static constexpr int TokenPackSize = TokenPackSize_;
static constexpr int M = M_;
static constexpr int TAIL_N = TAIL_N_;
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kBlockK>>;
using TileShape_MNK_TAIL = Shape<Int<kBlockM>, Int<TAIL_N>, Int<kBlockK>>;
static constexpr int kClusterM = kClusterM_;
using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
static constexpr int kStages = kStages_;
static_assert(kStages > 1);
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
AtomLayoutMNK{}));
using TiledMma_TAIL = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK_TAIL>(),
AtomLayoutMNK{}));
using SmemLayoutAtomA = decltype(
cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K, Element, Int<kBlockM>, Int<kBlockK / 2>>());
using SmemLayoutA = decltype(
tile_to_shape(SmemLayoutAtomA{},
make_shape(Int<kBlockM>{}, Int<kBlockK / 2>{}, Int<kStages>{})));
using SmemLayoutAtomB = decltype(
cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutB = decltype(
tile_to_shape(SmemLayoutAtomB{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutAtomB_TAIL = decltype(
cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK_TAIL{})),
decltype(cute::get<2>(TileShape_MNK_TAIL{}))>());
using SmemLayoutB_TAIL = decltype(
tile_to_shape(SmemLayoutAtomB_TAIL{},
make_shape(
shape<1>(TileShape_MNK_TAIL{}),
shape<2>(TileShape_MNK_TAIL{}),
Int<kStages>{})
));
using SmemLayoutAtomC = decltype(
cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K, ElementOutput,
decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<1>(TileShape_MNK{}))>());
using SmemLayoutC = decltype(tile_to_shape(SmemLayoutAtomC{}, select<0, 1>(TileShape_MNK{})));
using SmemCopyAtomAB = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
using SmemCopyAtomC = Copy_Atom<cute::SM90_U32x4_STSM_N, ElementOutput>;
using SharedStorage = SharedStorage<
kStages, Element, ElementOutput, SmemLayoutA, SmemLayoutB, SmemLayoutC>;
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
using PipelineState = typename cutlass::PipelineState<kStages>;
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<OutputType>);
static constexpr int kNumThreadsPerRow = kBlockN / kNumVecElem;
// static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
using TiledCopyCAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, OutputType>;
using TiledCopyCThrLayout = decltype(cute::make_layout(
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
LayoutRight{}));
using TiledCopyCValLayout = decltype(cute::make_layout(
cute::make_shape(_1{}, Int<kNumVecElem>{}),
LayoutRight{}));
using TiledCopyC = decltype(make_tiled_copy(
TiledCopyCAtom{},
TiledCopyCThrLayout{}, // Thr layout
TiledCopyCValLayout{} // Val layout
));
};

View File

@@ -0,0 +1,405 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "cutlass/pipeline/pipeline.hpp"
#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
// #include "named_barrier.hpp"
#include "utils.hpp"
using namespace cute;
template <typename Ktraits>
struct CollectiveMainloopFwd {
using Element = typename Ktraits::Element;
using ElementOutput = typename Ktraits::ElementOutput;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using TileShape_MNK_TAIL = typename Ktraits::TileShape_MNK_TAIL;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
using ElementAccum = typename Ktraits::ElementAccum;
static constexpr int kStages = Ktraits::kStages;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr int TAIL_N = Ktraits::TAIL_N;
static constexpr int kBlockK = Ktraits::kBlockK;
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int kTiles = Ktraits::kTiles;
static constexpr int M = Ktraits::M;
static constexpr int TokenPackSize = Ktraits::TokenPackSize;
using GmemTiledCopy = cute::SM90_TMA_LOAD;
using SmemLayoutA = typename Ktraits::SmemLayoutA;
using SmemLayoutB = typename Ktraits::SmemLayoutB;
using SmemLayoutC = typename Ktraits::SmemLayoutC;
using SmemLayoutB_TAIL = typename Ktraits::SmemLayoutB_TAIL;
using ShapeT = cute::Shape<int64_t, int64_t, int64_t>;
using StrideT = cute::Shape<int64_t, _1, int64_t>;
using LayoutT = cute::Layout<ShapeT, StrideT>;
using TMA_A = decltype(make_tma_copy(
GmemTiledCopy{},
make_tensor(
make_gmem_ptr(static_cast<Element const*>(nullptr)),
ShapeT{},
StrideT{}
),
SmemLayoutA{}(_, _, _0{}),
select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}),
size<0>(ClusterShape{})));
using TMA_B = decltype(make_tma_copy(
GmemTiledCopy{},
make_tensor(
make_gmem_ptr(static_cast<Element const*>(nullptr)),
ShapeT{},
StrideT{}
),
take<0, 2>(SmemLayoutB{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{})));
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
using SmemCopyAtomAB = typename Ktraits::SmemCopyAtomAB;
using SmemCopyAtomC = typename Ktraits::SmemCopyAtomC;
using TiledCopyC = typename Ktraits::TiledCopyC;
static constexpr uint32_t TmaTransactionBytesA = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutA{})) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr uint32_t TmaTransactionBytesB = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutB{})) * cutlass::sizeof_bits_v<Element> / 8);
struct Arguments {
Element const* ptr_A;
LayoutT layout_A;
Element const* ptr_B;
LayoutT layout_B;
ElementOutput * ptr_C;
LayoutT layout_C;
const float *weight_scale;
const float *input_row_sum;
const int * tokens;
};
struct Params {
LayoutT layout_A;
LayoutT layout_B;
TMA_A tma_load_A;
TMA_B tma_load_B;
ElementOutput * ptr_C;
const float *weight_scale;
const float *input_row_sum;
const int * tokens;
};
Params static
to_underlying_arguments(Arguments const& args) {
Tensor mA = make_tensor(make_gmem_ptr(args.ptr_A), args.layout_A);
TMA_A tma_load_A = make_tma_copy(
GmemTiledCopy{},
mA,
SmemLayoutA{}(_, _, _0{}),
select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}),
size<0>(ClusterShape{}));
Tensor mB = make_tensor(make_gmem_ptr(args.ptr_B), args.layout_B);
TMA_B tma_load_B = make_tma_copy(
GmemTiledCopy{},
mB,
SmemLayoutB{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{}));
return {args.layout_A, args.layout_B, tma_load_A, tma_load_B,
args.ptr_C, args.weight_scale, args.input_row_sum, args.tokens};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_A.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_B.get_tma_descriptor());
}
template <int CUR_N, typename SharedStorage, typename FrgTensorO, typename TiledMma>
CUTLASS_DEVICE void
store(Params const& mainloop_params,
FrgTensorO & tOrO,
SharedStorage& shared_storage,
TiledMma tiled_mma,
const float *input_row_sum,
const float *weight_scale,
const int tokens,
const int pre_fix_tokens,
const int bidm,
const int bidn,
const int bidb,
const int tidx) {
using packHalf = typename PackedHalf<ElementOutput>::Type;
Tensor tOrO_out = make_tensor<ElementOutput>(tOrO.layout());
#pragma unroll
for (int i = 0; i < size(tOrO); i+=4) {
const int sum_idx = i * 2;
tOrO[i] = (tOrO[i] + input_row_sum[sum_idx]) * weight_scale[0];
tOrO[i + 1] = (tOrO[i + 1] + input_row_sum[sum_idx + 1]) * weight_scale[0];
tOrO[i + 2] = (tOrO[i + 2] + input_row_sum[sum_idx]) * weight_scale[1];
tOrO[i + 3] = (tOrO[i + 3] + input_row_sum[sum_idx + 1]) * weight_scale[1];
*reinterpret_cast<packHalf*>(&tOrO_out[i]) = packHalf(tOrO[i], tOrO[i + 2]);
*reinterpret_cast<packHalf*>(&tOrO_out[i + 2]) = packHalf(tOrO[i + 1], tOrO[i + 3]);
}
uint16_t *smem_c = reinterpret_cast<uint16_t *>(shared_storage.smem_c.data());
uint32_t * reg_data = reinterpret_cast<uint32_t*>(tOrO_out.data());
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
constexpr int k_copy_times = CUR_N / 16;
#pragma unroll
for (int i = 0; i < k_copy_times; i++) {
uint32_t smem_ptr = cast_smem_ptr_to_uint(reinterpret_cast<uint128_t*>(smem_c + i * 16 * 128) + tidx);
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
asm volatile (
"stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
:: "r"(smem_ptr), "r"(reg_data[4 * i + 0]), "r"(reg_data[4 * i + 2]), "r"(reg_data[4 * i + 1]), "r"(reg_data[4 * i + 3]));
#endif
}
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
const int batch_idx = TokenPackSize == 0 ? pre_fix_tokens * M : bidb * M * TokenPackSize;
ElementOutput * store_c = mainloop_params.ptr_C + batch_idx + bidn * (M * kBlockN) + bidm * kBlockM;
const int reamin_tokens = tokens - bidn * kBlockN;
const int col = tidx % 2;
constexpr int kPackSize = 16 / sizeof(ElementOutput);
constexpr int kNumVecElem = kBlockM / kPackSize;
constexpr int copy_len = CUR_N * kNumVecElem;
#pragma unroll
for (int idx = tidx; idx < copy_len; idx += NumMmaThreads) {
const int idx_div2 = idx / 2;
const int store_idx = idx_div2 / 128 * 128 + idx_div2 % 8 * 16 + idx_div2 % 128 / 16 + idx_div2 % 16 / 8 * 8;
const int store_global_idx = store_idx * 2 + col;
const int row = store_global_idx / kNumVecElem;
const int col = store_global_idx % kNumVecElem;
if (row >= reamin_tokens) {
continue;
}
const int offset = row * (M / kPackSize) + col;
reinterpret_cast<uint4*>(store_c)[offset] = reinterpret_cast<uint4*>(smem_c)[idx];
}
}
template <typename MTensor>
CUTLASS_DEVICE auto get_local_no_packed_tensor(
const MTensor &mB,
const int pre_fix_token,
const int actual_token,
const int bidn) const {
auto g_offset = local_tile(
mB(_, _, 0),
cute::make_shape(1, size<1>(mB)),
make_coord(pre_fix_token, _0{}));
auto g_tensor = make_tensor(
g_offset.data(),
make_layout(
cute::make_shape(actual_token, size<2>(mB)),
g_offset.stride()
));
Tensor gB = local_tile(g_tensor, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
return gB;
}
template <typename SharedStorage>
CUTLASS_DEVICE void
load(Params const& mainloop_params,
MainloopPipeline pipeline,
PipelineState& smem_pipe_write,
SharedStorage &shared_storage,
const int tokens,
const int pre_fix_tokens,
const int bidm,
const int bidn,
const int bidb,
const int tidx) {
Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{});
Tensor mA = mainloop_params.tma_load_A.get_tma_tensor(mainloop_params.layout_A.shape());
Tensor mB = mainloop_params.tma_load_B.get_tma_tensor(mainloop_params.layout_B.shape());
Tensor gA = local_tile(mA(_, _, bidb), select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}), make_coord(bidm, _));
auto [tAgA, tAsA] = tma_partition(mainloop_params.tma_load_A, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sA), group_modes<0, 2>(gA));
const int kIters = kTiles / kStages;
if constexpr (TokenPackSize == 0) {
Tensor gB = get_local_no_packed_tensor(
mB,
pre_fix_tokens,
tokens,
bidn);
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sB), group_modes<0, 2>(gB));
if (tidx == 0) {
#pragma unroll
for (int kiter = 0; kiter < kIters; ++kiter) {
#pragma unroll
for (int s = 0; s < kStages; s++) {
const int i = kiter * kStages + s;
pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
++smem_pipe_write;
}
}
#pragma unroll
for (int i = kIters * kStages; i < kTiles; ++i) {
pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
++smem_pipe_write;
}
}
} else {
auto mB_this_batch = make_tensor(
mB(_, _, bidb).data(),
make_layout(
cute::make_shape(tokens, size<1>(mB)),
mB.stride()
));
Tensor gB = local_tile(mB_this_batch, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sB), group_modes<0, 2>(gB));
if (tidx == 0) {
#pragma unroll
for (int kiter = 0; kiter < kIters; ++kiter) {
#pragma unroll
for (int s = 0; s < kStages; s++) {
const int i = kiter * kStages + s;
pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
++smem_pipe_write;
}
}
#pragma unroll
for (int i = kIters * kStages; i < kTiles; ++i) {
pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
++smem_pipe_write;
}
}
}
}
template <int CUR_N, typename SharedStorage, typename FrgTensorO, typename TiledMma>
CUTLASS_DEVICE void
mma(Params const& mainloop_params,
TiledMma tiled_mma,
MainloopPipeline pipeline,
PipelineState& smem_pipe_read,
SharedStorage& shared_storage,
FrgTensorO &tSrS,
const int tidx) {
using sMemBLayout = std::conditional_t<
CUR_N == kBlockN,
SmemLayoutB,
SmemLayoutB_TAIL
>;
Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), sMemBLayout{});
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
auto threadMma = tiled_mma.get_thread_slice(tidx);
auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomAB{}, tiled_mma);
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(tidx);
Tensor tSrA = threadMma.partition_fragment_A(sA(_, _, 0));
Tensor tSrB = threadMma.partition_fragment_B(sB);
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
};
const int kIters = kTiles / kStages;
constexpr int B_STEPS = CUR_N == 0 ? 1 : (kBlockN / CUR_N);
#pragma unroll
for (int kiter = 0; kiter < kIters; ++kiter) {
#pragma unroll
for (int s = 0; s < kStages; s++) {
Tensor tSsA = smem_thr_copy_A.partition_S(sA(_, _, s));
consumer_wait(pipeline, smem_pipe_read);
gemm</*wg_wait=*/0>(tiled_mma, tSrA, tSsA, tSrB(_, _, _, s * B_STEPS), tSrS, smem_tiled_copy_A, smem_thr_copy_A);
pipeline.consumer_release(smem_pipe_read);
++smem_pipe_read;
}
}
#pragma unroll
for (int i = 0; i < kTiles % kStages; ++i) {
Tensor tSsA = smem_thr_copy_A.partition_S(sA(_, _, i));
consumer_wait(pipeline, smem_pipe_read);
gemm</*wg_wait=*/0>(tiled_mma, tSrA, tSsA, tSrB(_, _, _, i * B_STEPS), tSrS, smem_tiled_copy_A, smem_thr_copy_A);
pipeline.consumer_release(smem_pipe_read);
++smem_pipe_read;
}
}
};

View File

@@ -0,0 +1,114 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <cuda_fp16.h>
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif
#include <cute/tensor.hpp>
#include <cute/arch/cluster_sm90.hpp> // For cute::elect_one_sync()
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
using namespace cute;
template<typename T>
struct PackedHalf;
template<>
struct PackedHalf<cutlass::half_t> {
using Type = __half2;
};
template<>
struct PackedHalf<cutlass::bfloat16_t> {
using Type = nv_bfloat162;
};
template <typename To_type, typename Engine, typename Layout>
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
template <int numel>
__forceinline__ __device__ void convert_c4_2_fp8(const int32_t * src, int32_t * dst1, int32_t * dst2) {
#pragma unroll
for (int i = 0; i < numel; ++i) {
dst1[i] = (src[i] >> 4) & 0x0f0f0f0f;
dst2[i] = src[i] & 0x0f0f0f0f;
}
}
template <int wg_wait=0, bool arrive=true,
bool commit=true, typename Tensor0, typename Tensor1,
typename Tensor2, typename Tensor3, typename TiledMma,
typename ThrCopyA, typename TiledCopyA>
__forceinline__ __device__ void gemm(
TiledMma &tiled_mma,
Tensor0 &tCrA,
Tensor1 &tCsA,
Tensor2 const &tCrB,
Tensor3 &tCrC,
TiledCopyA const &tiled_copy_A,
ThrCopyA const &thr_copy_A) {
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
Tensor tCrA1 = make_tensor<cutlass::float_e4m3_t>(tCrA.layout());
Tensor tCrA2 = make_tensor<cutlass::float_e4m3_t>(tCrA.layout());
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(tCrC);
if constexpr (arrive) {
warpgroup_arrive();
}
constexpr int numel = decltype(size(tCrA(_, _, 0)))::value / 4;
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
cute::copy(tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{}));
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
if (k_block < size<2>(tCrA) - 1) {
cute::copy(tiled_copy_A, tCsA(_, _, k_block + 1), tCrA_copy_view(_, _, k_block + 1));
}
int32_t * tCrA_data = reinterpret_cast<int32_t *>(tCrA(_,_,k_block).data());
int32_t * tCrA1_data = reinterpret_cast<int32_t *>(tCrA1(_,_,k_block).data());
int32_t * tCrA2_data = reinterpret_cast<int32_t *>(tCrA2(_,_,k_block).data());
convert_c4_2_fp8<numel>(tCrA_data, tCrA1_data, tCrA2_data);
cute::gemm(tiled_mma, tCrA1(_,_,k_block), tCrB(_,_,2 * k_block), tCrC);
cute::gemm(tiled_mma, tCrA2(_,_,k_block), tCrB(_,_, 2 * k_block + 1), tCrC);
}
if constexpr (commit) {
warpgroup_commit_batch();
}
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
}

View File

@@ -0,0 +1,213 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
#include "helper.h"
#include "paddle/extension.h"
#include "w4afp8_gemm_template.h"
void weight_convert(const uint8_t *weight, uint8_t *weight_new, int batch, int M, int K) {
assert(K % 64 == 0);
for (int b = 0; b < batch; ++b) {
for (int m = 0; m < M; ++m) {
for (int k = 0; k < K; k+=64) {
for (int k_inner = 0; k_inner < 32; ++k_inner) {
uint8_t temp = 0;
uint8_t left = weight[b * M * K + m * K + k + k_inner];
uint8_t right = weight[b * M * K + m * K + k + k_inner + 32];
temp |= left << 4;
temp |= right;
weight_new[b * M * K / 2 + m * K / 2 + k / 2 + k_inner] = *reinterpret_cast<uint8_t*>(&temp);
}
}
}
}
}
template <typename OutputType>
void DisPatchW4AFp8Gemm(
const cutlass::float_e4m3_t* input,
const cutlass::float_e4m3_t* weight,
const int * tokens,
const float * input_row_sum,
const float * weight_scale,
OutputType * out,
const int token_padding_size,
const int max_tokens,
const int batch_size,
const int M,
const int K,
cudaStream_t stream) {
int kBlockN = (max_tokens + 15) / 16 * 16;
int TailN = 0;
if (kBlockN > 256) {
TailN = kBlockN % 256;
kBlockN = 256;
}
if constexpr (std::is_same_v<OutputType, cutlass::bfloat16_t>) {
GEMM_SWITCH_BF16(
M, K, batch_size, token_padding_size, kBlockN, TailN,
weight,
input,
out,
weight_scale,
input_row_sum,
tokens,
max_tokens,
stream)
} else {
GEMM_SWITCH_FP16(
M, K, batch_size, token_padding_size, kBlockN, TailN,
weight,
input,
out,
weight_scale,
input_row_sum,
tokens,
max_tokens,
stream)
}
}
std::vector<paddle::Tensor> W4AFp8Gemm(
const paddle::Tensor& input,
const paddle::Tensor& weight,
const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group
const paddle::Tensor& input_row_sum,
const paddle::Tensor& weight_scale,
const int token_padding_size,
const int max_tokens,
const bool is_bflot16) {
const int batch_size = weight.dims()[0];
const int M = weight.dims()[1];
const int K = weight.dims()[2] * 2;
if (input.dtype() != paddle::DataType::FLOAT8_E4M3FN) {
PD_THROW("Only supported dtype in ['FLOAT8_E4M3FN'].");
}
if (token_padding_size == 0) {
const int all_tokens = input.dims()[0];
if (is_bflot16) {
paddle::Tensor out = paddle::empty({all_tokens, M}, paddle::DataType::BFLOAT16, input.place());
phi::dtype::bfloat16 *out_data = out.data<phi::dtype::bfloat16>();
DisPatchW4AFp8Gemm(
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
tokens.data<int>(),
input_row_sum.data<float>(),
weight_scale.data<float>(),
reinterpret_cast<cutlass::bfloat16_t*>(out_data),
token_padding_size,
max_tokens,
batch_size,
M,
K,
input.stream());
return {out};
} else {
paddle::Tensor out = paddle::empty({all_tokens, M}, paddle::DataType::FLOAT16, input.place());
phi::dtype::float16 *out_data = out.data<phi::dtype::float16>();
DisPatchW4AFp8Gemm(
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
tokens.data<int>(),
input_row_sum.data<float>(),
weight_scale.data<float>(),
reinterpret_cast<cutlass::half_t*>(out_data),
token_padding_size,
max_tokens,
batch_size,
M,
K,
input.stream());
return {out};
}
} else {
if (is_bflot16) {
paddle::Tensor out = paddle::empty({batch_size, token_padding_size, M}, paddle::DataType::BFLOAT16, input.place());
phi::dtype::bfloat16 * out_data = out.data<phi::dtype::bfloat16>();
DisPatchW4AFp8Gemm(
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
tokens.data<int>(),
input_row_sum.data<float>(),
weight_scale.data<float>(),
reinterpret_cast<cutlass::bfloat16_t*>(out_data),
token_padding_size,
max_tokens,
batch_size,
M,
K,
input.stream());
return {out};
} else {
paddle::Tensor out = paddle::empty({batch_size, token_padding_size, M}, paddle::DataType::FLOAT16, input.place());
phi::dtype::float16 * out_data = out.data<phi::dtype::float16>();
DisPatchW4AFp8Gemm(
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
tokens.data<int>(),
input_row_sum.data<float>(),
weight_scale.data<float>(),
reinterpret_cast<cutlass::half_t*>(out_data),
token_padding_size,
max_tokens,
batch_size,
M,
K,
input.stream());
return {out};
}
}
}
std::vector<paddle::Tensor> W4AFp8GemmWeightConvert(const paddle::Tensor& weight) {
const int batch_size = weight.dims()[0];
const int M = weight.dims()[1];
const int K = weight.dims()[2];
paddle::Tensor weight_new = paddle::empty({batch_size, M, K / 2}, paddle::DataType::UINT8, weight.place());
weight_convert(weight.data<uint8_t>(), weight_new.data<uint8_t>(), batch_size, M, K);
return {weight_new};
}
PD_BUILD_STATIC_OP(w4afp8_gemm)
.Inputs({"input",
"weight",
"tokens",
"input_row_sum",
"weight_scale"})
.Outputs({"out"})
.Attrs({"token_padding_size: int",
"max_tokens: int",
"is_bflot16: bool"})
.SetKernelFn(PD_KERNEL(W4AFp8Gemm));
PD_BUILD_STATIC_OP(w4afp8_gemm_weight_convert)
.Inputs({"weight"})
.Outputs({"converted_weight"})
.SetKernelFn(PD_KERNEL(W4AFp8GemmWeightConvert));

View File

@@ -0,0 +1,252 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "cute/atom/mma_atom.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/cluster_launch.hpp"
#include "cutlass/arch/reg_reconfig.h"
#include "kernel_traits.h"
#include "mainloop_fwd.h"
template <typename Ktraits>
void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) w4afp8_geem_kernel(
CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits>::Params const mainloop_params) {
using Element = typename Ktraits::Element;
static_assert(cutlass::sizeof_bits_v<Element> == 8);
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using TileShape_MNK_TAIL = typename Ktraits::TileShape_MNK_TAIL;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int M = Ktraits::M;
static constexpr int TokenPackSize = Ktraits::TokenPackSize;
static constexpr int TAIL_N = Ktraits::TAIL_N;
using CollectiveMainloop = CollectiveMainloopFwd<Ktraits>;
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
using ElementOutput = typename Ktraits::ElementOutput;
extern __shared__ char shared_memory[];
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
const int bidm = blockIdx.x;
const int bidn = blockIdx.y;
const int bidb = blockIdx.z;
const int tidx = threadIdx.x;
if (tidx == 0) {
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
}
// Obtain warp index
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesA + CollectiveMainloop::TmaTransactionBytesB;
int warp_group_idx = cutlass::canonical_warp_group_idx();
pipeline_params.role = warp_group_idx == 0
? MainloopPipeline::ThreadCategory::Producer
: MainloopPipeline::ThreadCategory::Consumer;
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NumMmaThreads;
MainloopPipeline pipeline(shared_storage.pipeline, pipeline_params, ClusterShape{});
CollectiveMainloop collective_mainloop;
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
cute::cluster_wait();
} else {
__syncthreads();
}
const int pre_fix_tokens = TokenPackSize == 0 ? mainloop_params.tokens[bidb] : 0;
const int tokens = TokenPackSize == 0 ? mainloop_params.tokens[bidb + 1] - pre_fix_tokens : mainloop_params.tokens[bidb];
if (bidn * kBlockN >= tokens) {
return;
}
float* input_row_sum = reinterpret_cast<float*>(
shared_memory + sizeof(typename Ktraits::SharedStorage));
if (warp_group_idx == 0) {
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 12 ? 40 : 32>();
PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
collective_mainloop.load(
mainloop_params,
pipeline,
smem_pipe_write,
shared_storage,
tokens,
pre_fix_tokens,
bidm,
bidn,
bidb,
tidx);
} else {
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 232 : 160>();
PipelineState smem_pipe_read;
typename Ktraits::TiledMma tiled_mma;
typename Ktraits::TiledMma_TAIL tiled_mma_tail;
const int mma_tidx = tidx - NumCopyThreads;
const int lane_id = mma_tidx % 4 * 2;
const float2 weight_scale = reinterpret_cast<const float2*>(mainloop_params.weight_scale + bidb * M + bidm * kBlockM)[mma_tidx / 4];
if constexpr (TokenPackSize == 0) {
const int input_sum_idx = pre_fix_tokens + bidn * kBlockN;
if (mma_tidx < kBlockN) {
reinterpret_cast<float*>(input_row_sum)[mma_tidx] = reinterpret_cast<const float*>(mainloop_params.input_row_sum + input_sum_idx)[mma_tidx];
}
} else {
const int input_sum_idx = bidb * TokenPackSize + bidn * kBlockN;
if (mma_tidx < kBlockN / 4) {
reinterpret_cast<float4*>(input_row_sum)[mma_tidx] = reinterpret_cast<const float4*>(mainloop_params.input_row_sum + input_sum_idx)[mma_tidx];
}
}
const int reamin_tokens = tokens - bidn * kBlockN;
if (TAIL_N > 0 && reamin_tokens < kBlockN) {
Tensor tSrS_tail = partition_fragment_C(tiled_mma_tail, select<0, 1>(TileShape_MNK_TAIL{}));
collective_mainloop.mma<TAIL_N>(
mainloop_params,
tiled_mma_tail,
pipeline,
smem_pipe_read,
shared_storage,
tSrS_tail,
mma_tidx);
collective_mainloop.store<TAIL_N>(
mainloop_params,
tSrS_tail,
shared_storage,
tiled_mma_tail,
input_row_sum + lane_id,
reinterpret_cast<const float*>(&weight_scale),
tokens,
pre_fix_tokens,
bidm,
bidn,
bidb,
mma_tidx);
} else {
Tensor tSrS = partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{}));
collective_mainloop.mma<kBlockN>(
mainloop_params,
tiled_mma,
pipeline,
smem_pipe_read,
shared_storage,
tSrS,
mma_tidx);
collective_mainloop.store<kBlockN>(
mainloop_params,
tSrS,
shared_storage,
tiled_mma,
input_row_sum + lane_id,
reinterpret_cast<const float*>(&weight_scale),
tokens,
pre_fix_tokens,
bidm,
bidn,
bidb,
mma_tidx);
}
}
}
template <int Batch>
auto get_gmem_layout(const int Rows, const int Cols) {
return make_layout(
make_shape(
static_cast<int64_t>(Rows),
static_cast<int64_t>(Cols),
static_cast<int64_t>(Batch)),
make_stride(
static_cast<int64_t>(Cols),
cute::_1{},
static_cast<int64_t>(Rows * Cols)));
}
template <typename InputType, typename OutputType, typename Kernel_traits, int M, int K, int Batch, int TokenPackSize>
void run_gemm(const InputType * A, const InputType * B, OutputType * C, const float *weight_scale,
const float *input_row_sum, const int * tokens, const int max_tokens, cudaStream_t stream) {
using ElementOutput = typename Kernel_traits::ElementOutput;
using Element = typename Kernel_traits::Element;
using CollectiveMainloop = CollectiveMainloopFwd<Kernel_traits>;
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
constexpr int M_nums = (M + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
const int N_nums = (max_tokens + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
typename CollectiveMainloop::Params mainloop_params =
CollectiveMainloop::to_underlying_arguments({
static_cast<Element const*>(A),
get_gmem_layout<Batch>(M, K / 2),
static_cast<Element const*>(B),
get_gmem_layout<Batch>(TokenPackSize == 0 ? max_tokens * Batch : TokenPackSize, K),
static_cast<ElementOutput*>(C),
get_gmem_layout<Batch>(M, TokenPackSize == 0 ? max_tokens : TokenPackSize),
weight_scale,
input_row_sum,
tokens
});
void *kernel;
kernel = (void *)w4afp8_geem_kernel<Kernel_traits>;
int smem_size = sizeof(typename Kernel_traits::SharedStorage) + sizeof(float) * Kernel_traits::kBlockN;
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
dim3 grid_dims;
grid_dims.x = M_nums;
grid_dims.y = N_nums;
grid_dims.z = Batch;
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
dim3 block_dims(ctaSize);
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
cutlass::launch_kernel_on_cluster(
launch_params, kernel, mainloop_params);
}

View File

@@ -494,6 +494,8 @@ elif paddle.is_compiled_with_cuda():
if cc >= 90 and nvcc_version >= 12.0:
# Hopper optmized mla
sources += find_end_files("gpu_ops/mla_attn", ".cu")
os.system("python utils/auto_gen_w4afp8_gemm_kernel.py")
sources += find_end_files("gpu_ops/w4afp8_gemm", ".cu")
setup(
name="fastdeploy_ops",

View File

@@ -0,0 +1,207 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
file_dir = "./gpu_ops/w4afp8_gemm/"
gemm_template_head = """
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <cuda_fp16.h>
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif
#include <cute/tensor.hpp>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
"""
gemm_template_case = """
void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
const cutlass::float_e4m3_t * weight,
const cutlass::float_e4m3_t * input,
{cutlass_type} * out,
const float *weight_scale,
const float *input_row_sum,
const int *tokens,
const int max_tokens,
cudaStream_t stream);
"""
gemm_template_cu_head = """
#include "paddle/extension.h"
#include "w4afp8_gemm_template.h"
#include "w4afp8_gemm_kernel.hpp"
"""
gemm_template_cu_template = """
void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
const cutlass::float_e4m3_t * weight,
const cutlass::float_e4m3_t * input,
{cutlass_type} * out,
const float *weight_scale,
const float *input_row_sum,
const int *tokens,
const int max_tokens,
cudaStream_t stream) {{
constexpr static int M = {M};
constexpr static int K = {K};
constexpr static int Batch = {BATCH};
constexpr static int TokenPackSize = {PADDING};
constexpr static int kBlockN = {N};
constexpr static int kBlockN_TAIL = {TAILN};
constexpr static int kBlockM = 128;
constexpr static int kBlockK = 128;
constexpr static int kNWarps = 4 + kBlockM / 16;
constexpr static int kStages = 5;
constexpr int kCluster = 1;
static_assert(K % kBlockK == 0);
constexpr int kTiles = K / kBlockK;
using Kernel_traits = Kernel_traits<
kBlockM, kBlockN, kBlockK, kNWarps, kStages, kTiles,
M, TokenPackSize, kBlockN_TAIL, kCluster, cutlass::float_e4m3_t,
{cutlass_type}>;
run_gemm<cutlass::float_e4m3_t, {cutlass_type},
Kernel_traits, M, K, Batch, TokenPackSize>
(weight, input, out, weight_scale,
input_row_sum, tokens, max_tokens, stream);
}}
"""
gemm_case = [
[8192, 3584, 8, 0], # eb45T ffn1
[8192, 3584, 8, 2048], # eb45T ffn1
[7168, 8192, 8, 0], # eb45T ffn2
[7168, 8192, 8, 2048], # eb45T ffn2
]
dtype = ["BF16", "FP16"]
def get_cutlass_type(type):
if type == "BF16":
return "cutlass::bfloat16_t"
elif type == "FP16":
return "cutlass::half_t"
template_head_file = open(f"{file_dir}w4afp8_gemm_template.h", "w")
template_head_file.write(gemm_template_head)
for type in dtype:
for case in gemm_case:
for n in range(16, 257, 16):
template_head_file.write(
gemm_template_case.format(
M=case[0],
K=case[1],
N=n,
BATCH=case[2],
TYPE=type,
PADDING=case[3],
TAILN=0,
cutlass_type=get_cutlass_type(type),
)
)
template_head_file.write(
gemm_template_case.format(
M=case[0],
K=case[1],
N=256,
BATCH=case[2],
TYPE=type,
PADDING=case[3],
TAILN=n - 16,
cutlass_type=get_cutlass_type(type),
)
)
template_cu_file = open(
f"{file_dir}w4afp8_gemm_M{case[0]}_N{n}_TAILN{0}_K{case[1]}_B{case[2]}_P{case[3]}_{type}.cu", "w"
)
template_cu_file.write(gemm_template_cu_head)
template_cu_file.write(
gemm_template_cu_template.format(
M=case[0],
K=case[1],
N=n,
BATCH=case[2],
TYPE=type,
PADDING=case[3],
TAILN=0,
cutlass_type=get_cutlass_type(type),
)
)
template_cu_file.close()
template_cu_file = open(
f"{file_dir}w4afp8_gemm_M{case[0]}_N{256}_TAILN{n-16}_K{case[1]}_B{case[2]}_P{case[3]}_{type}.cu", "w"
)
template_cu_file.write(gemm_template_cu_head)
template_cu_file.write(
gemm_template_cu_template.format(
M=case[0],
K=case[1],
N=256,
BATCH=case[2],
TYPE=type,
PADDING=case[3],
TAILN=n - 16,
cutlass_type=get_cutlass_type(type),
)
)
template_cu_file.close()
for type in dtype:
template_head_file.write("\n")
template_head_file.write(
"""#define GEMM_SWITCH_{TYPE}(_M, _K, _BATCH, _TokenPaddingSize, _kBlockN, _TailN, ...) {{ \\
if (_M == 0 && _K == 0 && _BATCH == 0 && _TokenPaddingSize == 0 && _kBlockN == 0 && _TailN == 0) {{ \\""".format(
TYPE=type
)
)
template_head_file.write("\n")
for case in gemm_case:
for n in range(16, 257, 16):
template_head_file.write(
""" }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\
w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(
M=case[0], K=case[1], N=n, BATCH=case[2], TYPE=type, PADDING=case[3], TAILN=0
)
)
template_head_file.write("\n")
template_head_file.write(
""" }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\
w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(
M=case[0], K=case[1], N=256, BATCH=case[2], TYPE=type, PADDING=case[3], TAILN=n - 16
)
)
template_head_file.write("\n")
template_head_file.write(
""" } else { \\
PADDLE_THROW(phi::errors::Unimplemented("W4aFp8 not supported m=%d k=%d batch=%d token_padding_size=%d kBlockN=%d tailN=%d\\n", _M, _K, _BATCH, _TokenPaddingSize, _kBlockN, _TailN)); \\
} \\
}"""
)
template_head_file.close()

View File

@@ -0,0 +1,103 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm, w4afp8_gemm_weight_convert
def w4afp8_gemm_naive(input_bf16, weight_quant, tokens, weight_dequant_scale, BATCH, N):
all_tokens = int(tokens.sum())
out = paddle.zeros([all_tokens, N], dtype="bfloat16")
pre_fix_token = 0
for i in range(BATCH):
input = input_bf16[pre_fix_token : pre_fix_token + tokens[i], :]
weight = (weight_quant[i] - 7.0) * weight_dequant_scale[i]
out_i = paddle.matmul(input, weight.astype("bfloat16"), transpose_y=True)
out[pre_fix_token : pre_fix_token + tokens[i], :] = out_i
pre_fix_token += tokens[i]
return out
def peruate_scale(weight_scale):
weight_scale = weight_scale.reshape([BATCH, N])
temp = paddle.zeros([16])
for b in range(BATCH):
for n in range(0, N, 16):
temp[:] = weight_scale[b, n : n + 16]
for j in range(0, 16, 2):
weight_scale[b, n + j] = temp[j // 2]
weight_scale[b, n + j + 1] = temp[j // 2 + 8]
return weight_scale
paddle.seed(0)
tokens_per_group = 32
N = 8192
K = 3584
BATCH = 8
TokenPadding = 0
tokens = [tokens_per_group] * BATCH
tokens_perfix_sum = np.cumsum(tokens)
tokens_perfix_sum = np.insert(tokens_perfix_sum, 0, 0)
tokens = paddle.to_tensor(tokens, dtype="int32")
tokens_perfix_sum = paddle.to_tensor(tokens_perfix_sum, dtype="int32")
all_tokens = int(tokens.sum())
input_fp8 = paddle.randn([all_tokens, K], dtype="bfloat16").astype(paddle.float8_e4m3fn)
input_bf16 = input_fp8.astype("bfloat16")
weight = paddle.randn([BATCH, N, K], dtype="bfloat16") / 10
weight_scale = 7 / weight.abs().max(axis=-1).reshape([BATCH, N, 1])
weight_quant = (weight * weight_scale).astype("int") + 7
weight_quant = paddle.clip(weight_quant, 0, 14)
weight_quant = weight_quant.astype("bfloat16")
weight_dequant_scale = 1 / weight_scale.astype("float32")
input_row_sum = input_bf16.sum(axis=1) * -7 / 512
max_tokens = int(tokens.max())
out_naive = w4afp8_gemm_naive(input_bf16, weight_quant, tokens, weight_dequant_scale, BATCH, N)
weight_dequant_scale = paddle.to_tensor(peruate_scale(weight_dequant_scale) * 512)
weight_int4 = w4afp8_gemm_weight_convert(weight_quant.astype("uint8").cpu())
if TokenPadding == 0:
out_cuda = w4afp8_gemm(
input_fp8,
weight_int4.cuda(),
tokens_perfix_sum,
input_row_sum.astype("float32"),
weight_dequant_scale.astype("float32"),
int(TokenPadding),
max_tokens,
True,
)
else:
out_cuda = w4afp8_gemm(
input_fp8,
weight_int4.cuda(),
tokens,
input_row_sum.astype("float32"),
weight_dequant_scale.astype("float32"),
int(TokenPadding),
max_tokens,
True,
)
gap = (out_cuda - out_naive).abs()
assert float(gap.mean()) < 0.07