fix w4afp8 (#5634)

This commit is contained in:
lizexu123
2025-12-22 13:39:41 +08:00
committed by GitHub
parent 6eada4929d
commit 6d323769dd
3 changed files with 59 additions and 47 deletions

View File

@@ -206,22 +206,24 @@ void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
template <int Experts> template <int Experts>
auto get_gmem_layout(const int Rows, const int Cols) { auto get_gmem_layout(const int Rows, const int Cols) {
return make_layout(make_shape(static_cast<int64_t>(Rows), return make_layout(
make_shape(static_cast<int64_t>(Rows),
static_cast<int64_t>(Cols), static_cast<int64_t>(Cols),
static_cast<int64_t>(Experts)), static_cast<int64_t>(Experts)),
make_stride(static_cast<int64_t>(Cols), make_stride(static_cast<int64_t>(Cols),
cute::_1{}, cute::_1{},
static_cast<int64_t>(Rows * Cols))); static_cast<int64_t>(Rows) * static_cast<int64_t>(Cols)));
} }
template <int Experts> template <int Experts>
auto get_scale_layout(const int Rows, const int Cols) { auto get_scale_layout(const int Rows, const int Cols) {
return make_layout(make_shape(static_cast<int64_t>(Cols), return make_layout(
make_shape(static_cast<int64_t>(Cols),
static_cast<int64_t>(Rows), static_cast<int64_t>(Rows),
static_cast<int64_t>(Experts)), static_cast<int64_t>(Experts)),
make_stride(cute::_1{}, make_stride(cute::_1{},
static_cast<int64_t>(Cols), static_cast<int64_t>(Cols),
static_cast<int64_t>(Rows * Cols))); static_cast<int64_t>(Rows) * static_cast<int64_t>(Cols)));
} }
template <typename InputType, template <typename InputType,

View File

@@ -85,7 +85,7 @@ void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(
""" """
# [M, K, Number of experts, token Padding Size, weight K group size] # [M, K, Number of experts, token Padding Size, weight K group size]
gemm_case = [[256, 256, 2, 0, 128], [512, 256, 2, 0, 128]] gemm_case = [[256, 256, 2, 0, 128], [512, 256, 2, 0, 128], [256, 5120, 128, 0, 128]]
dtype = ["BF16"] dtype = ["BF16"]

View File

@@ -27,12 +27,18 @@ from fastdeploy.model_executor.ops.gpu import (
class TestW4AFP8GEMM(unittest.TestCase): class TestW4AFP8GEMM(unittest.TestCase):
def setUp(self): def setUp(self):
paddle.seed(0) paddle.seed(0)
self.tokens_per_group = 1 self.test_cases = [
self.N = 256 [1, 2, 256, 256],
self.K = 256 [4096, 128, 256, 5120],
self.BATCH = 2 ]
self.TokenPadding = 0 self.TokenPadding = 0
def set_data(self, tokens_per_group, Experts, N, K):
self.tokens_per_group = tokens_per_group
self.N = N
self.K = K
self.BATCH = Experts
tokens = [self.tokens_per_group] * self.BATCH tokens = [self.tokens_per_group] * self.BATCH
self.tokens_prefix_sum = np.cumsum(tokens) self.tokens_prefix_sum = np.cumsum(tokens)
@@ -90,6 +96,10 @@ class TestW4AFP8GEMM(unittest.TestCase):
return processed_weight_scale return processed_weight_scale
def test_w4afp8_gemm(self): def test_w4afp8_gemm(self):
for test_case in self.test_cases:
tokens_per_group, Experts, N, K = test_case
self.set_data(tokens_per_group, Experts, N, K)
out_naive = self.w4afp8_gemm_naive( out_naive = self.w4afp8_gemm_naive(
self.input_bf16, self.weight_quant_naive, self.tokens, self.weight_dequant_scale self.input_bf16, self.weight_quant_naive, self.tokens, self.weight_dequant_scale
) )