Revert "【New Feature】W4afp8 supports per group quantization (#4272)" (#4854)

This reverts commit 93fcf7e4ec.
This commit is contained in:
YuBaoku
2025-11-06 17:48:28 +08:00
committed by GitHub
parent 3478d20262
commit 819b2dbbae
26 changed files with 1718 additions and 4378 deletions

View File

@@ -24,116 +24,91 @@
#include <cuda_bf16.h>
#endif
#include <cute/arch/cluster_sm90.hpp> // For cute::elect_one_sync()
#include <cute/tensor.hpp>
#include <cute/arch/cluster_sm90.hpp> // For cute::elect_one_sync()
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
using namespace cute;
template <typename T>
template<typename T>
struct PackedHalf;
template <>
template<>
struct PackedHalf<cutlass::half_t> {
using Type = __half2;
using Type = __half2;
};
template <>
template<>
struct PackedHalf<cutlass::bfloat16_t> {
using Type = nv_bfloat162;
using Type = nv_bfloat162;
};
template <typename To_type, typename Engine, typename Layout>
__forceinline__ __device__ auto convert_type(
Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
auto frag =
convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(
tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
template <int numel>
__forceinline__ __device__ void convert_c4_2_fp8(const int32_t *src,
int32_t *dst1,
int32_t *dst2) {
#pragma unroll
for (int i = 0; i < numel; ++i) {
uint32_t head1 = src[i] & 0x80808080;
dst1[i] = (src[i] >> 4) & 0x07070707;
dst1[i] = dst1[i] | head1;
uint32_t head2 = (src[i] & 0x08080808) << 4;
dst2[i] = src[i] & 0x07070707;
dst2[i] = dst2[i] | head2;
}
}
template <int wg_wait = 0,
bool arrive = true,
bool commit = true,
typename Tensor0,
typename Tensor1,
typename Tensor2,
typename Tensor3,
typename TiledMma,
typename ThrCopyA,
typename TiledCopyA>
__forceinline__ __device__ void gemm(TiledMma &tiled_mma,
Tensor0 &tCrA,
Tensor1 &tCsA,
Tensor2 const &tCrB,
Tensor3 &tCrC,
TiledCopyA const &tiled_copy_A,
ThrCopyA const &thr_copy_A) {
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator,
typename TiledMma::FrgTypeA>::value;
Tensor tCrA1 = make_tensor<cutlass::float_e4m3_t>(tCrA.layout());
Tensor tCrA2 = make_tensor<cutlass::float_e4m3_t>(tCrA.layout());
if constexpr (Is_RS) {
warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA));
}
warpgroup_fence_operand(tCrC);
if constexpr (arrive) {
warpgroup_arrive();
}
constexpr int numel = decltype(size(tCrA(_, _, 0)))::value / 4;
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
cute::copy(tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{}));
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
if (k_block < size<2>(tCrA) - 1) {
cute::copy(tiled_copy_A,
tCsA(_, _, k_block + 1),
tCrA_copy_view(_, _, k_block + 1));
__forceinline__ __device__ void convert_c4_2_fp8(const int32_t * src, int32_t * dst1, int32_t * dst2) {
#pragma unroll
for (int i = 0; i < numel; ++i) {
dst1[i] = (src[i] >> 4) & 0x0f0f0f0f;
dst2[i] = src[i] & 0x0f0f0f0f;
}
int32_t *tCrA_data =
reinterpret_cast<int32_t *>(tCrA(_, _, k_block).data());
int32_t *tCrA1_data =
reinterpret_cast<int32_t *>(tCrA1(_, _, k_block).data());
int32_t *tCrA2_data =
reinterpret_cast<int32_t *>(tCrA2(_, _, k_block).data());
convert_c4_2_fp8<numel>(tCrA_data, tCrA1_data, tCrA2_data);
cute::gemm(tiled_mma, tCrA1(_, _, k_block), tCrB(_, _, 2 * k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
cute::gemm(
tiled_mma, tCrA2(_, _, k_block), tCrB(_, _, 2 * k_block + 1), tCrC);
}
if constexpr (commit) {
warpgroup_commit_batch();
}
if constexpr (wg_wait >= 0) {
warpgroup_wait<wg_wait>();
}
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) {
warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA));
}
}
template <int wg_wait=0, bool arrive=true,
bool commit=true, typename Tensor0, typename Tensor1,
typename Tensor2, typename Tensor3, typename TiledMma,
typename ThrCopyA, typename TiledCopyA>
__forceinline__ __device__ void gemm(
TiledMma &tiled_mma,
Tensor0 &tCrA,
Tensor1 &tCsA,
Tensor2 const &tCrB,
Tensor3 &tCrC,
TiledCopyA const &tiled_copy_A,
ThrCopyA const &thr_copy_A) {
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
Tensor tCrA1 = make_tensor<cutlass::float_e4m3_t>(tCrA.layout());
Tensor tCrA2 = make_tensor<cutlass::float_e4m3_t>(tCrA.layout());
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(tCrC);
if constexpr (arrive) {
warpgroup_arrive();
}
constexpr int numel = decltype(size(tCrA(_, _, 0)))::value / 4;
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
cute::copy(tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{}));
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
if (k_block < size<2>(tCrA) - 1) {
cute::copy(tiled_copy_A, tCsA(_, _, k_block + 1), tCrA_copy_view(_, _, k_block + 1));
}
int32_t * tCrA_data = reinterpret_cast<int32_t *>(tCrA(_,_,k_block).data());
int32_t * tCrA1_data = reinterpret_cast<int32_t *>(tCrA1(_,_,k_block).data());
int32_t * tCrA2_data = reinterpret_cast<int32_t *>(tCrA2(_,_,k_block).data());
convert_c4_2_fp8<numel>(tCrA_data, tCrA1_data, tCrA2_data);
cute::gemm(tiled_mma, tCrA1(_,_,k_block), tCrB(_,_,2 * k_block), tCrC);
cute::gemm(tiled_mma, tCrA2(_,_,k_block), tCrB(_,_, 2 * k_block + 1), tCrC);
}
if constexpr (commit) {
warpgroup_commit_batch();
}
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
}