fix w4afp8 (#5634)

This commit is contained in:
lizexu123
2025-12-22 13:39:41 +08:00
committed by GitHub
parent 6eada4929d
commit 6d323769dd
3 changed files with 59 additions and 47 deletions

View File

@@ -206,22 +206,24 @@ void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
template <int Experts>
auto get_gmem_layout(const int Rows, const int Cols) {
return make_layout(make_shape(static_cast<int64_t>(Rows),
static_cast<int64_t>(Cols),
static_cast<int64_t>(Experts)),
make_stride(static_cast<int64_t>(Cols),
cute::_1{},
static_cast<int64_t>(Rows * Cols)));
return make_layout(
make_shape(static_cast<int64_t>(Rows),
static_cast<int64_t>(Cols),
static_cast<int64_t>(Experts)),
make_stride(static_cast<int64_t>(Cols),
cute::_1{},
static_cast<int64_t>(Rows) * static_cast<int64_t>(Cols)));
}
template <int Experts>
auto get_scale_layout(const int Rows, const int Cols) {
return make_layout(make_shape(static_cast<int64_t>(Cols),
static_cast<int64_t>(Rows),
static_cast<int64_t>(Experts)),
make_stride(cute::_1{},
static_cast<int64_t>(Cols),
static_cast<int64_t>(Rows * Cols)));
return make_layout(
make_shape(static_cast<int64_t>(Cols),
static_cast<int64_t>(Rows),
static_cast<int64_t>(Experts)),
make_stride(cute::_1{},
static_cast<int64_t>(Cols),
static_cast<int64_t>(Rows) * static_cast<int64_t>(Cols)));
}
template <typename InputType,

View File

@@ -85,7 +85,7 @@ void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(
"""
# [M, K, Number of experts, token Padding Size, weight K group size]
gemm_case = [[256, 256, 2, 0, 128], [512, 256, 2, 0, 128]]
gemm_case = [[256, 256, 2, 0, 128], [512, 256, 2, 0, 128], [256, 5120, 128, 0, 128]]
dtype = ["BF16"]

View File

@@ -27,12 +27,18 @@ from fastdeploy.model_executor.ops.gpu import (
class TestW4AFP8GEMM(unittest.TestCase):
def setUp(self):
paddle.seed(0)
self.tokens_per_group = 1
self.N = 256
self.K = 256
self.BATCH = 2
self.test_cases = [
[1, 2, 256, 256],
[4096, 128, 256, 5120],
]
self.TokenPadding = 0
def set_data(self, tokens_per_group, Experts, N, K):
self.tokens_per_group = tokens_per_group
self.N = N
self.K = K
self.BATCH = Experts
tokens = [self.tokens_per_group] * self.BATCH
self.tokens_prefix_sum = np.cumsum(tokens)
@@ -90,39 +96,43 @@ class TestW4AFP8GEMM(unittest.TestCase):
return processed_weight_scale
def test_w4afp8_gemm(self):
out_naive = self.w4afp8_gemm_naive(
self.input_bf16, self.weight_quant_naive, self.tokens, self.weight_dequant_scale
)
for test_case in self.test_cases:
tokens_per_group, Experts, N, K = test_case
# weight_dequant_scale = paddle.to_tensor(self.permute_scale(self.weight_dequant_scale) * 512)
weight_dequant_scale = self.get_per_group_scale(self.weight_dequant_scale * 512)
weight_int4 = w4afp8_gemm_weight_convert(self.weight_quant.astype("uint8").cpu()).cuda()
if self.TokenPadding == 0:
out_cuda = w4afp8_gemm(
self.input_fp8,
weight_int4,
self.tokens_prefix_sum,
weight_dequant_scale.astype("float32"),
None,
int(self.TokenPadding),
self.all_tokens,
True,
)
else:
out_cuda = w4afp8_gemm(
self.input_fp8,
weight_int4,
self.tokens,
weight_dequant_scale.astype("float32"),
None,
int(self.TokenPadding),
self.max_tokens,
True,
self.set_data(tokens_per_group, Experts, N, K)
out_naive = self.w4afp8_gemm_naive(
self.input_bf16, self.weight_quant_naive, self.tokens, self.weight_dequant_scale
)
gap = (out_cuda - out_naive).abs()
self.assertLess(float(gap.mean()), 0.11)
# weight_dequant_scale = paddle.to_tensor(self.permute_scale(self.weight_dequant_scale) * 512)
weight_dequant_scale = self.get_per_group_scale(self.weight_dequant_scale * 512)
weight_int4 = w4afp8_gemm_weight_convert(self.weight_quant.astype("uint8").cpu()).cuda()
if self.TokenPadding == 0:
out_cuda = w4afp8_gemm(
self.input_fp8,
weight_int4,
self.tokens_prefix_sum,
weight_dequant_scale.astype("float32"),
None,
int(self.TokenPadding),
self.all_tokens,
True,
)
else:
out_cuda = w4afp8_gemm(
self.input_fp8,
weight_int4,
self.tokens,
weight_dequant_scale.astype("float32"),
None,
int(self.TokenPadding),
self.max_tokens,
True,
)
gap = (out_cuda - out_naive).abs()
self.assertLess(float(gap.mean()), 0.11)
if __name__ == "__main__":