fix w4afp8 (#5634)

This commit is contained in:
lizexu123
2025-12-22 13:39:41 +08:00
committed by GitHub
parent 6eada4929d
commit 6d323769dd
3 changed files with 59 additions and 47 deletions

View File

@@ -206,22 +206,24 @@ void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
template <int Experts>
auto get_gmem_layout(const int Rows, const int Cols) {
return make_layout(make_shape(static_cast<int64_t>(Rows),
static_cast<int64_t>(Cols),
static_cast<int64_t>(Experts)),
make_stride(static_cast<int64_t>(Cols),
cute::_1{},
static_cast<int64_t>(Rows * Cols)));
return make_layout(
make_shape(static_cast<int64_t>(Rows),
static_cast<int64_t>(Cols),
static_cast<int64_t>(Experts)),
make_stride(static_cast<int64_t>(Cols),
cute::_1{},
static_cast<int64_t>(Rows) * static_cast<int64_t>(Cols)));
}
template <int Experts>
auto get_scale_layout(const int Rows, const int Cols) {
return make_layout(make_shape(static_cast<int64_t>(Cols),
static_cast<int64_t>(Rows),
static_cast<int64_t>(Experts)),
make_stride(cute::_1{},
static_cast<int64_t>(Cols),
static_cast<int64_t>(Rows * Cols)));
return make_layout(
make_shape(static_cast<int64_t>(Cols),
static_cast<int64_t>(Rows),
static_cast<int64_t>(Experts)),
make_stride(cute::_1{},
static_cast<int64_t>(Cols),
static_cast<int64_t>(Rows) * static_cast<int64_t>(Cols)));
}
template <typename InputType,

View File

@@ -85,7 +85,7 @@ void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(
"""
# [M, K, Number of experts, token Padding Size, weight K group size]
gemm_case = [[256, 256, 2, 0, 128], [512, 256, 2, 0, 128]]
gemm_case = [[256, 256, 2, 0, 128], [512, 256, 2, 0, 128], [256, 5120, 128, 0, 128]]
dtype = ["BF16"]