mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
支持w4afp8 (#3324)
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -164,3 +164,7 @@ build
|
||||
.ccls-cache
|
||||
|
||||
third_party
|
||||
|
||||
|
||||
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_*.cu
|
||||
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_template.h
|
154
custom_ops/gpu_ops/w4afp8_gemm/kernel_traits.h
Normal file
154
custom_ops/gpu_ops/w4afp8_gemm/kernel_traits.h
Normal file
@@ -0,0 +1,154 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "cute/algorithm/copy.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <int kStages, class GemmType, class OutputType, class SmemLayoutA,
|
||||
class SmemLayoutB, class SmemLayoutC>
|
||||
struct SharedStorage {
|
||||
union {
|
||||
struct {
|
||||
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutA>> smem_a;
|
||||
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutB>> smem_b;
|
||||
};
|
||||
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutC>> smem_c;
|
||||
};
|
||||
|
||||
struct {
|
||||
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline;
|
||||
};
|
||||
};
|
||||
|
||||
template<int kBlockM_, int kBlockN_, int kBlockK_,
|
||||
int kNWarps_, int kStages_,
|
||||
int kTiles_, int M_,
|
||||
int TokenPackSize_,
|
||||
int TAIL_N_ = 0,
|
||||
int kClusterM_ = 1,
|
||||
typename elem_type=cutlass::float_e4m3_t,
|
||||
typename OutputType = cutlass::bfloat16_t>
|
||||
struct Kernel_traits {
|
||||
using Element = elem_type;
|
||||
using ElementAccum = float;
|
||||
using ElementOutput = OutputType;
|
||||
static_assert(cutlass::sizeof_bits_v<Element> == 8);
|
||||
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
|
||||
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
|
||||
|
||||
static_assert(kNWarps_ == 12 || kNWarps_ == 16);
|
||||
|
||||
static constexpr int kBlockM = kBlockM_;
|
||||
static constexpr int kBlockN = kBlockN_;
|
||||
static constexpr int kBlockK = kBlockK_;
|
||||
static constexpr int kTiles = kTiles_;
|
||||
static constexpr int TokenPackSize = TokenPackSize_;
|
||||
static constexpr int M = M_;
|
||||
static constexpr int TAIL_N = TAIL_N_;
|
||||
|
||||
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kBlockK>>;
|
||||
using TileShape_MNK_TAIL = Shape<Int<kBlockM>, Int<TAIL_N>, Int<kBlockK>>;
|
||||
|
||||
static constexpr int kClusterM = kClusterM_;
|
||||
using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
|
||||
|
||||
static constexpr int kStages = kStages_;
|
||||
static_assert(kStages > 1);
|
||||
|
||||
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
using TiledMma_TAIL = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK_TAIL>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
using SmemLayoutAtomA = decltype(
|
||||
cutlass::gemm::collective::detail::rs_smem_selector<
|
||||
GMMA::Major::K, Element, Int<kBlockM>, Int<kBlockK / 2>>());
|
||||
|
||||
using SmemLayoutA = decltype(
|
||||
tile_to_shape(SmemLayoutAtomA{},
|
||||
make_shape(Int<kBlockM>{}, Int<kBlockK / 2>{}, Int<kStages>{})));
|
||||
|
||||
using SmemLayoutAtomB = decltype(
|
||||
cutlass::gemm::collective::detail::rs_smem_selector<
|
||||
GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
using SmemLayoutB = decltype(
|
||||
tile_to_shape(SmemLayoutAtomB{},
|
||||
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
|
||||
|
||||
using SmemLayoutAtomB_TAIL = decltype(
|
||||
cutlass::gemm::collective::detail::rs_smem_selector<
|
||||
GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK_TAIL{})),
|
||||
decltype(cute::get<2>(TileShape_MNK_TAIL{}))>());
|
||||
|
||||
using SmemLayoutB_TAIL = decltype(
|
||||
tile_to_shape(SmemLayoutAtomB_TAIL{},
|
||||
make_shape(
|
||||
shape<1>(TileShape_MNK_TAIL{}),
|
||||
shape<2>(TileShape_MNK_TAIL{}),
|
||||
Int<kStages>{})
|
||||
));
|
||||
|
||||
using SmemLayoutAtomC = decltype(
|
||||
cutlass::gemm::collective::detail::rs_smem_selector<
|
||||
GMMA::Major::K, ElementOutput,
|
||||
decltype(cute::get<0>(TileShape_MNK{})),
|
||||
decltype(cute::get<1>(TileShape_MNK{}))>());
|
||||
|
||||
using SmemLayoutC = decltype(tile_to_shape(SmemLayoutAtomC{}, select<0, 1>(TileShape_MNK{})));
|
||||
|
||||
using SmemCopyAtomAB = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
|
||||
using SmemCopyAtomC = Copy_Atom<cute::SM90_U32x4_STSM_N, ElementOutput>;
|
||||
|
||||
using SharedStorage = SharedStorage<
|
||||
kStages, Element, ElementOutput, SmemLayoutA, SmemLayoutB, SmemLayoutC>;
|
||||
|
||||
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
|
||||
using PipelineState = typename cutlass::PipelineState<kStages>;
|
||||
|
||||
|
||||
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<OutputType>);
|
||||
static constexpr int kNumThreadsPerRow = kBlockN / kNumVecElem;
|
||||
// static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
|
||||
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
|
||||
using TiledCopyCAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, OutputType>;
|
||||
using TiledCopyCThrLayout = decltype(cute::make_layout(
|
||||
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
|
||||
LayoutRight{}));
|
||||
using TiledCopyCValLayout = decltype(cute::make_layout(
|
||||
cute::make_shape(_1{}, Int<kNumVecElem>{}),
|
||||
LayoutRight{}));
|
||||
using TiledCopyC = decltype(make_tiled_copy(
|
||||
TiledCopyCAtom{},
|
||||
TiledCopyCThrLayout{}, // Thr layout
|
||||
TiledCopyCValLayout{} // Val layout
|
||||
));
|
||||
};
|
405
custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h
Normal file
405
custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h
Normal file
@@ -0,0 +1,405 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
// #include "named_barrier.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
|
||||
using namespace cute;
|
||||
template <typename Ktraits>
|
||||
struct CollectiveMainloopFwd {
|
||||
|
||||
using Element = typename Ktraits::Element;
|
||||
using ElementOutput = typename Ktraits::ElementOutput;
|
||||
using TileShape_MNK = typename Ktraits::TileShape_MNK;
|
||||
using TileShape_MNK_TAIL = typename Ktraits::TileShape_MNK_TAIL;
|
||||
using ClusterShape = typename Ktraits::ClusterShape_MNK;
|
||||
using ElementAccum = typename Ktraits::ElementAccum;
|
||||
|
||||
static constexpr int kStages = Ktraits::kStages;
|
||||
static constexpr int kBlockM = Ktraits::kBlockM;
|
||||
static constexpr int kBlockN = Ktraits::kBlockN;
|
||||
static constexpr int TAIL_N = Ktraits::TAIL_N;
|
||||
static constexpr int kBlockK = Ktraits::kBlockK;
|
||||
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr int kTiles = Ktraits::kTiles;
|
||||
static constexpr int M = Ktraits::M;
|
||||
static constexpr int TokenPackSize = Ktraits::TokenPackSize;
|
||||
|
||||
using GmemTiledCopy = cute::SM90_TMA_LOAD;
|
||||
|
||||
|
||||
using SmemLayoutA = typename Ktraits::SmemLayoutA;
|
||||
using SmemLayoutB = typename Ktraits::SmemLayoutB;
|
||||
using SmemLayoutC = typename Ktraits::SmemLayoutC;
|
||||
using SmemLayoutB_TAIL = typename Ktraits::SmemLayoutB_TAIL;
|
||||
|
||||
using ShapeT = cute::Shape<int64_t, int64_t, int64_t>;
|
||||
using StrideT = cute::Shape<int64_t, _1, int64_t>;
|
||||
using LayoutT = cute::Layout<ShapeT, StrideT>;
|
||||
|
||||
using TMA_A = decltype(make_tma_copy(
|
||||
GmemTiledCopy{},
|
||||
make_tensor(
|
||||
make_gmem_ptr(static_cast<Element const*>(nullptr)),
|
||||
ShapeT{},
|
||||
StrideT{}
|
||||
),
|
||||
SmemLayoutA{}(_, _, _0{}),
|
||||
select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}),
|
||||
size<0>(ClusterShape{})));
|
||||
|
||||
using TMA_B = decltype(make_tma_copy(
|
||||
GmemTiledCopy{},
|
||||
make_tensor(
|
||||
make_gmem_ptr(static_cast<Element const*>(nullptr)),
|
||||
ShapeT{},
|
||||
StrideT{}
|
||||
),
|
||||
take<0, 2>(SmemLayoutB{}),
|
||||
select<1, 2>(TileShape_MNK{}),
|
||||
size<0>(ClusterShape{})));
|
||||
|
||||
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
|
||||
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
using SmemCopyAtomAB = typename Ktraits::SmemCopyAtomAB;
|
||||
using SmemCopyAtomC = typename Ktraits::SmemCopyAtomC;
|
||||
using TiledCopyC = typename Ktraits::TiledCopyC;
|
||||
|
||||
static constexpr uint32_t TmaTransactionBytesA = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutA{})) * cutlass::sizeof_bits_v<Element> / 8);
|
||||
static constexpr uint32_t TmaTransactionBytesB = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutB{})) * cutlass::sizeof_bits_v<Element> / 8);
|
||||
|
||||
struct Arguments {
|
||||
Element const* ptr_A;
|
||||
LayoutT layout_A;
|
||||
Element const* ptr_B;
|
||||
LayoutT layout_B;
|
||||
ElementOutput * ptr_C;
|
||||
LayoutT layout_C;
|
||||
const float *weight_scale;
|
||||
const float *input_row_sum;
|
||||
const int * tokens;
|
||||
};
|
||||
|
||||
struct Params {
|
||||
LayoutT layout_A;
|
||||
LayoutT layout_B;
|
||||
TMA_A tma_load_A;
|
||||
TMA_B tma_load_B;
|
||||
ElementOutput * ptr_C;
|
||||
const float *weight_scale;
|
||||
const float *input_row_sum;
|
||||
const int * tokens;
|
||||
};
|
||||
|
||||
|
||||
Params static
|
||||
to_underlying_arguments(Arguments const& args) {
|
||||
Tensor mA = make_tensor(make_gmem_ptr(args.ptr_A), args.layout_A);
|
||||
TMA_A tma_load_A = make_tma_copy(
|
||||
GmemTiledCopy{},
|
||||
mA,
|
||||
SmemLayoutA{}(_, _, _0{}),
|
||||
select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}),
|
||||
size<0>(ClusterShape{}));
|
||||
Tensor mB = make_tensor(make_gmem_ptr(args.ptr_B), args.layout_B);
|
||||
TMA_B tma_load_B = make_tma_copy(
|
||||
GmemTiledCopy{},
|
||||
mB,
|
||||
SmemLayoutB{}(_, _, _0{}),
|
||||
select<1, 2>(TileShape_MNK{}),
|
||||
size<0>(ClusterShape{}));
|
||||
|
||||
return {args.layout_A, args.layout_B, tma_load_A, tma_load_B,
|
||||
args.ptr_C, args.weight_scale, args.input_row_sum, args.tokens};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params) {
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_A.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_B.get_tma_descriptor());
|
||||
}
|
||||
|
||||
template <int CUR_N, typename SharedStorage, typename FrgTensorO, typename TiledMma>
|
||||
CUTLASS_DEVICE void
|
||||
store(Params const& mainloop_params,
|
||||
FrgTensorO & tOrO,
|
||||
SharedStorage& shared_storage,
|
||||
TiledMma tiled_mma,
|
||||
const float *input_row_sum,
|
||||
const float *weight_scale,
|
||||
const int tokens,
|
||||
const int pre_fix_tokens,
|
||||
const int bidm,
|
||||
const int bidn,
|
||||
const int bidb,
|
||||
const int tidx) {
|
||||
|
||||
using packHalf = typename PackedHalf<ElementOutput>::Type;
|
||||
Tensor tOrO_out = make_tensor<ElementOutput>(tOrO.layout());
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tOrO); i+=4) {
|
||||
const int sum_idx = i * 2;
|
||||
tOrO[i] = (tOrO[i] + input_row_sum[sum_idx]) * weight_scale[0];
|
||||
tOrO[i + 1] = (tOrO[i + 1] + input_row_sum[sum_idx + 1]) * weight_scale[0];
|
||||
tOrO[i + 2] = (tOrO[i + 2] + input_row_sum[sum_idx]) * weight_scale[1];
|
||||
tOrO[i + 3] = (tOrO[i + 3] + input_row_sum[sum_idx + 1]) * weight_scale[1];
|
||||
*reinterpret_cast<packHalf*>(&tOrO_out[i]) = packHalf(tOrO[i], tOrO[i + 2]);
|
||||
*reinterpret_cast<packHalf*>(&tOrO_out[i + 2]) = packHalf(tOrO[i + 1], tOrO[i + 3]);
|
||||
}
|
||||
|
||||
uint16_t *smem_c = reinterpret_cast<uint16_t *>(shared_storage.smem_c.data());
|
||||
|
||||
uint32_t * reg_data = reinterpret_cast<uint32_t*>(tOrO_out.data());
|
||||
|
||||
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
|
||||
|
||||
constexpr int k_copy_times = CUR_N / 16;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < k_copy_times; i++) {
|
||||
uint32_t smem_ptr = cast_smem_ptr_to_uint(reinterpret_cast<uint128_t*>(smem_c + i * 16 * 128) + tidx);
|
||||
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
|
||||
asm volatile (
|
||||
"stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
|
||||
:: "r"(smem_ptr), "r"(reg_data[4 * i + 0]), "r"(reg_data[4 * i + 2]), "r"(reg_data[4 * i + 1]), "r"(reg_data[4 * i + 3]));
|
||||
#endif
|
||||
}
|
||||
|
||||
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
|
||||
const int batch_idx = TokenPackSize == 0 ? pre_fix_tokens * M : bidb * M * TokenPackSize;
|
||||
ElementOutput * store_c = mainloop_params.ptr_C + batch_idx + bidn * (M * kBlockN) + bidm * kBlockM;
|
||||
|
||||
const int reamin_tokens = tokens - bidn * kBlockN;
|
||||
|
||||
const int col = tidx % 2;
|
||||
|
||||
constexpr int kPackSize = 16 / sizeof(ElementOutput);
|
||||
constexpr int kNumVecElem = kBlockM / kPackSize;
|
||||
constexpr int copy_len = CUR_N * kNumVecElem;
|
||||
#pragma unroll
|
||||
for (int idx = tidx; idx < copy_len; idx += NumMmaThreads) {
|
||||
const int idx_div2 = idx / 2;
|
||||
const int store_idx = idx_div2 / 128 * 128 + idx_div2 % 8 * 16 + idx_div2 % 128 / 16 + idx_div2 % 16 / 8 * 8;
|
||||
const int store_global_idx = store_idx * 2 + col;
|
||||
const int row = store_global_idx / kNumVecElem;
|
||||
const int col = store_global_idx % kNumVecElem;
|
||||
if (row >= reamin_tokens) {
|
||||
continue;
|
||||
}
|
||||
const int offset = row * (M / kPackSize) + col;
|
||||
reinterpret_cast<uint4*>(store_c)[offset] = reinterpret_cast<uint4*>(smem_c)[idx];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename MTensor>
|
||||
CUTLASS_DEVICE auto get_local_no_packed_tensor(
|
||||
const MTensor &mB,
|
||||
const int pre_fix_token,
|
||||
const int actual_token,
|
||||
const int bidn) const {
|
||||
|
||||
auto g_offset = local_tile(
|
||||
mB(_, _, 0),
|
||||
cute::make_shape(1, size<1>(mB)),
|
||||
make_coord(pre_fix_token, _0{}));
|
||||
|
||||
auto g_tensor = make_tensor(
|
||||
g_offset.data(),
|
||||
make_layout(
|
||||
cute::make_shape(actual_token, size<2>(mB)),
|
||||
g_offset.stride()
|
||||
));
|
||||
|
||||
Tensor gB = local_tile(g_tensor, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
|
||||
|
||||
return gB;
|
||||
}
|
||||
|
||||
template <typename SharedStorage>
|
||||
CUTLASS_DEVICE void
|
||||
load(Params const& mainloop_params,
|
||||
MainloopPipeline pipeline,
|
||||
PipelineState& smem_pipe_write,
|
||||
SharedStorage &shared_storage,
|
||||
const int tokens,
|
||||
const int pre_fix_tokens,
|
||||
const int bidm,
|
||||
const int bidn,
|
||||
const int bidb,
|
||||
const int tidx) {
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{});
|
||||
|
||||
Tensor mA = mainloop_params.tma_load_A.get_tma_tensor(mainloop_params.layout_A.shape());
|
||||
Tensor mB = mainloop_params.tma_load_B.get_tma_tensor(mainloop_params.layout_B.shape());
|
||||
|
||||
Tensor gA = local_tile(mA(_, _, bidb), select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}), make_coord(bidm, _));
|
||||
|
||||
auto [tAgA, tAsA] = tma_partition(mainloop_params.tma_load_A, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sA), group_modes<0, 2>(gA));
|
||||
|
||||
const int kIters = kTiles / kStages;
|
||||
|
||||
if constexpr (TokenPackSize == 0) {
|
||||
Tensor gB = get_local_no_packed_tensor(
|
||||
mB,
|
||||
pre_fix_tokens,
|
||||
tokens,
|
||||
bidn);
|
||||
|
||||
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sB), group_modes<0, 2>(gB));
|
||||
|
||||
if (tidx == 0) {
|
||||
#pragma unroll
|
||||
for (int kiter = 0; kiter < kIters; ++kiter) {
|
||||
#pragma unroll
|
||||
for (int s = 0; s < kStages; s++) {
|
||||
const int i = kiter * kStages + s;
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
|
||||
|
||||
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = kIters * kStages; i < kTiles; ++i) {
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
|
||||
|
||||
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto mB_this_batch = make_tensor(
|
||||
mB(_, _, bidb).data(),
|
||||
make_layout(
|
||||
cute::make_shape(tokens, size<1>(mB)),
|
||||
mB.stride()
|
||||
));
|
||||
Tensor gB = local_tile(mB_this_batch, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
|
||||
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sB), group_modes<0, 2>(gB));
|
||||
|
||||
if (tidx == 0) {
|
||||
#pragma unroll
|
||||
for (int kiter = 0; kiter < kIters; ++kiter) {
|
||||
#pragma unroll
|
||||
for (int s = 0; s < kStages; s++) {
|
||||
const int i = kiter * kStages + s;
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
|
||||
|
||||
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = kIters * kStages; i < kTiles; ++i) {
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
|
||||
|
||||
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int CUR_N, typename SharedStorage, typename FrgTensorO, typename TiledMma>
|
||||
CUTLASS_DEVICE void
|
||||
mma(Params const& mainloop_params,
|
||||
TiledMma tiled_mma,
|
||||
MainloopPipeline pipeline,
|
||||
PipelineState& smem_pipe_read,
|
||||
SharedStorage& shared_storage,
|
||||
FrgTensorO &tSrS,
|
||||
const int tidx) {
|
||||
|
||||
using sMemBLayout = std::conditional_t<
|
||||
CUR_N == kBlockN,
|
||||
SmemLayoutB,
|
||||
SmemLayoutB_TAIL
|
||||
>;
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), sMemBLayout{});
|
||||
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
|
||||
auto threadMma = tiled_mma.get_thread_slice(tidx);
|
||||
|
||||
auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomAB{}, tiled_mma);
|
||||
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(tidx);
|
||||
|
||||
Tensor tSrA = threadMma.partition_fragment_A(sA(_, _, 0));
|
||||
Tensor tSrB = threadMma.partition_fragment_B(sB);
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
};
|
||||
|
||||
const int kIters = kTiles / kStages;
|
||||
|
||||
constexpr int B_STEPS = CUR_N == 0 ? 1 : (kBlockN / CUR_N);
|
||||
|
||||
#pragma unroll
|
||||
for (int kiter = 0; kiter < kIters; ++kiter) {
|
||||
#pragma unroll
|
||||
for (int s = 0; s < kStages; s++) {
|
||||
Tensor tSsA = smem_thr_copy_A.partition_S(sA(_, _, s));
|
||||
consumer_wait(pipeline, smem_pipe_read);
|
||||
gemm</*wg_wait=*/0>(tiled_mma, tSrA, tSsA, tSrB(_, _, _, s * B_STEPS), tSrS, smem_tiled_copy_A, smem_thr_copy_A);
|
||||
pipeline.consumer_release(smem_pipe_read);
|
||||
++smem_pipe_read;
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kTiles % kStages; ++i) {
|
||||
Tensor tSsA = smem_thr_copy_A.partition_S(sA(_, _, i));
|
||||
consumer_wait(pipeline, smem_pipe_read);
|
||||
|
||||
gemm</*wg_wait=*/0>(tiled_mma, tSrA, tSsA, tSrB(_, _, _, i * B_STEPS), tSrS, smem_tiled_copy_A, smem_thr_copy_A);
|
||||
pipeline.consumer_release(smem_pipe_read);
|
||||
++smem_pipe_read;
|
||||
}
|
||||
}
|
||||
};
|
114
custom_ops/gpu_ops/w4afp8_gemm/utils.hpp
Normal file
114
custom_ops/gpu_ops/w4afp8_gemm/utils.hpp
Normal file
@@ -0,0 +1,114 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/arch/cluster_sm90.hpp> // For cute::elect_one_sync()
|
||||
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<typename T>
|
||||
struct PackedHalf;
|
||||
|
||||
template<>
|
||||
struct PackedHalf<cutlass::half_t> {
|
||||
using Type = __half2;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct PackedHalf<cutlass::bfloat16_t> {
|
||||
using Type = nv_bfloat162;
|
||||
};
|
||||
|
||||
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
|
||||
using From_type = typename Engine::value_type;
|
||||
constexpr int numel = decltype(size(tensor))::value;
|
||||
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
||||
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
|
||||
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
|
||||
}
|
||||
|
||||
template <int numel>
|
||||
__forceinline__ __device__ void convert_c4_2_fp8(const int32_t * src, int32_t * dst1, int32_t * dst2) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < numel; ++i) {
|
||||
dst1[i] = (src[i] >> 4) & 0x0f0f0f0f;
|
||||
dst2[i] = src[i] & 0x0f0f0f0f;
|
||||
}
|
||||
}
|
||||
|
||||
template <int wg_wait=0, bool arrive=true,
|
||||
bool commit=true, typename Tensor0, typename Tensor1,
|
||||
typename Tensor2, typename Tensor3, typename TiledMma,
|
||||
typename ThrCopyA, typename TiledCopyA>
|
||||
__forceinline__ __device__ void gemm(
|
||||
TiledMma &tiled_mma,
|
||||
Tensor0 &tCrA,
|
||||
Tensor1 &tCsA,
|
||||
Tensor2 const &tCrB,
|
||||
Tensor3 &tCrC,
|
||||
TiledCopyA const &tiled_copy_A,
|
||||
ThrCopyA const &thr_copy_A) {
|
||||
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
|
||||
Tensor tCrA1 = make_tensor<cutlass::float_e4m3_t>(tCrA.layout());
|
||||
Tensor tCrA2 = make_tensor<cutlass::float_e4m3_t>(tCrA.layout());
|
||||
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
|
||||
warpgroup_fence_operand(tCrC);
|
||||
if constexpr (arrive) {
|
||||
warpgroup_arrive();
|
||||
}
|
||||
constexpr int numel = decltype(size(tCrA(_, _, 0)))::value / 4;
|
||||
|
||||
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
|
||||
cute::copy(tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{}));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
if (k_block < size<2>(tCrA) - 1) {
|
||||
cute::copy(tiled_copy_A, tCsA(_, _, k_block + 1), tCrA_copy_view(_, _, k_block + 1));
|
||||
}
|
||||
int32_t * tCrA_data = reinterpret_cast<int32_t *>(tCrA(_,_,k_block).data());
|
||||
int32_t * tCrA1_data = reinterpret_cast<int32_t *>(tCrA1(_,_,k_block).data());
|
||||
int32_t * tCrA2_data = reinterpret_cast<int32_t *>(tCrA2(_,_,k_block).data());
|
||||
convert_c4_2_fp8<numel>(tCrA_data, tCrA1_data, tCrA2_data);
|
||||
|
||||
cute::gemm(tiled_mma, tCrA1(_,_,k_block), tCrB(_,_,2 * k_block), tCrC);
|
||||
cute::gemm(tiled_mma, tCrA2(_,_,k_block), tCrB(_,_, 2 * k_block + 1), tCrC);
|
||||
}
|
||||
if constexpr (commit) {
|
||||
warpgroup_commit_batch();
|
||||
}
|
||||
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
|
||||
warpgroup_fence_operand(tCrC);
|
||||
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
|
||||
}
|
213
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu
Normal file
213
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu
Normal file
@@ -0,0 +1,213 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
#include "w4afp8_gemm_template.h"
|
||||
|
||||
|
||||
void weight_convert(const uint8_t *weight, uint8_t *weight_new, int batch, int M, int K) {
|
||||
assert(K % 64 == 0);
|
||||
for (int b = 0; b < batch; ++b) {
|
||||
for (int m = 0; m < M; ++m) {
|
||||
for (int k = 0; k < K; k+=64) {
|
||||
for (int k_inner = 0; k_inner < 32; ++k_inner) {
|
||||
uint8_t temp = 0;
|
||||
uint8_t left = weight[b * M * K + m * K + k + k_inner];
|
||||
uint8_t right = weight[b * M * K + m * K + k + k_inner + 32];
|
||||
temp |= left << 4;
|
||||
temp |= right;
|
||||
weight_new[b * M * K / 2 + m * K / 2 + k / 2 + k_inner] = *reinterpret_cast<uint8_t*>(&temp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
template <typename OutputType>
|
||||
void DisPatchW4AFp8Gemm(
|
||||
const cutlass::float_e4m3_t* input,
|
||||
const cutlass::float_e4m3_t* weight,
|
||||
const int * tokens,
|
||||
const float * input_row_sum,
|
||||
const float * weight_scale,
|
||||
OutputType * out,
|
||||
const int token_padding_size,
|
||||
const int max_tokens,
|
||||
const int batch_size,
|
||||
const int M,
|
||||
const int K,
|
||||
cudaStream_t stream) {
|
||||
|
||||
int kBlockN = (max_tokens + 15) / 16 * 16;
|
||||
int TailN = 0;
|
||||
if (kBlockN > 256) {
|
||||
TailN = kBlockN % 256;
|
||||
kBlockN = 256;
|
||||
}
|
||||
if constexpr (std::is_same_v<OutputType, cutlass::bfloat16_t>) {
|
||||
GEMM_SWITCH_BF16(
|
||||
M, K, batch_size, token_padding_size, kBlockN, TailN,
|
||||
weight,
|
||||
input,
|
||||
out,
|
||||
weight_scale,
|
||||
input_row_sum,
|
||||
tokens,
|
||||
max_tokens,
|
||||
stream)
|
||||
} else {
|
||||
GEMM_SWITCH_FP16(
|
||||
M, K, batch_size, token_padding_size, kBlockN, TailN,
|
||||
weight,
|
||||
input,
|
||||
out,
|
||||
weight_scale,
|
||||
input_row_sum,
|
||||
tokens,
|
||||
max_tokens,
|
||||
stream)
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> W4AFp8Gemm(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& weight,
|
||||
const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group
|
||||
const paddle::Tensor& input_row_sum,
|
||||
const paddle::Tensor& weight_scale,
|
||||
const int token_padding_size,
|
||||
const int max_tokens,
|
||||
const bool is_bflot16) {
|
||||
|
||||
const int batch_size = weight.dims()[0];
|
||||
const int M = weight.dims()[1];
|
||||
const int K = weight.dims()[2] * 2;
|
||||
|
||||
if (input.dtype() != paddle::DataType::FLOAT8_E4M3FN) {
|
||||
PD_THROW("Only supported dtype in ['FLOAT8_E4M3FN'].");
|
||||
}
|
||||
|
||||
if (token_padding_size == 0) {
|
||||
const int all_tokens = input.dims()[0];
|
||||
if (is_bflot16) {
|
||||
paddle::Tensor out = paddle::empty({all_tokens, M}, paddle::DataType::BFLOAT16, input.place());
|
||||
phi::dtype::bfloat16 *out_data = out.data<phi::dtype::bfloat16>();
|
||||
DisPatchW4AFp8Gemm(
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
|
||||
tokens.data<int>(),
|
||||
input_row_sum.data<float>(),
|
||||
weight_scale.data<float>(),
|
||||
reinterpret_cast<cutlass::bfloat16_t*>(out_data),
|
||||
token_padding_size,
|
||||
max_tokens,
|
||||
batch_size,
|
||||
M,
|
||||
K,
|
||||
input.stream());
|
||||
return {out};
|
||||
} else {
|
||||
paddle::Tensor out = paddle::empty({all_tokens, M}, paddle::DataType::FLOAT16, input.place());
|
||||
phi::dtype::float16 *out_data = out.data<phi::dtype::float16>();
|
||||
DisPatchW4AFp8Gemm(
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
|
||||
tokens.data<int>(),
|
||||
input_row_sum.data<float>(),
|
||||
weight_scale.data<float>(),
|
||||
reinterpret_cast<cutlass::half_t*>(out_data),
|
||||
token_padding_size,
|
||||
max_tokens,
|
||||
batch_size,
|
||||
M,
|
||||
K,
|
||||
input.stream());
|
||||
return {out};
|
||||
}
|
||||
} else {
|
||||
if (is_bflot16) {
|
||||
paddle::Tensor out = paddle::empty({batch_size, token_padding_size, M}, paddle::DataType::BFLOAT16, input.place());
|
||||
phi::dtype::bfloat16 * out_data = out.data<phi::dtype::bfloat16>();
|
||||
DisPatchW4AFp8Gemm(
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
|
||||
tokens.data<int>(),
|
||||
input_row_sum.data<float>(),
|
||||
weight_scale.data<float>(),
|
||||
reinterpret_cast<cutlass::bfloat16_t*>(out_data),
|
||||
token_padding_size,
|
||||
max_tokens,
|
||||
batch_size,
|
||||
M,
|
||||
K,
|
||||
input.stream());
|
||||
return {out};
|
||||
} else {
|
||||
paddle::Tensor out = paddle::empty({batch_size, token_padding_size, M}, paddle::DataType::FLOAT16, input.place());
|
||||
phi::dtype::float16 * out_data = out.data<phi::dtype::float16>();
|
||||
|
||||
DisPatchW4AFp8Gemm(
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
|
||||
tokens.data<int>(),
|
||||
input_row_sum.data<float>(),
|
||||
weight_scale.data<float>(),
|
||||
reinterpret_cast<cutlass::half_t*>(out_data),
|
||||
token_padding_size,
|
||||
max_tokens,
|
||||
batch_size,
|
||||
M,
|
||||
K,
|
||||
input.stream());
|
||||
return {out};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> W4AFp8GemmWeightConvert(const paddle::Tensor& weight) {
|
||||
const int batch_size = weight.dims()[0];
|
||||
const int M = weight.dims()[1];
|
||||
const int K = weight.dims()[2];
|
||||
paddle::Tensor weight_new = paddle::empty({batch_size, M, K / 2}, paddle::DataType::UINT8, weight.place());
|
||||
weight_convert(weight.data<uint8_t>(), weight_new.data<uint8_t>(), batch_size, M, K);
|
||||
return {weight_new};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(w4afp8_gemm)
|
||||
.Inputs({"input",
|
||||
"weight",
|
||||
"tokens",
|
||||
"input_row_sum",
|
||||
"weight_scale"})
|
||||
.Outputs({"out"})
|
||||
.Attrs({"token_padding_size: int",
|
||||
"max_tokens: int",
|
||||
"is_bflot16: bool"})
|
||||
.SetKernelFn(PD_KERNEL(W4AFp8Gemm));
|
||||
|
||||
PD_BUILD_STATIC_OP(w4afp8_gemm_weight_convert)
|
||||
.Inputs({"weight"})
|
||||
.Outputs({"converted_weight"})
|
||||
.SetKernelFn(PD_KERNEL(W4AFp8GemmWeightConvert));
|
252
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp
Normal file
252
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp
Normal file
@@ -0,0 +1,252 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
|
||||
#include "kernel_traits.h"
|
||||
#include "mainloop_fwd.h"
|
||||
|
||||
template <typename Ktraits>
|
||||
void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) w4afp8_geem_kernel(
|
||||
CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits>::Params const mainloop_params) {
|
||||
|
||||
using Element = typename Ktraits::Element;
|
||||
static_assert(cutlass::sizeof_bits_v<Element> == 8);
|
||||
|
||||
using TileShape_MNK = typename Ktraits::TileShape_MNK;
|
||||
using TileShape_MNK_TAIL = typename Ktraits::TileShape_MNK_TAIL;
|
||||
using ClusterShape = typename Ktraits::ClusterShape_MNK;
|
||||
|
||||
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
|
||||
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr int kBlockN = Ktraits::kBlockN;
|
||||
static constexpr int kBlockM = Ktraits::kBlockM;
|
||||
static constexpr int M = Ktraits::M;
|
||||
static constexpr int TokenPackSize = Ktraits::TokenPackSize;
|
||||
static constexpr int TAIL_N = Ktraits::TAIL_N;
|
||||
|
||||
using CollectiveMainloop = CollectiveMainloopFwd<Ktraits>;
|
||||
|
||||
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
using ElementOutput = typename Ktraits::ElementOutput;
|
||||
|
||||
extern __shared__ char shared_memory[];
|
||||
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
|
||||
|
||||
const int bidm = blockIdx.x;
|
||||
const int bidn = blockIdx.y;
|
||||
const int bidb = blockIdx.z;
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
if (tidx == 0) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
|
||||
}
|
||||
|
||||
// Obtain warp index
|
||||
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
||||
|
||||
PipelineParams pipeline_params;
|
||||
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesA + CollectiveMainloop::TmaTransactionBytesB;
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
pipeline_params.role = warp_group_idx == 0
|
||||
? MainloopPipeline::ThreadCategory::Producer
|
||||
: MainloopPipeline::ThreadCategory::Consumer;
|
||||
pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
pipeline_params.num_consumers = NumMmaThreads;
|
||||
|
||||
MainloopPipeline pipeline(shared_storage.pipeline, pipeline_params, ClusterShape{});
|
||||
|
||||
CollectiveMainloop collective_mainloop;
|
||||
|
||||
if constexpr (size(ClusterShape{}) > 1) {
|
||||
cute::cluster_arrive_relaxed();
|
||||
cute::cluster_wait();
|
||||
} else {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
const int pre_fix_tokens = TokenPackSize == 0 ? mainloop_params.tokens[bidb] : 0;
|
||||
|
||||
const int tokens = TokenPackSize == 0 ? mainloop_params.tokens[bidb + 1] - pre_fix_tokens : mainloop_params.tokens[bidb];
|
||||
|
||||
|
||||
if (bidn * kBlockN >= tokens) {
|
||||
return;
|
||||
}
|
||||
|
||||
float* input_row_sum = reinterpret_cast<float*>(
|
||||
shared_memory + sizeof(typename Ktraits::SharedStorage));
|
||||
|
||||
if (warp_group_idx == 0) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 12 ? 40 : 32>();
|
||||
PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
collective_mainloop.load(
|
||||
mainloop_params,
|
||||
pipeline,
|
||||
smem_pipe_write,
|
||||
shared_storage,
|
||||
tokens,
|
||||
pre_fix_tokens,
|
||||
bidm,
|
||||
bidn,
|
||||
bidb,
|
||||
tidx);
|
||||
} else {
|
||||
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 232 : 160>();
|
||||
PipelineState smem_pipe_read;
|
||||
|
||||
typename Ktraits::TiledMma tiled_mma;
|
||||
|
||||
typename Ktraits::TiledMma_TAIL tiled_mma_tail;
|
||||
|
||||
const int mma_tidx = tidx - NumCopyThreads;
|
||||
const int lane_id = mma_tidx % 4 * 2;
|
||||
|
||||
const float2 weight_scale = reinterpret_cast<const float2*>(mainloop_params.weight_scale + bidb * M + bidm * kBlockM)[mma_tidx / 4];
|
||||
|
||||
if constexpr (TokenPackSize == 0) {
|
||||
const int input_sum_idx = pre_fix_tokens + bidn * kBlockN;
|
||||
if (mma_tidx < kBlockN) {
|
||||
reinterpret_cast<float*>(input_row_sum)[mma_tidx] = reinterpret_cast<const float*>(mainloop_params.input_row_sum + input_sum_idx)[mma_tidx];
|
||||
}
|
||||
} else {
|
||||
const int input_sum_idx = bidb * TokenPackSize + bidn * kBlockN;
|
||||
if (mma_tidx < kBlockN / 4) {
|
||||
reinterpret_cast<float4*>(input_row_sum)[mma_tidx] = reinterpret_cast<const float4*>(mainloop_params.input_row_sum + input_sum_idx)[mma_tidx];
|
||||
}
|
||||
}
|
||||
|
||||
const int reamin_tokens = tokens - bidn * kBlockN;
|
||||
|
||||
if (TAIL_N > 0 && reamin_tokens < kBlockN) {
|
||||
Tensor tSrS_tail = partition_fragment_C(tiled_mma_tail, select<0, 1>(TileShape_MNK_TAIL{}));
|
||||
collective_mainloop.mma<TAIL_N>(
|
||||
mainloop_params,
|
||||
tiled_mma_tail,
|
||||
pipeline,
|
||||
smem_pipe_read,
|
||||
shared_storage,
|
||||
tSrS_tail,
|
||||
mma_tidx);
|
||||
collective_mainloop.store<TAIL_N>(
|
||||
mainloop_params,
|
||||
tSrS_tail,
|
||||
shared_storage,
|
||||
tiled_mma_tail,
|
||||
input_row_sum + lane_id,
|
||||
reinterpret_cast<const float*>(&weight_scale),
|
||||
tokens,
|
||||
pre_fix_tokens,
|
||||
bidm,
|
||||
bidn,
|
||||
bidb,
|
||||
mma_tidx);
|
||||
} else {
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{}));
|
||||
collective_mainloop.mma<kBlockN>(
|
||||
mainloop_params,
|
||||
tiled_mma,
|
||||
pipeline,
|
||||
smem_pipe_read,
|
||||
shared_storage,
|
||||
tSrS,
|
||||
mma_tidx);
|
||||
collective_mainloop.store<kBlockN>(
|
||||
mainloop_params,
|
||||
tSrS,
|
||||
shared_storage,
|
||||
tiled_mma,
|
||||
input_row_sum + lane_id,
|
||||
reinterpret_cast<const float*>(&weight_scale),
|
||||
tokens,
|
||||
pre_fix_tokens,
|
||||
bidm,
|
||||
bidn,
|
||||
bidb,
|
||||
mma_tidx);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <int Batch>
|
||||
auto get_gmem_layout(const int Rows, const int Cols) {
|
||||
return make_layout(
|
||||
make_shape(
|
||||
static_cast<int64_t>(Rows),
|
||||
static_cast<int64_t>(Cols),
|
||||
static_cast<int64_t>(Batch)),
|
||||
make_stride(
|
||||
static_cast<int64_t>(Cols),
|
||||
cute::_1{},
|
||||
static_cast<int64_t>(Rows * Cols)));
|
||||
}
|
||||
|
||||
|
||||
template <typename InputType, typename OutputType, typename Kernel_traits, int M, int K, int Batch, int TokenPackSize>
|
||||
void run_gemm(const InputType * A, const InputType * B, OutputType * C, const float *weight_scale,
|
||||
const float *input_row_sum, const int * tokens, const int max_tokens, cudaStream_t stream) {
|
||||
|
||||
using ElementOutput = typename Kernel_traits::ElementOutput;
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using CollectiveMainloop = CollectiveMainloopFwd<Kernel_traits>;
|
||||
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
|
||||
|
||||
constexpr int M_nums = (M + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||
const int N_nums = (max_tokens + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
|
||||
|
||||
typename CollectiveMainloop::Params mainloop_params =
|
||||
CollectiveMainloop::to_underlying_arguments({
|
||||
static_cast<Element const*>(A),
|
||||
get_gmem_layout<Batch>(M, K / 2),
|
||||
static_cast<Element const*>(B),
|
||||
get_gmem_layout<Batch>(TokenPackSize == 0 ? max_tokens * Batch : TokenPackSize, K),
|
||||
static_cast<ElementOutput*>(C),
|
||||
get_gmem_layout<Batch>(M, TokenPackSize == 0 ? max_tokens : TokenPackSize),
|
||||
weight_scale,
|
||||
input_row_sum,
|
||||
tokens
|
||||
});
|
||||
|
||||
void *kernel;
|
||||
kernel = (void *)w4afp8_geem_kernel<Kernel_traits>;
|
||||
|
||||
int smem_size = sizeof(typename Kernel_traits::SharedStorage) + sizeof(float) * Kernel_traits::kBlockN;
|
||||
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
|
||||
dim3 grid_dims;
|
||||
grid_dims.x = M_nums;
|
||||
grid_dims.y = N_nums;
|
||||
grid_dims.z = Batch;
|
||||
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
|
||||
dim3 block_dims(ctaSize);
|
||||
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
|
||||
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
|
||||
cutlass::launch_kernel_on_cluster(
|
||||
launch_params, kernel, mainloop_params);
|
||||
}
|
@@ -494,6 +494,8 @@ elif paddle.is_compiled_with_cuda():
|
||||
if cc >= 90 and nvcc_version >= 12.0:
|
||||
# Hopper optmized mla
|
||||
sources += find_end_files("gpu_ops/mla_attn", ".cu")
|
||||
os.system("python utils/auto_gen_w4afp8_gemm_kernel.py")
|
||||
sources += find_end_files("gpu_ops/w4afp8_gemm", ".cu")
|
||||
|
||||
setup(
|
||||
name="fastdeploy_ops",
|
||||
|
207
custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py
Normal file
207
custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py
Normal file
@@ -0,0 +1,207 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
file_dir = "./gpu_ops/w4afp8_gemm/"
|
||||
|
||||
gemm_template_head = """
|
||||
#pragma once
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
#include <cuda_fp16.h>
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
"""
|
||||
gemm_template_case = """
|
||||
void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
|
||||
const cutlass::float_e4m3_t * weight,
|
||||
const cutlass::float_e4m3_t * input,
|
||||
{cutlass_type} * out,
|
||||
const float *weight_scale,
|
||||
const float *input_row_sum,
|
||||
const int *tokens,
|
||||
const int max_tokens,
|
||||
cudaStream_t stream);
|
||||
"""
|
||||
|
||||
gemm_template_cu_head = """
|
||||
#include "paddle/extension.h"
|
||||
#include "w4afp8_gemm_template.h"
|
||||
#include "w4afp8_gemm_kernel.hpp"
|
||||
|
||||
"""
|
||||
gemm_template_cu_template = """
|
||||
void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
|
||||
const cutlass::float_e4m3_t * weight,
|
||||
const cutlass::float_e4m3_t * input,
|
||||
{cutlass_type} * out,
|
||||
const float *weight_scale,
|
||||
const float *input_row_sum,
|
||||
const int *tokens,
|
||||
const int max_tokens,
|
||||
cudaStream_t stream) {{
|
||||
|
||||
constexpr static int M = {M};
|
||||
constexpr static int K = {K};
|
||||
constexpr static int Batch = {BATCH};
|
||||
constexpr static int TokenPackSize = {PADDING};
|
||||
constexpr static int kBlockN = {N};
|
||||
constexpr static int kBlockN_TAIL = {TAILN};
|
||||
constexpr static int kBlockM = 128;
|
||||
constexpr static int kBlockK = 128;
|
||||
constexpr static int kNWarps = 4 + kBlockM / 16;
|
||||
constexpr static int kStages = 5;
|
||||
constexpr int kCluster = 1;
|
||||
static_assert(K % kBlockK == 0);
|
||||
constexpr int kTiles = K / kBlockK;
|
||||
|
||||
using Kernel_traits = Kernel_traits<
|
||||
kBlockM, kBlockN, kBlockK, kNWarps, kStages, kTiles,
|
||||
M, TokenPackSize, kBlockN_TAIL, kCluster, cutlass::float_e4m3_t,
|
||||
{cutlass_type}>;
|
||||
run_gemm<cutlass::float_e4m3_t, {cutlass_type},
|
||||
Kernel_traits, M, K, Batch, TokenPackSize>
|
||||
(weight, input, out, weight_scale,
|
||||
input_row_sum, tokens, max_tokens, stream);
|
||||
}}
|
||||
"""
|
||||
|
||||
gemm_case = [
|
||||
[8192, 3584, 8, 0], # eb45T ffn1
|
||||
[8192, 3584, 8, 2048], # eb45T ffn1
|
||||
[7168, 8192, 8, 0], # eb45T ffn2
|
||||
[7168, 8192, 8, 2048], # eb45T ffn2
|
||||
]
|
||||
|
||||
dtype = ["BF16", "FP16"]
|
||||
|
||||
|
||||
def get_cutlass_type(type):
|
||||
if type == "BF16":
|
||||
return "cutlass::bfloat16_t"
|
||||
elif type == "FP16":
|
||||
return "cutlass::half_t"
|
||||
|
||||
|
||||
template_head_file = open(f"{file_dir}w4afp8_gemm_template.h", "w")
|
||||
template_head_file.write(gemm_template_head)
|
||||
|
||||
for type in dtype:
|
||||
for case in gemm_case:
|
||||
for n in range(16, 257, 16):
|
||||
template_head_file.write(
|
||||
gemm_template_case.format(
|
||||
M=case[0],
|
||||
K=case[1],
|
||||
N=n,
|
||||
BATCH=case[2],
|
||||
TYPE=type,
|
||||
PADDING=case[3],
|
||||
TAILN=0,
|
||||
cutlass_type=get_cutlass_type(type),
|
||||
)
|
||||
)
|
||||
template_head_file.write(
|
||||
gemm_template_case.format(
|
||||
M=case[0],
|
||||
K=case[1],
|
||||
N=256,
|
||||
BATCH=case[2],
|
||||
TYPE=type,
|
||||
PADDING=case[3],
|
||||
TAILN=n - 16,
|
||||
cutlass_type=get_cutlass_type(type),
|
||||
)
|
||||
)
|
||||
|
||||
template_cu_file = open(
|
||||
f"{file_dir}w4afp8_gemm_M{case[0]}_N{n}_TAILN{0}_K{case[1]}_B{case[2]}_P{case[3]}_{type}.cu", "w"
|
||||
)
|
||||
template_cu_file.write(gemm_template_cu_head)
|
||||
template_cu_file.write(
|
||||
gemm_template_cu_template.format(
|
||||
M=case[0],
|
||||
K=case[1],
|
||||
N=n,
|
||||
BATCH=case[2],
|
||||
TYPE=type,
|
||||
PADDING=case[3],
|
||||
TAILN=0,
|
||||
cutlass_type=get_cutlass_type(type),
|
||||
)
|
||||
)
|
||||
|
||||
template_cu_file.close()
|
||||
|
||||
template_cu_file = open(
|
||||
f"{file_dir}w4afp8_gemm_M{case[0]}_N{256}_TAILN{n-16}_K{case[1]}_B{case[2]}_P{case[3]}_{type}.cu", "w"
|
||||
)
|
||||
template_cu_file.write(gemm_template_cu_head)
|
||||
template_cu_file.write(
|
||||
gemm_template_cu_template.format(
|
||||
M=case[0],
|
||||
K=case[1],
|
||||
N=256,
|
||||
BATCH=case[2],
|
||||
TYPE=type,
|
||||
PADDING=case[3],
|
||||
TAILN=n - 16,
|
||||
cutlass_type=get_cutlass_type(type),
|
||||
)
|
||||
)
|
||||
|
||||
template_cu_file.close()
|
||||
|
||||
for type in dtype:
|
||||
template_head_file.write("\n")
|
||||
template_head_file.write(
|
||||
"""#define GEMM_SWITCH_{TYPE}(_M, _K, _BATCH, _TokenPaddingSize, _kBlockN, _TailN, ...) {{ \\
|
||||
if (_M == 0 && _K == 0 && _BATCH == 0 && _TokenPaddingSize == 0 && _kBlockN == 0 && _TailN == 0) {{ \\""".format(
|
||||
TYPE=type
|
||||
)
|
||||
)
|
||||
|
||||
template_head_file.write("\n")
|
||||
|
||||
for case in gemm_case:
|
||||
for n in range(16, 257, 16):
|
||||
template_head_file.write(
|
||||
""" }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\
|
||||
w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(
|
||||
M=case[0], K=case[1], N=n, BATCH=case[2], TYPE=type, PADDING=case[3], TAILN=0
|
||||
)
|
||||
)
|
||||
template_head_file.write("\n")
|
||||
template_head_file.write(
|
||||
""" }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\
|
||||
w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(
|
||||
M=case[0], K=case[1], N=256, BATCH=case[2], TYPE=type, PADDING=case[3], TAILN=n - 16
|
||||
)
|
||||
)
|
||||
template_head_file.write("\n")
|
||||
|
||||
template_head_file.write(
|
||||
""" } else { \\
|
||||
PADDLE_THROW(phi::errors::Unimplemented("W4aFp8 not supported m=%d k=%d batch=%d token_padding_size=%d kBlockN=%d tailN=%d\\n", _M, _K, _BATCH, _TokenPaddingSize, _kBlockN, _TailN)); \\
|
||||
} \\
|
||||
}"""
|
||||
)
|
||||
|
||||
template_head_file.close()
|
103
test/operators/test_w4afp8_gemm.py
Normal file
103
test/operators/test_w4afp8_gemm.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm, w4afp8_gemm_weight_convert
|
||||
|
||||
|
||||
def w4afp8_gemm_naive(input_bf16, weight_quant, tokens, weight_dequant_scale, BATCH, N):
|
||||
all_tokens = int(tokens.sum())
|
||||
out = paddle.zeros([all_tokens, N], dtype="bfloat16")
|
||||
pre_fix_token = 0
|
||||
for i in range(BATCH):
|
||||
input = input_bf16[pre_fix_token : pre_fix_token + tokens[i], :]
|
||||
weight = (weight_quant[i] - 7.0) * weight_dequant_scale[i]
|
||||
out_i = paddle.matmul(input, weight.astype("bfloat16"), transpose_y=True)
|
||||
out[pre_fix_token : pre_fix_token + tokens[i], :] = out_i
|
||||
pre_fix_token += tokens[i]
|
||||
return out
|
||||
|
||||
|
||||
def peruate_scale(weight_scale):
|
||||
weight_scale = weight_scale.reshape([BATCH, N])
|
||||
temp = paddle.zeros([16])
|
||||
for b in range(BATCH):
|
||||
for n in range(0, N, 16):
|
||||
temp[:] = weight_scale[b, n : n + 16]
|
||||
for j in range(0, 16, 2):
|
||||
weight_scale[b, n + j] = temp[j // 2]
|
||||
weight_scale[b, n + j + 1] = temp[j // 2 + 8]
|
||||
return weight_scale
|
||||
|
||||
|
||||
paddle.seed(0)
|
||||
tokens_per_group = 32
|
||||
N = 8192
|
||||
K = 3584
|
||||
BATCH = 8
|
||||
TokenPadding = 0
|
||||
|
||||
tokens = [tokens_per_group] * BATCH
|
||||
tokens_perfix_sum = np.cumsum(tokens)
|
||||
tokens_perfix_sum = np.insert(tokens_perfix_sum, 0, 0)
|
||||
|
||||
tokens = paddle.to_tensor(tokens, dtype="int32")
|
||||
tokens_perfix_sum = paddle.to_tensor(tokens_perfix_sum, dtype="int32")
|
||||
|
||||
all_tokens = int(tokens.sum())
|
||||
|
||||
input_fp8 = paddle.randn([all_tokens, K], dtype="bfloat16").astype(paddle.float8_e4m3fn)
|
||||
input_bf16 = input_fp8.astype("bfloat16")
|
||||
weight = paddle.randn([BATCH, N, K], dtype="bfloat16") / 10
|
||||
|
||||
weight_scale = 7 / weight.abs().max(axis=-1).reshape([BATCH, N, 1])
|
||||
weight_quant = (weight * weight_scale).astype("int") + 7
|
||||
weight_quant = paddle.clip(weight_quant, 0, 14)
|
||||
weight_quant = weight_quant.astype("bfloat16")
|
||||
weight_dequant_scale = 1 / weight_scale.astype("float32")
|
||||
input_row_sum = input_bf16.sum(axis=1) * -7 / 512
|
||||
max_tokens = int(tokens.max())
|
||||
|
||||
out_naive = w4afp8_gemm_naive(input_bf16, weight_quant, tokens, weight_dequant_scale, BATCH, N)
|
||||
weight_dequant_scale = paddle.to_tensor(peruate_scale(weight_dequant_scale) * 512)
|
||||
|
||||
weight_int4 = w4afp8_gemm_weight_convert(weight_quant.astype("uint8").cpu())
|
||||
|
||||
if TokenPadding == 0:
|
||||
out_cuda = w4afp8_gemm(
|
||||
input_fp8,
|
||||
weight_int4.cuda(),
|
||||
tokens_perfix_sum,
|
||||
input_row_sum.astype("float32"),
|
||||
weight_dequant_scale.astype("float32"),
|
||||
int(TokenPadding),
|
||||
max_tokens,
|
||||
True,
|
||||
)
|
||||
else:
|
||||
out_cuda = w4afp8_gemm(
|
||||
input_fp8,
|
||||
weight_int4.cuda(),
|
||||
tokens,
|
||||
input_row_sum.astype("float32"),
|
||||
weight_dequant_scale.astype("float32"),
|
||||
int(TokenPadding),
|
||||
max_tokens,
|
||||
True,
|
||||
)
|
||||
|
||||
gap = (out_cuda - out_naive).abs()
|
||||
assert float(gap.mean()) < 0.07
|
Reference in New Issue
Block a user