fix w4afp8 (#5634)

This commit is contained in:
lizexu123
2025-12-22 13:39:41 +08:00
committed by GitHub
parent 6eada4929d
commit 6d323769dd
3 changed files with 59 additions and 47 deletions

View File

@@ -206,22 +206,24 @@ void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
template <int Experts> template <int Experts>
auto get_gmem_layout(const int Rows, const int Cols) { auto get_gmem_layout(const int Rows, const int Cols) {
return make_layout(make_shape(static_cast<int64_t>(Rows), return make_layout(
static_cast<int64_t>(Cols), make_shape(static_cast<int64_t>(Rows),
static_cast<int64_t>(Experts)), static_cast<int64_t>(Cols),
make_stride(static_cast<int64_t>(Cols), static_cast<int64_t>(Experts)),
cute::_1{}, make_stride(static_cast<int64_t>(Cols),
static_cast<int64_t>(Rows * Cols))); cute::_1{},
static_cast<int64_t>(Rows) * static_cast<int64_t>(Cols)));
} }
template <int Experts> template <int Experts>
auto get_scale_layout(const int Rows, const int Cols) { auto get_scale_layout(const int Rows, const int Cols) {
return make_layout(make_shape(static_cast<int64_t>(Cols), return make_layout(
static_cast<int64_t>(Rows), make_shape(static_cast<int64_t>(Cols),
static_cast<int64_t>(Experts)), static_cast<int64_t>(Rows),
make_stride(cute::_1{}, static_cast<int64_t>(Experts)),
static_cast<int64_t>(Cols), make_stride(cute::_1{},
static_cast<int64_t>(Rows * Cols))); static_cast<int64_t>(Cols),
static_cast<int64_t>(Rows) * static_cast<int64_t>(Cols)));
} }
template <typename InputType, template <typename InputType,

View File

@@ -85,7 +85,7 @@ void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(
""" """
# [M, K, Number of experts, token Padding Size, weight K group size] # [M, K, Number of experts, token Padding Size, weight K group size]
gemm_case = [[256, 256, 2, 0, 128], [512, 256, 2, 0, 128]] gemm_case = [[256, 256, 2, 0, 128], [512, 256, 2, 0, 128], [256, 5120, 128, 0, 128]]
dtype = ["BF16"] dtype = ["BF16"]

View File

@@ -27,12 +27,18 @@ from fastdeploy.model_executor.ops.gpu import (
class TestW4AFP8GEMM(unittest.TestCase): class TestW4AFP8GEMM(unittest.TestCase):
def setUp(self): def setUp(self):
paddle.seed(0) paddle.seed(0)
self.tokens_per_group = 1 self.test_cases = [
self.N = 256 [1, 2, 256, 256],
self.K = 256 [4096, 128, 256, 5120],
self.BATCH = 2 ]
self.TokenPadding = 0 self.TokenPadding = 0
def set_data(self, tokens_per_group, Experts, N, K):
self.tokens_per_group = tokens_per_group
self.N = N
self.K = K
self.BATCH = Experts
tokens = [self.tokens_per_group] * self.BATCH tokens = [self.tokens_per_group] * self.BATCH
self.tokens_prefix_sum = np.cumsum(tokens) self.tokens_prefix_sum = np.cumsum(tokens)
@@ -90,39 +96,43 @@ class TestW4AFP8GEMM(unittest.TestCase):
return processed_weight_scale return processed_weight_scale
def test_w4afp8_gemm(self): def test_w4afp8_gemm(self):
out_naive = self.w4afp8_gemm_naive( for test_case in self.test_cases:
self.input_bf16, self.weight_quant_naive, self.tokens, self.weight_dequant_scale tokens_per_group, Experts, N, K = test_case
)
# weight_dequant_scale = paddle.to_tensor(self.permute_scale(self.weight_dequant_scale) * 512) self.set_data(tokens_per_group, Experts, N, K)
weight_dequant_scale = self.get_per_group_scale(self.weight_dequant_scale * 512) out_naive = self.w4afp8_gemm_naive(
weight_int4 = w4afp8_gemm_weight_convert(self.weight_quant.astype("uint8").cpu()).cuda() self.input_bf16, self.weight_quant_naive, self.tokens, self.weight_dequant_scale
if self.TokenPadding == 0:
out_cuda = w4afp8_gemm(
self.input_fp8,
weight_int4,
self.tokens_prefix_sum,
weight_dequant_scale.astype("float32"),
None,
int(self.TokenPadding),
self.all_tokens,
True,
)
else:
out_cuda = w4afp8_gemm(
self.input_fp8,
weight_int4,
self.tokens,
weight_dequant_scale.astype("float32"),
None,
int(self.TokenPadding),
self.max_tokens,
True,
) )
gap = (out_cuda - out_naive).abs() # weight_dequant_scale = paddle.to_tensor(self.permute_scale(self.weight_dequant_scale) * 512)
self.assertLess(float(gap.mean()), 0.11) weight_dequant_scale = self.get_per_group_scale(self.weight_dequant_scale * 512)
weight_int4 = w4afp8_gemm_weight_convert(self.weight_quant.astype("uint8").cpu()).cuda()
if self.TokenPadding == 0:
out_cuda = w4afp8_gemm(
self.input_fp8,
weight_int4,
self.tokens_prefix_sum,
weight_dequant_scale.astype("float32"),
None,
int(self.TokenPadding),
self.all_tokens,
True,
)
else:
out_cuda = w4afp8_gemm(
self.input_fp8,
weight_int4,
self.tokens,
weight_dequant_scale.astype("float32"),
None,
int(self.TokenPadding),
self.max_tokens,
True,
)
gap = (out_cuda - out_naive).abs()
self.assertLess(float(gap.mean()), 0.11)
if __name__ == "__main__": if __name__ == "__main__":